# --------------------------------------------------------- # Chainlit app (Method A) # Project: Anatomy & Physiology Tutor + /gen Diagram (SD-XL) # --------------------------------------------------------- import os, json, time from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional import chainlit as cl from dotenv import load_dotenv from pydantic import BaseModel from openai import AsyncOpenAI as _SDKAsyncOpenAI from huggingface_hub import InferenceClient from PIL import Image # ============================= # Inline "agents" shim (no extra package needed) # ============================= def set_tracing_disabled(disabled: bool = True): return disabled def function_tool(func: Callable): func._is_tool = True return func def handoff(*args, **kwargs): return None class InputGuardrail: def __init__(self, guardrail_function: Callable): self.guardrail_function = guardrail_function @dataclass class GuardrailFunctionOutput: output_info: Any tripwire_triggered: bool = False tripwire_message: str = "" class InputGuardrailTripwireTriggered(Exception): pass class AsyncOpenAI: def __init__(self, api_key: str, base_url: Optional[str] = None): kwargs = {"api_key": api_key} if base_url: kwargs["base_url"] = base_url self._client = _SDKAsyncOpenAI(**kwargs) @property def client(self): return self._client class OpenAIChatCompletionsModel: def __init__(self, model: str, openai_client: AsyncOpenAI): self.model = model self.client = openai_client.client @dataclass class Agent: name: str instructions: str model: OpenAIChatCompletionsModel tools: Optional[List[Callable]] = field(default_factory=list) handoff_description: Optional[str] = None output_type: Optional[type] = None input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list) def tool_specs(self) -> List[Dict[str, Any]]: specs = [] for t in (self.tools or []): if getattr(t, "_is_tool", False): specs.append({ "type": "function", "function": { "name": t.__name__, "description": (t.__doc__ or "")[:512], "parameters": { "type": "object", "properties": { p: {"type": "string"} for p in t.__code__.co_varnames[:t.__code__.co_argcount] }, "required": list(t.__code__.co_varnames[:t.__code__.co_argcount]), }, }, }) return specs class Runner: @staticmethod async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None): msgs = [ {"role": "system", "content": agent.instructions}, {"role": "user", "content": user_input}, ] tools = agent.tool_specs() tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)} # up to 4 tool-use rounds for _ in range(4): resp = await agent.model.client.chat.completions.create( model=agent.model.model, messages=msgs, tools=tools if tools else None, tool_choice="auto" if tools else None, ) choice = resp.choices[0] msg = choice.message msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls}) if msg.tool_calls: for call in msg.tool_calls: fn_name = call.function.name args = json.loads(call.function.arguments or "{}") if fn_name in tool_map: try: result = tool_map[fn_name](**args) except Exception as e: result = {"error": str(e)} else: result = {"error": f"Unknown tool: {fn_name}"} msgs.append({ "role": "tool", "tool_call_id": call.id, "name": fn_name, "content": json.dumps(result), }) continue # next round with tool outputs final_text = msg.content or "" final_obj = type("Result", (), {})() final_obj.final_output = final_text final_obj.context = context or {} if agent.output_type and issubclass(agent.output_type, BaseModel): try: data = agent.output_type.model_validate_json(final_text) final_obj.final_output = data.model_dump_json() final_obj.final_output_as = lambda t: data except Exception: final_obj.final_output_as = lambda t: final_text else: final_obj.final_output_as = lambda t: final_text return final_obj final_obj = type("Result", (), {})() final_obj.final_output = "Sorry, I couldn't complete the request." final_obj.context = context or {} final_obj.final_output_as = lambda t: final_obj.final_output return final_obj # ============================= # App configuration # ============================= load_dotenv() API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY") if not API_KEY: raise RuntimeError("Missing GEMINI_API_KEY (or OPENAI_API_KEY). Add it in Space secrets or .env.") HF_TOKEN = os.environ.get("HF_TOKEN") # for SD-XL generation set_tracing_disabled(True) external_client: AsyncOpenAI = AsyncOpenAI( api_key=API_KEY, base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel( model="gemini-2.5-flash", openai_client=external_client, ) # ============================= # A&P tools # ============================= @function_tool def topic_reference_guide(topic: str) -> dict: """ Educational bullets for common Anatomy & Physiology topics. Returns dict with anatomy, physiology, misconceptions, study_tips. """ t = (topic or "").lower() def pack(anat, phys, misc, tips): return {"anatomy": anat, "physiology": phys, "misconceptions": misc, "study_tips": tips} if any(k in t for k in ["cardiac conduction", "sa node", "av node", "bundle", "purkinje", "ecg"]): return pack( anat=[ "SA node (RA) → AV node → His bundle → R/L bundle branches → Purkinje fibers.", "Fibrous skeleton insulates atria from ventricles; AV node is the gateway." ], phys=[ "Pacemaker automaticity (If current) sets heart rate.", "AV nodal delay allows ventricular filling; His–Purkinje enables synchronized contraction." ], misc=[ "Misconception: all tissues pace equally — SA node dominates.", "PR interval ≈ AV nodal delay." ], tips=[ "Map ECG waves to mechanics (P/QRS/T).", "Relate ion channels to nodal vs myocyte AP phases." ], ) if any(k in t for k in ["nephron", "kidney", "gfr", "raas", "countercurrent"]): return pack( anat=[ "Segments: Bowman’s → PCT → DLH → ALH (thin/thick) → DCT → collecting duct.", "Cortex (glomeruli/PCT/DCT) vs medulla (loops/collecting ducts)." ], phys=[ "GFR via Starling forces; PCT bulk reabsorption; ALH generates gradient; DCT/CD fine-tune (ADH/aldosterone).", "Countercurrent multiplication + vasa recta maintain medullary gradient." ], misc=[ "Misconception: water reabsorption is equal everywhere — hormones control CD concentration.", "Urea also supports medullary osmolality." ], tips=[ "Sketch transporters per segment.", "Practice ‘what if’ with ↑ADH/↑Aldosterone/↑GFR." ], ) if any(k in t for k in ["alveolar", "gas exchange", "v/q", "ventilation perfusion", "oxygen dissociation"]): return pack( anat=[ "Conducting vs respiratory zones; Type I (exchange) vs Type II (surfactant) pneumocytes." ], phys=[ "V/Q matching optimizes exchange; extremes: dead space (high V/Q) vs shunt (low V/Q).", "O2–Hb curve shifts with pH, CO2, temp, 2,3-BPG (Bohr effect)." ], misc=[ "Misconception: uniform V/Q across lung — gravity/disease alter distribution.", "Assuming PaO2–SaO2 linearity (it’s sigmoidal)." ], tips=[ "Draw V/Q along lung height.", "Link spirometry patterns to mechanics." ], ) if any(k in t for k in ["neuron", "synapse", "action potential", "neurotransmitter"]): return pack( anat=[ "Neuron: dendrites, soma, axon hillock, axon; myelin & nodes." ], phys=[ "AP: Na+ depolarization → K+ repolarization; refractory periods.", "Chemical synapse: Ca2+-dependent vesicle release; EPSPs/IPSPs summate." ], misc=[ "Misconception: any EPSP fires AP — threshold & summation matter." ], tips=[ "Relate channel states to AP phases.", "Compare ionotropic vs metabotropic effects." ], ) if any(k in t for k in ["muscle contraction", "excitation contraction", "sarcomere"]): return pack( anat=[ "Sarcomere: Z to Z; thin (actin/troponin/tropomyosin) vs thick (myosin).", "T-tubules & SR triads (skeletal)." ], phys=[ "AP → DHPR → RyR → Ca2+ release → troponin C → cross-bridge cycling.", "Force–length & force–velocity relationships; motor unit recruitment." ], misc=[ "ATP also needed for detachment & Ca2+ resequestration." ], tips=[ "Diagram cross-bridge cycle.", "Predict effects of length on force." ], ) return pack( anat=["Identify major structures & relationships."], phys=["Describe inputs → key steps → outputs; control loops."], misc=["Clarify commonly conflated terms; distinguish variation from pathology (education-only)."], tips=["Make a labeled sketch; use ‘if-this-then-that’ scenarios to test understanding."], ) @function_tool def study_outline(topic: str) -> dict: """Return a suggested outline and practice prompts for the topic.""" return { "topic": topic, "subtopics": [ "Key structures & relationships", "Mechanism (step-by-step)", "Control & feedback", "Quantitative intuition (flows, pressures, potentials)", "Common misconceptions" ], "practice_prompts": [ f"Explain {topic} in 5 steps to a first-year student.", f"Draw a block diagram of {topic} with inputs/outputs.", f"What changes if a key parameter increases/decreases in {topic}?", ], } # ============================= # Agents # ============================= tutor_instructions = ( "You are an Anatomy & Physiology Tutor. TEACH, do not diagnose.\n" "Given a topic (e.g., 'cardiac conduction', 'nephron physiology'), produce concise bullet points:\n" "1) Anatomy (structures/relationships)\n" "2) Physiology (inputs → steps → outputs)\n" "3) Common misconceptions\n" "4) Study tips\n" "5) Caution: education only, no diagnosis\n" "Use tools (topic_reference_guide, study_outline) to ground the response.\n" "Avoid clinical diagnosis or treatment advice." ) tutor_agent = Agent( name="A&P Tutor", instructions=tutor_instructions, model=llm_model, tools=[topic_reference_guide, study_outline], ) guardrail_agent = Agent( name="Safety Classifier", instructions=( "Classify if the user's message requests medical diagnosis or unsafe medical advice, " "and if it includes personal identifiers. Respond as JSON with fields: " "{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}." ), model=llm_model, ) # ============================= # SD-XL helper (HF Inference API) -> PNG path # ============================= def sdxl_png(prompt: str, negative: str = "") -> str: token = os.getenv("HF_TOKEN") if not token: raise RuntimeError("Missing HF_TOKEN. Add it in Space secrets.") client = InferenceClient("stabilityai/stable-diffusion-xl-base-1.0", token=token) img: Image.Image = client.text_to_image( prompt = prompt + " | educational, clean vector style, white background, no labels, safe-for-work", negative_prompt = negative or "text, watermark, logo, gore, photorealistic patient, clutter", width = 1024, height = 768, guidance_scale = 7.5, num_inference_steps = 30, seed = 42 ) out_dir = os.environ.get("CHAINLIT_FILES_DIR") or os.path.join(os.getcwd(), ".files") os.makedirs(out_dir, exist_ok=True) path = os.path.join(out_dir, f"sdxl-{int(time.time())}.png") img.save(path) return path # ============================= # Chainlit flows # ============================= WELCOME = ( "🧠 **Anatomy & Physiology Tutor** + 🎨 **/gen Diagram (SD-XL)**\n\n" "• Ask any A&P topic (e.g., *cardiac conduction*, *nephron physiology*, *gas exchange*).\n" "• Generate a clean educational diagram (PNG) with:\n" " `/gen isometric nephron diagram, flat vector, white background, no labels`\n\n" "⚠️ Education only — no diagnosis or clinical advice." ) @cl.on_chat_start async def on_chat_start(): await cl.Message(content=WELCOME).send() @cl.on_message async def on_message(message: cl.Message): text = (message.content or "").strip() # Quick command: /gen → SD-XL PNG if text.lower().startswith("/gen "): desc = text[5:].strip() if not desc: await cl.Message(content="Please provide a description after `/gen`.\nExample: `/gen isometric nephron diagram, flat vector`").send() return try: path = sdxl_png(desc) except Exception as e: await cl.Message(content=f"Generation failed: {e}").send() return await cl.Message( content="🎨 **Generated diagram** (education only):", elements=[cl.Image(path=path, name=os.path.basename(path), display="inline")], ).send() return # Light safety check (blocks diagnosis/treatment requests) try: safety = await Runner.run(guardrail_agent, text) parsed = safety.final_output try: data = json.loads(parsed) if isinstance(parsed, str) else parsed except Exception: data = {} if isinstance(data, dict) and (data.get("unsafe_medical_advice") or data.get("requests_diagnosis")): await cl.Message( content="🚫 I can’t provide diagnosis or treatment advice. I can teach A&P concepts and generate **educational** diagrams." ).send() return except Exception: pass # Tutor mode topic = text or "anatomy and physiology overview" result = await Runner.run(tutor_agent, topic) # Tool-based quick reference + outline try: guide = topic_reference_guide(topic) except Exception: guide = {} try: outline = study_outline(topic) except Exception: outline = {} def bullets(arr): return "\n".join([f"- {b}" for b in arr]) if isinstance(arr, list) else "- (n/a)" quick_ref = ( f"### 📚 Quick Reference: {topic}\n" f"**Anatomy**\n{bullets(guide.get('anatomy', []))}\n\n" f"**Physiology**\n{bullets(guide.get('physiology', []))}\n\n" f"**Common Misconceptions**\n{bullets(guide.get('misconceptions', []))}\n\n" f"**Study Tips**\n{bullets(guide.get('study_tips', []))}\n\n" ) outline_md = "" if outline: subs = bullets(outline.get("subtopics", [])) qs = bullets(outline.get("practice_prompts", [])) outline_md = f"### 🗂️ Suggested Outline\n{subs}\n\n### 📝 Practice Prompts\n{qs}\n\n" caution = "> ⚠️ Education only — no diagnosis or treatment advice." answer = result.final_output or "I couldn’t generate an explanation." await cl.Message(content=f"{quick_ref}{outline_md}---\n{answer}\n\n{caution}").send()