ibadhasnain's picture
Update app.py
10d0608 verified
# ---------------------------------------------------------
# Chainlit app (Method A)
# Project: Anatomy & Physiology Tutor + /gen Diagram (SD-XL)
# ---------------------------------------------------------
import os, json, time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
import chainlit as cl
from dotenv import load_dotenv
from pydantic import BaseModel
from openai import AsyncOpenAI as _SDKAsyncOpenAI
from huggingface_hub import InferenceClient
from PIL import Image
# =============================
# Inline "agents" shim (no extra package needed)
# =============================
def set_tracing_disabled(disabled: bool = True):
return disabled
def function_tool(func: Callable):
func._is_tool = True
return func
def handoff(*args, **kwargs):
return None
class InputGuardrail:
def __init__(self, guardrail_function: Callable):
self.guardrail_function = guardrail_function
@dataclass
class GuardrailFunctionOutput:
output_info: Any
tripwire_triggered: bool = False
tripwire_message: str = ""
class InputGuardrailTripwireTriggered(Exception):
pass
class AsyncOpenAI:
def __init__(self, api_key: str, base_url: Optional[str] = None):
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self._client = _SDKAsyncOpenAI(**kwargs)
@property
def client(self):
return self._client
class OpenAIChatCompletionsModel:
def __init__(self, model: str, openai_client: AsyncOpenAI):
self.model = model
self.client = openai_client.client
@dataclass
class Agent:
name: str
instructions: str
model: OpenAIChatCompletionsModel
tools: Optional[List[Callable]] = field(default_factory=list)
handoff_description: Optional[str] = None
output_type: Optional[type] = None
input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list)
def tool_specs(self) -> List[Dict[str, Any]]:
specs = []
for t in (self.tools or []):
if getattr(t, "_is_tool", False):
specs.append({
"type": "function",
"function": {
"name": t.__name__,
"description": (t.__doc__ or "")[:512],
"parameters": {
"type": "object",
"properties": {
p: {"type": "string"}
for p in t.__code__.co_varnames[:t.__code__.co_argcount]
},
"required": list(t.__code__.co_varnames[:t.__code__.co_argcount]),
},
},
})
return specs
class Runner:
@staticmethod
async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None):
msgs = [
{"role": "system", "content": agent.instructions},
{"role": "user", "content": user_input},
]
tools = agent.tool_specs()
tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)}
# up to 4 tool-use rounds
for _ in range(4):
resp = await agent.model.client.chat.completions.create(
model=agent.model.model,
messages=msgs,
tools=tools if tools else None,
tool_choice="auto" if tools else None,
)
choice = resp.choices[0]
msg = choice.message
msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls})
if msg.tool_calls:
for call in msg.tool_calls:
fn_name = call.function.name
args = json.loads(call.function.arguments or "{}")
if fn_name in tool_map:
try:
result = tool_map[fn_name](**args)
except Exception as e:
result = {"error": str(e)}
else:
result = {"error": f"Unknown tool: {fn_name}"}
msgs.append({
"role": "tool",
"tool_call_id": call.id,
"name": fn_name,
"content": json.dumps(result),
})
continue # next round with tool outputs
final_text = msg.content or ""
final_obj = type("Result", (), {})()
final_obj.final_output = final_text
final_obj.context = context or {}
if agent.output_type and issubclass(agent.output_type, BaseModel):
try:
data = agent.output_type.model_validate_json(final_text)
final_obj.final_output = data.model_dump_json()
final_obj.final_output_as = lambda t: data
except Exception:
final_obj.final_output_as = lambda t: final_text
else:
final_obj.final_output_as = lambda t: final_text
return final_obj
final_obj = type("Result", (), {})()
final_obj.final_output = "Sorry, I couldn't complete the request."
final_obj.context = context or {}
final_obj.final_output_as = lambda t: final_obj.final_output
return final_obj
# =============================
# App configuration
# =============================
load_dotenv()
API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY")
if not API_KEY:
raise RuntimeError("Missing GEMINI_API_KEY (or OPENAI_API_KEY). Add it in Space secrets or .env.")
HF_TOKEN = os.environ.get("HF_TOKEN") # for SD-XL generation
set_tracing_disabled(True)
external_client: AsyncOpenAI = AsyncOpenAI(
api_key=API_KEY,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel(
model="gemini-2.5-flash",
openai_client=external_client,
)
# =============================
# A&P tools
# =============================
@function_tool
def topic_reference_guide(topic: str) -> dict:
"""
Educational bullets for common Anatomy & Physiology topics.
Returns dict with anatomy, physiology, misconceptions, study_tips.
"""
t = (topic or "").lower()
def pack(anat, phys, misc, tips):
return {"anatomy": anat, "physiology": phys, "misconceptions": misc, "study_tips": tips}
if any(k in t for k in ["cardiac conduction", "sa node", "av node", "bundle", "purkinje", "ecg"]):
return pack(
anat=[
"SA node (RA) → AV node → His bundle → R/L bundle branches → Purkinje fibers.",
"Fibrous skeleton insulates atria from ventricles; AV node is the gateway."
],
phys=[
"Pacemaker automaticity (If current) sets heart rate.",
"AV nodal delay allows ventricular filling; His–Purkinje enables synchronized contraction."
],
misc=[
"Misconception: all tissues pace equally — SA node dominates.",
"PR interval ≈ AV nodal delay."
],
tips=[
"Map ECG waves to mechanics (P/QRS/T).",
"Relate ion channels to nodal vs myocyte AP phases."
],
)
if any(k in t for k in ["nephron", "kidney", "gfr", "raas", "countercurrent"]):
return pack(
anat=[
"Segments: Bowman’s → PCT → DLH → ALH (thin/thick) → DCT → collecting duct.",
"Cortex (glomeruli/PCT/DCT) vs medulla (loops/collecting ducts)."
],
phys=[
"GFR via Starling forces; PCT bulk reabsorption; ALH generates gradient; DCT/CD fine-tune (ADH/aldosterone).",
"Countercurrent multiplication + vasa recta maintain medullary gradient."
],
misc=[
"Misconception: water reabsorption is equal everywhere — hormones control CD concentration.",
"Urea also supports medullary osmolality."
],
tips=[
"Sketch transporters per segment.",
"Practice ‘what if’ with ↑ADH/↑Aldosterone/↑GFR."
],
)
if any(k in t for k in ["alveolar", "gas exchange", "v/q", "ventilation perfusion", "oxygen dissociation"]):
return pack(
anat=[
"Conducting vs respiratory zones; Type I (exchange) vs Type II (surfactant) pneumocytes."
],
phys=[
"V/Q matching optimizes exchange; extremes: dead space (high V/Q) vs shunt (low V/Q).",
"O2–Hb curve shifts with pH, CO2, temp, 2,3-BPG (Bohr effect)."
],
misc=[
"Misconception: uniform V/Q across lung — gravity/disease alter distribution.",
"Assuming PaO2–SaO2 linearity (it’s sigmoidal)."
],
tips=[
"Draw V/Q along lung height.",
"Link spirometry patterns to mechanics."
],
)
if any(k in t for k in ["neuron", "synapse", "action potential", "neurotransmitter"]):
return pack(
anat=[
"Neuron: dendrites, soma, axon hillock, axon; myelin & nodes."
],
phys=[
"AP: Na+ depolarization → K+ repolarization; refractory periods.",
"Chemical synapse: Ca2+-dependent vesicle release; EPSPs/IPSPs summate."
],
misc=[
"Misconception: any EPSP fires AP — threshold & summation matter."
],
tips=[
"Relate channel states to AP phases.",
"Compare ionotropic vs metabotropic effects."
],
)
if any(k in t for k in ["muscle contraction", "excitation contraction", "sarcomere"]):
return pack(
anat=[
"Sarcomere: Z to Z; thin (actin/troponin/tropomyosin) vs thick (myosin).",
"T-tubules & SR triads (skeletal)."
],
phys=[
"AP → DHPR → RyR → Ca2+ release → troponin C → cross-bridge cycling.",
"Force–length & force–velocity relationships; motor unit recruitment."
],
misc=[
"ATP also needed for detachment & Ca2+ resequestration."
],
tips=[
"Diagram cross-bridge cycle.",
"Predict effects of length on force."
],
)
return pack(
anat=["Identify major structures & relationships."],
phys=["Describe inputs → key steps → outputs; control loops."],
misc=["Clarify commonly conflated terms; distinguish variation from pathology (education-only)."],
tips=["Make a labeled sketch; use ‘if-this-then-that’ scenarios to test understanding."],
)
@function_tool
def study_outline(topic: str) -> dict:
"""Return a suggested outline and practice prompts for the topic."""
return {
"topic": topic,
"subtopics": [
"Key structures & relationships",
"Mechanism (step-by-step)",
"Control & feedback",
"Quantitative intuition (flows, pressures, potentials)",
"Common misconceptions"
],
"practice_prompts": [
f"Explain {topic} in 5 steps to a first-year student.",
f"Draw a block diagram of {topic} with inputs/outputs.",
f"What changes if a key parameter increases/decreases in {topic}?",
],
}
# =============================
# Agents
# =============================
tutor_instructions = (
"You are an Anatomy & Physiology Tutor. TEACH, do not diagnose.\n"
"Given a topic (e.g., 'cardiac conduction', 'nephron physiology'), produce concise bullet points:\n"
"1) Anatomy (structures/relationships)\n"
"2) Physiology (inputs → steps → outputs)\n"
"3) Common misconceptions\n"
"4) Study tips\n"
"5) Caution: education only, no diagnosis\n"
"Use tools (topic_reference_guide, study_outline) to ground the response.\n"
"Avoid clinical diagnosis or treatment advice."
)
tutor_agent = Agent(
name="A&P Tutor",
instructions=tutor_instructions,
model=llm_model,
tools=[topic_reference_guide, study_outline],
)
guardrail_agent = Agent(
name="Safety Classifier",
instructions=(
"Classify if the user's message requests medical diagnosis or unsafe medical advice, "
"and if it includes personal identifiers. Respond as JSON with fields: "
"{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}."
),
model=llm_model,
)
# =============================
# SD-XL helper (HF Inference API) -> PNG path
# =============================
def sdxl_png(prompt: str, negative: str = "") -> str:
token = os.getenv("HF_TOKEN")
if not token:
raise RuntimeError("Missing HF_TOKEN. Add it in Space secrets.")
client = InferenceClient("stabilityai/stable-diffusion-xl-base-1.0", token=token)
img: Image.Image = client.text_to_image(
prompt = prompt + " | educational, clean vector style, white background, no labels, safe-for-work",
negative_prompt = negative or "text, watermark, logo, gore, photorealistic patient, clutter",
width = 1024, height = 768,
guidance_scale = 7.5,
num_inference_steps = 30,
seed = 42
)
out_dir = os.environ.get("CHAINLIT_FILES_DIR") or os.path.join(os.getcwd(), ".files")
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f"sdxl-{int(time.time())}.png")
img.save(path)
return path
# =============================
# Chainlit flows
# =============================
WELCOME = (
"🧠 **Anatomy & Physiology Tutor** + 🎨 **/gen Diagram (SD-XL)**\n\n"
"• Ask any A&P topic (e.g., *cardiac conduction*, *nephron physiology*, *gas exchange*).\n"
"• Generate a clean educational diagram (PNG) with:\n"
" `/gen isometric nephron diagram, flat vector, white background, no labels`\n\n"
"⚠️ Education only — no diagnosis or clinical advice."
)
@cl.on_chat_start
async def on_chat_start():
await cl.Message(content=WELCOME).send()
@cl.on_message
async def on_message(message: cl.Message):
text = (message.content or "").strip()
# Quick command: /gen <desc> → SD-XL PNG
if text.lower().startswith("/gen "):
desc = text[5:].strip()
if not desc:
await cl.Message(content="Please provide a description after `/gen`.\nExample: `/gen isometric nephron diagram, flat vector`").send()
return
try:
path = sdxl_png(desc)
except Exception as e:
await cl.Message(content=f"Generation failed: {e}").send()
return
await cl.Message(
content="🎨 **Generated diagram** (education only):",
elements=[cl.Image(path=path, name=os.path.basename(path), display="inline")],
).send()
return
# Light safety check (blocks diagnosis/treatment requests)
try:
safety = await Runner.run(guardrail_agent, text)
parsed = safety.final_output
try:
data = json.loads(parsed) if isinstance(parsed, str) else parsed
except Exception:
data = {}
if isinstance(data, dict) and (data.get("unsafe_medical_advice") or data.get("requests_diagnosis")):
await cl.Message(
content="🚫 I can’t provide diagnosis or treatment advice. I can teach A&P concepts and generate **educational** diagrams."
).send()
return
except Exception:
pass
# Tutor mode
topic = text or "anatomy and physiology overview"
result = await Runner.run(tutor_agent, topic)
# Tool-based quick reference + outline
try:
guide = topic_reference_guide(topic)
except Exception:
guide = {}
try:
outline = study_outline(topic)
except Exception:
outline = {}
def bullets(arr):
return "\n".join([f"- {b}" for b in arr]) if isinstance(arr, list) else "- (n/a)"
quick_ref = (
f"### 📚 Quick Reference: {topic}\n"
f"**Anatomy**\n{bullets(guide.get('anatomy', []))}\n\n"
f"**Physiology**\n{bullets(guide.get('physiology', []))}\n\n"
f"**Common Misconceptions**\n{bullets(guide.get('misconceptions', []))}\n\n"
f"**Study Tips**\n{bullets(guide.get('study_tips', []))}\n\n"
)
outline_md = ""
if outline:
subs = bullets(outline.get("subtopics", []))
qs = bullets(outline.get("practice_prompts", []))
outline_md = f"### 🗂️ Suggested Outline\n{subs}\n\n### 📝 Practice Prompts\n{qs}\n\n"
caution = "> ⚠️ Education only — no diagnosis or treatment advice."
answer = result.final_output or "I couldn’t generate an explanation."
await cl.Message(content=f"{quick_ref}{outline_md}---\n{answer}\n\n{caution}").send()