| | |
| | |
| | |
| | |
| | import os, json, time |
| | from dataclasses import dataclass, field |
| | from typing import Any, Callable, Dict, List, Optional |
| |
|
| | import chainlit as cl |
| | from dotenv import load_dotenv |
| | from pydantic import BaseModel |
| | from openai import AsyncOpenAI as _SDKAsyncOpenAI |
| | from huggingface_hub import InferenceClient |
| | from PIL import Image |
| |
|
| | |
| | |
| | |
| | def set_tracing_disabled(disabled: bool = True): |
| | return disabled |
| |
|
| | def function_tool(func: Callable): |
| | func._is_tool = True |
| | return func |
| |
|
| | def handoff(*args, **kwargs): |
| | return None |
| |
|
| | class InputGuardrail: |
| | def __init__(self, guardrail_function: Callable): |
| | self.guardrail_function = guardrail_function |
| |
|
| | @dataclass |
| | class GuardrailFunctionOutput: |
| | output_info: Any |
| | tripwire_triggered: bool = False |
| | tripwire_message: str = "" |
| |
|
| | class InputGuardrailTripwireTriggered(Exception): |
| | pass |
| |
|
| | class AsyncOpenAI: |
| | def __init__(self, api_key: str, base_url: Optional[str] = None): |
| | kwargs = {"api_key": api_key} |
| | if base_url: |
| | kwargs["base_url"] = base_url |
| | self._client = _SDKAsyncOpenAI(**kwargs) |
| |
|
| | @property |
| | def client(self): |
| | return self._client |
| |
|
| | class OpenAIChatCompletionsModel: |
| | def __init__(self, model: str, openai_client: AsyncOpenAI): |
| | self.model = model |
| | self.client = openai_client.client |
| |
|
| | @dataclass |
| | class Agent: |
| | name: str |
| | instructions: str |
| | model: OpenAIChatCompletionsModel |
| | tools: Optional[List[Callable]] = field(default_factory=list) |
| | handoff_description: Optional[str] = None |
| | output_type: Optional[type] = None |
| | input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list) |
| |
|
| | def tool_specs(self) -> List[Dict[str, Any]]: |
| | specs = [] |
| | for t in (self.tools or []): |
| | if getattr(t, "_is_tool", False): |
| | specs.append({ |
| | "type": "function", |
| | "function": { |
| | "name": t.__name__, |
| | "description": (t.__doc__ or "")[:512], |
| | "parameters": { |
| | "type": "object", |
| | "properties": { |
| | p: {"type": "string"} |
| | for p in t.__code__.co_varnames[:t.__code__.co_argcount] |
| | }, |
| | "required": list(t.__code__.co_varnames[:t.__code__.co_argcount]), |
| | }, |
| | }, |
| | }) |
| | return specs |
| |
|
| | class Runner: |
| | @staticmethod |
| | async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None): |
| | msgs = [ |
| | {"role": "system", "content": agent.instructions}, |
| | {"role": "user", "content": user_input}, |
| | ] |
| | tools = agent.tool_specs() |
| | tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)} |
| |
|
| | |
| | for _ in range(4): |
| | resp = await agent.model.client.chat.completions.create( |
| | model=agent.model.model, |
| | messages=msgs, |
| | tools=tools if tools else None, |
| | tool_choice="auto" if tools else None, |
| | ) |
| | choice = resp.choices[0] |
| | msg = choice.message |
| | msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls}) |
| |
|
| | if msg.tool_calls: |
| | for call in msg.tool_calls: |
| | fn_name = call.function.name |
| | args = json.loads(call.function.arguments or "{}") |
| | if fn_name in tool_map: |
| | try: |
| | result = tool_map[fn_name](**args) |
| | except Exception as e: |
| | result = {"error": str(e)} |
| | else: |
| | result = {"error": f"Unknown tool: {fn_name}"} |
| | msgs.append({ |
| | "role": "tool", |
| | "tool_call_id": call.id, |
| | "name": fn_name, |
| | "content": json.dumps(result), |
| | }) |
| | continue |
| |
|
| | final_text = msg.content or "" |
| | final_obj = type("Result", (), {})() |
| | final_obj.final_output = final_text |
| | final_obj.context = context or {} |
| | if agent.output_type and issubclass(agent.output_type, BaseModel): |
| | try: |
| | data = agent.output_type.model_validate_json(final_text) |
| | final_obj.final_output = data.model_dump_json() |
| | final_obj.final_output_as = lambda t: data |
| | except Exception: |
| | final_obj.final_output_as = lambda t: final_text |
| | else: |
| | final_obj.final_output_as = lambda t: final_text |
| | return final_obj |
| |
|
| | final_obj = type("Result", (), {})() |
| | final_obj.final_output = "Sorry, I couldn't complete the request." |
| | final_obj.context = context or {} |
| | final_obj.final_output_as = lambda t: final_obj.final_output |
| | return final_obj |
| |
|
| | |
| | |
| | |
| | load_dotenv() |
| | API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY") |
| | if not API_KEY: |
| | raise RuntimeError("Missing GEMINI_API_KEY (or OPENAI_API_KEY). Add it in Space secrets or .env.") |
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| |
|
| | set_tracing_disabled(True) |
| |
|
| | external_client: AsyncOpenAI = AsyncOpenAI( |
| | api_key=API_KEY, |
| | base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
| | ) |
| | llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel( |
| | model="gemini-2.5-flash", |
| | openai_client=external_client, |
| | ) |
| |
|
| | |
| | |
| | |
| | @function_tool |
| | def topic_reference_guide(topic: str) -> dict: |
| | """ |
| | Educational bullets for common Anatomy & Physiology topics. |
| | Returns dict with anatomy, physiology, misconceptions, study_tips. |
| | """ |
| | t = (topic or "").lower() |
| |
|
| | def pack(anat, phys, misc, tips): |
| | return {"anatomy": anat, "physiology": phys, "misconceptions": misc, "study_tips": tips} |
| |
|
| | if any(k in t for k in ["cardiac conduction", "sa node", "av node", "bundle", "purkinje", "ecg"]): |
| | return pack( |
| | anat=[ |
| | "SA node (RA) → AV node → His bundle → R/L bundle branches → Purkinje fibers.", |
| | "Fibrous skeleton insulates atria from ventricles; AV node is the gateway." |
| | ], |
| | phys=[ |
| | "Pacemaker automaticity (If current) sets heart rate.", |
| | "AV nodal delay allows ventricular filling; His–Purkinje enables synchronized contraction." |
| | ], |
| | misc=[ |
| | "Misconception: all tissues pace equally — SA node dominates.", |
| | "PR interval ≈ AV nodal delay." |
| | ], |
| | tips=[ |
| | "Map ECG waves to mechanics (P/QRS/T).", |
| | "Relate ion channels to nodal vs myocyte AP phases." |
| | ], |
| | ) |
| |
|
| | if any(k in t for k in ["nephron", "kidney", "gfr", "raas", "countercurrent"]): |
| | return pack( |
| | anat=[ |
| | "Segments: Bowman’s → PCT → DLH → ALH (thin/thick) → DCT → collecting duct.", |
| | "Cortex (glomeruli/PCT/DCT) vs medulla (loops/collecting ducts)." |
| | ], |
| | phys=[ |
| | "GFR via Starling forces; PCT bulk reabsorption; ALH generates gradient; DCT/CD fine-tune (ADH/aldosterone).", |
| | "Countercurrent multiplication + vasa recta maintain medullary gradient." |
| | ], |
| | misc=[ |
| | "Misconception: water reabsorption is equal everywhere — hormones control CD concentration.", |
| | "Urea also supports medullary osmolality." |
| | ], |
| | tips=[ |
| | "Sketch transporters per segment.", |
| | "Practice ‘what if’ with ↑ADH/↑Aldosterone/↑GFR." |
| | ], |
| | ) |
| |
|
| | if any(k in t for k in ["alveolar", "gas exchange", "v/q", "ventilation perfusion", "oxygen dissociation"]): |
| | return pack( |
| | anat=[ |
| | "Conducting vs respiratory zones; Type I (exchange) vs Type II (surfactant) pneumocytes." |
| | ], |
| | phys=[ |
| | "V/Q matching optimizes exchange; extremes: dead space (high V/Q) vs shunt (low V/Q).", |
| | "O2–Hb curve shifts with pH, CO2, temp, 2,3-BPG (Bohr effect)." |
| | ], |
| | misc=[ |
| | "Misconception: uniform V/Q across lung — gravity/disease alter distribution.", |
| | "Assuming PaO2–SaO2 linearity (it’s sigmoidal)." |
| | ], |
| | tips=[ |
| | "Draw V/Q along lung height.", |
| | "Link spirometry patterns to mechanics." |
| | ], |
| | ) |
| |
|
| | if any(k in t for k in ["neuron", "synapse", "action potential", "neurotransmitter"]): |
| | return pack( |
| | anat=[ |
| | "Neuron: dendrites, soma, axon hillock, axon; myelin & nodes." |
| | ], |
| | phys=[ |
| | "AP: Na+ depolarization → K+ repolarization; refractory periods.", |
| | "Chemical synapse: Ca2+-dependent vesicle release; EPSPs/IPSPs summate." |
| | ], |
| | misc=[ |
| | "Misconception: any EPSP fires AP — threshold & summation matter." |
| | ], |
| | tips=[ |
| | "Relate channel states to AP phases.", |
| | "Compare ionotropic vs metabotropic effects." |
| | ], |
| | ) |
| |
|
| | if any(k in t for k in ["muscle contraction", "excitation contraction", "sarcomere"]): |
| | return pack( |
| | anat=[ |
| | "Sarcomere: Z to Z; thin (actin/troponin/tropomyosin) vs thick (myosin).", |
| | "T-tubules & SR triads (skeletal)." |
| | ], |
| | phys=[ |
| | "AP → DHPR → RyR → Ca2+ release → troponin C → cross-bridge cycling.", |
| | "Force–length & force–velocity relationships; motor unit recruitment." |
| | ], |
| | misc=[ |
| | "ATP also needed for detachment & Ca2+ resequestration." |
| | ], |
| | tips=[ |
| | "Diagram cross-bridge cycle.", |
| | "Predict effects of length on force." |
| | ], |
| | ) |
| |
|
| | return pack( |
| | anat=["Identify major structures & relationships."], |
| | phys=["Describe inputs → key steps → outputs; control loops."], |
| | misc=["Clarify commonly conflated terms; distinguish variation from pathology (education-only)."], |
| | tips=["Make a labeled sketch; use ‘if-this-then-that’ scenarios to test understanding."], |
| | ) |
| |
|
| | @function_tool |
| | def study_outline(topic: str) -> dict: |
| | """Return a suggested outline and practice prompts for the topic.""" |
| | return { |
| | "topic": topic, |
| | "subtopics": [ |
| | "Key structures & relationships", |
| | "Mechanism (step-by-step)", |
| | "Control & feedback", |
| | "Quantitative intuition (flows, pressures, potentials)", |
| | "Common misconceptions" |
| | ], |
| | "practice_prompts": [ |
| | f"Explain {topic} in 5 steps to a first-year student.", |
| | f"Draw a block diagram of {topic} with inputs/outputs.", |
| | f"What changes if a key parameter increases/decreases in {topic}?", |
| | ], |
| | } |
| |
|
| | |
| | |
| | |
| | tutor_instructions = ( |
| | "You are an Anatomy & Physiology Tutor. TEACH, do not diagnose.\n" |
| | "Given a topic (e.g., 'cardiac conduction', 'nephron physiology'), produce concise bullet points:\n" |
| | "1) Anatomy (structures/relationships)\n" |
| | "2) Physiology (inputs → steps → outputs)\n" |
| | "3) Common misconceptions\n" |
| | "4) Study tips\n" |
| | "5) Caution: education only, no diagnosis\n" |
| | "Use tools (topic_reference_guide, study_outline) to ground the response.\n" |
| | "Avoid clinical diagnosis or treatment advice." |
| | ) |
| | tutor_agent = Agent( |
| | name="A&P Tutor", |
| | instructions=tutor_instructions, |
| | model=llm_model, |
| | tools=[topic_reference_guide, study_outline], |
| | ) |
| |
|
| | guardrail_agent = Agent( |
| | name="Safety Classifier", |
| | instructions=( |
| | "Classify if the user's message requests medical diagnosis or unsafe medical advice, " |
| | "and if it includes personal identifiers. Respond as JSON with fields: " |
| | "{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}." |
| | ), |
| | model=llm_model, |
| | ) |
| |
|
| | |
| | |
| | |
| | def sdxl_png(prompt: str, negative: str = "") -> str: |
| | token = os.getenv("HF_TOKEN") |
| | if not token: |
| | raise RuntimeError("Missing HF_TOKEN. Add it in Space secrets.") |
| | client = InferenceClient("stabilityai/stable-diffusion-xl-base-1.0", token=token) |
| |
|
| | img: Image.Image = client.text_to_image( |
| | prompt = prompt + " | educational, clean vector style, white background, no labels, safe-for-work", |
| | negative_prompt = negative or "text, watermark, logo, gore, photorealistic patient, clutter", |
| | width = 1024, height = 768, |
| | guidance_scale = 7.5, |
| | num_inference_steps = 30, |
| | seed = 42 |
| | ) |
| | out_dir = os.environ.get("CHAINLIT_FILES_DIR") or os.path.join(os.getcwd(), ".files") |
| | os.makedirs(out_dir, exist_ok=True) |
| | path = os.path.join(out_dir, f"sdxl-{int(time.time())}.png") |
| | img.save(path) |
| | return path |
| |
|
| | |
| | |
| | |
| | WELCOME = ( |
| | "🧠 **Anatomy & Physiology Tutor** + 🎨 **/gen Diagram (SD-XL)**\n\n" |
| | "• Ask any A&P topic (e.g., *cardiac conduction*, *nephron physiology*, *gas exchange*).\n" |
| | "• Generate a clean educational diagram (PNG) with:\n" |
| | " `/gen isometric nephron diagram, flat vector, white background, no labels`\n\n" |
| | "⚠️ Education only — no diagnosis or clinical advice." |
| | ) |
| |
|
| | @cl.on_chat_start |
| | async def on_chat_start(): |
| | await cl.Message(content=WELCOME).send() |
| |
|
| | @cl.on_message |
| | async def on_message(message: cl.Message): |
| | text = (message.content or "").strip() |
| |
|
| | |
| | if text.lower().startswith("/gen "): |
| | desc = text[5:].strip() |
| | if not desc: |
| | await cl.Message(content="Please provide a description after `/gen`.\nExample: `/gen isometric nephron diagram, flat vector`").send() |
| | return |
| | try: |
| | path = sdxl_png(desc) |
| | except Exception as e: |
| | await cl.Message(content=f"Generation failed: {e}").send() |
| | return |
| | await cl.Message( |
| | content="🎨 **Generated diagram** (education only):", |
| | elements=[cl.Image(path=path, name=os.path.basename(path), display="inline")], |
| | ).send() |
| | return |
| |
|
| | |
| | try: |
| | safety = await Runner.run(guardrail_agent, text) |
| | parsed = safety.final_output |
| | try: |
| | data = json.loads(parsed) if isinstance(parsed, str) else parsed |
| | except Exception: |
| | data = {} |
| | if isinstance(data, dict) and (data.get("unsafe_medical_advice") or data.get("requests_diagnosis")): |
| | await cl.Message( |
| | content="🚫 I can’t provide diagnosis or treatment advice. I can teach A&P concepts and generate **educational** diagrams." |
| | ).send() |
| | return |
| | except Exception: |
| | pass |
| |
|
| | |
| | topic = text or "anatomy and physiology overview" |
| | result = await Runner.run(tutor_agent, topic) |
| |
|
| | |
| | try: |
| | guide = topic_reference_guide(topic) |
| | except Exception: |
| | guide = {} |
| | try: |
| | outline = study_outline(topic) |
| | except Exception: |
| | outline = {} |
| |
|
| | def bullets(arr): |
| | return "\n".join([f"- {b}" for b in arr]) if isinstance(arr, list) else "- (n/a)" |
| |
|
| | quick_ref = ( |
| | f"### 📚 Quick Reference: {topic}\n" |
| | f"**Anatomy**\n{bullets(guide.get('anatomy', []))}\n\n" |
| | f"**Physiology**\n{bullets(guide.get('physiology', []))}\n\n" |
| | f"**Common Misconceptions**\n{bullets(guide.get('misconceptions', []))}\n\n" |
| | f"**Study Tips**\n{bullets(guide.get('study_tips', []))}\n\n" |
| | ) |
| |
|
| | outline_md = "" |
| | if outline: |
| | subs = bullets(outline.get("subtopics", [])) |
| | qs = bullets(outline.get("practice_prompts", [])) |
| | outline_md = f"### 🗂️ Suggested Outline\n{subs}\n\n### 📝 Practice Prompts\n{qs}\n\n" |
| |
|
| | caution = "> ⚠️ Education only — no diagnosis or treatment advice." |
| | answer = result.final_output or "I couldn’t generate an explanation." |
| | await cl.Message(content=f"{quick_ref}{outline_md}---\n{answer}\n\n{caution}").send() |