# ----------------------------- # Single-file Chainlit app with inline "agents" shim # Project: Multimodal Biomedical Imaging Tutor (education only) # ----------------------------- import os, json from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv from pydantic import BaseModel, Field import chainlit as cl from openai import AsyncOpenAI as _SDKAsyncOpenAI # ============================= # Inline "agents" shim # ============================= def set_tracing_disabled(disabled: bool = True): return disabled def function_tool(func: Callable): func._is_tool = True return func def handoff(*args, **kwargs): return None class InputGuardrail: def __init__(self, guardrail_function: Callable): self.guardrail_function = guardrail_function @dataclass class GuardrailFunctionOutput: output_info: Any tripwire_triggered: bool = False tripwire_message: str = "" class InputGuardrailTripwireTriggered(Exception): pass class AsyncOpenAI: def __init__(self, api_key: str, base_url: Optional[str] = None): kwargs = {"api_key": api_key} if base_url: kwargs["base_url"] = base_url self._client = _SDKAsyncOpenAI(**kwargs) @property def client(self): return self._client class OpenAIChatCompletionsModel: def __init__(self, model: str, openai_client: AsyncOpenAI): self.model = model self.client = openai_client.client @dataclass class Agent: name: str instructions: str model: OpenAIChatCompletionsModel tools: Optional[List[Callable]] = field(default_factory=list) handoff_description: Optional[str] = None output_type: Optional[type] = None # optional Pydantic model class input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list) def tool_specs(self) -> List[Dict[str, Any]]: specs = [] for t in (self.tools or []): if getattr(t, "_is_tool", False): specs.append({ "type": "function", "function": { "name": t.__name__, "description": (t.__doc__ or "")[:512], "parameters": { "type": "object", "properties": { p: {"type": "string"} for p in t.__code__.co_varnames[:t.__code__.co_argcount] }, "required": list(t.__code__.co_varnames[:t.__code__.co_argcount]), }, }, }) return specs class Runner: @staticmethod async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None): msgs = [ {"role": "system", "content": agent.instructions}, {"role": "user", "content": user_input}, ] tools = agent.tool_specs() tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)} # simple tool loop for _ in range(4): resp = await agent.model.client.chat.completions.create( model=agent.model.model, messages=msgs, tools=tools if tools else None, tool_choice="auto" if tools else None, ) choice = resp.choices[0] msg = choice.message msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls}) if msg.tool_calls: for call in msg.tool_calls: fn_name = call.function.name args = json.loads(call.function.arguments or "{}") if fn_name in tool_map: try: result = tool_map[fn_name](**args) except Exception as e: result = {"error": str(e)} else: result = {"error": f"Unknown tool: {fn_name}"} msgs.append({ "role": "tool", "tool_call_id": call.id, "name": fn_name, "content": json.dumps(result), }) continue # let the model use tool outputs # finalize final_text = msg.content or "" final_obj = type("Result", (), {})() final_obj.final_output = final_text final_obj.context = context or {} if agent.output_type and issubclass(agent.output_type, BaseModel): try: data = agent.output_type.model_validate_json(final_text) final_obj.final_output = data.model_dump_json() final_obj.final_output_as = lambda t: data except Exception: final_obj.final_output_as = lambda t: final_text else: final_obj.final_output_as = lambda t: final_text return final_obj final_obj = type("Result", (), {})() final_obj.final_output = "Sorry, I couldn't complete the request." final_obj.context = context or {} final_obj.final_output_as = lambda t: final_obj.final_output return final_obj # ============================= # App configuration # ============================= load_dotenv() API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY") if not API_KEY: raise RuntimeError( "Missing GEMINI_API_KEY (or OPENAI_API_KEY). " "Add it in the Space secrets or a .env file." ) set_tracing_disabled(True) external_client: AsyncOpenAI = AsyncOpenAI( api_key=API_KEY, base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel( model="gemini-2.5-flash", openai_client=external_client, ) # ============================= # Domain models for tutor # ============================= class Section(BaseModel): title: str bullets: List[str] class TutorResponse(BaseModel): modality: str acquisition_overview: Section common_artifacts: Section preprocessing_methods: Section study_tips: Section caution: str # ============================= # Tools # ============================= @function_tool def infer_modality_from_filename(filename: str) -> dict: """ Guess modality (MRI/X-ray/CT/Ultrasound) from filename keywords. Returns: {"modality": ""} """ f = (filename or "").lower() guess = "unknown" mapping = { "xray": "X-ray", "x_ray": "X-ray", "xr": "X-ray", "cxr": "X-ray", "mri": "MRI", "t1": "MRI", "t2": "MRI", "flair": "MRI", "dwi": "MRI", "adc": "MRI", "ct": "CT", "cta": "CT", "ultrasound": "Ultrasound", "usg": "Ultrasound", "echo": "Ultrasound", } for key, mod in mapping.items(): if key in f: guess = mod break return {"modality": guess} @function_tool def imaging_reference_guide(modality: str) -> dict: """ Educational points for acquisition, artifacts, preprocessing, and study tips by modality. Education only (no diagnosis). """ mod = (modality or "").strip().lower() if mod in ["xray", "x-ray", "x_ray"]: return { "acquisition": [ "Projection radiography using ionizing radiation.", "Common views: AP, PA, lateral; exposure (kVp/mAs) and positioning matter.", "Grids/collimation reduce scatter and improve contrast." ], "artifacts": [ "Motion blur; under/overexposure affecting contrast.", "Grid cut-off; foreign objects (buttons, jewelry).", "Magnification/distortion from object–detector distance." ], "preprocessing": [ "Denoising (median/NLM), histogram equalization.", "Window/level selection (bone vs soft tissue) for teaching.", "Edge enhancement (unsharp mask) with caution (halo artifacts)." ], "study_tips": [ "Use a systematic approach (e.g., ABCDE for chest X-ray).", "Compare sides; verify devices, labels, positioning.", "Correlate with clinical scenario; keep a checklist." ], } if mod in ["mri", "mr"]: return { "acquisition": [ "MR uses RF pulses in a strong magnetic field; sequences set contrast.", "Key sequences: T1, T2, FLAIR, DWI/ADC, GRE/SWI.", "TR/TE/flip angle shape SNR, contrast, time." ], "artifacts": [ "Motion/ghosting (movement, pulsation).", "Susceptibility (metal, air-bone interfaces).", "Chemical shift, Gibbs ringing.", "B0/B1 inhomogeneity causing intensity bias." ], "preprocessing": [ "Bias-field correction (N4).", "Denoising (non-local means), registration/normalization.", "Skull stripping (brain), intensity standardization." ], "study_tips": [ "Know sequence intent (T1 anatomy, T2 fluid, FLAIR edema).", "Check diffusion for acute ischemia (with ADC).", "Use consistent windowing for longitudinal comparison." ], } if mod == "ct": return { "acquisition": [ "Helical CT reconstructs attenuation in Hounsfield Units.", "Kernels (bone vs soft) change sharpness/noise.", "Contrast phases (arterial/venous) match the task." ], "artifacts": [ "Beam hardening (streaks), partial volume.", "Motion (breathing/cardiac).", "Metal artifacts; consider MAR algorithms." ], "preprocessing": [ "Denoising (bilateral/NLM) while preserving edges.", "Appropriate window/level (lung, mediastinum, bone).", "Iterative reconstruction / metal artifact reduction." ], "study_tips": [ "Use standard planes; scroll systematically.", "Compare windows; document sizes/HU as needed.", "Correlate phase with the clinical question." ], } return { "acquisition": [ "Acquisition parameters define contrast, resolution, and noise.", "Positioning and motion control are crucial for quality." ], "artifacts": [ "Motion blur/ghosting; foreign objects and hardware.", "Parameter misconfiguration harms interpretability." ], "preprocessing": [ "Denoising and contrast normalization for clarity.", "Registration to standard planes for comparison." ], "study_tips": [ "Adopt a checklist; compare across time or sides.", "Learn modality-specific knobs (window/level, sequences)." ], } @function_tool def file_facts(filename: str, size_bytes: str) -> dict: """ Returns lightweight file facts: filename and byte size (as string). """ try: size = int(size_bytes) except Exception: size = -1 return {"filename": filename, "size_bytes": size} # ============================= # Agents # ============================= tutor_instructions = ( "You are a Biomedical Imaging Education Tutor. TEACH, do not diagnose.\n" "Given an uploaded MRI or X-ray, provide:\n" "1) Acquisition overview\n" "2) Common artifacts\n" "3) Preprocessing methods\n" "4) Study tips\n" "5) A caution line: education only, no diagnosis\n" "Use tools to infer modality from filename and to fetch a modality reference guide.\n" "If unclear, provide a generic overview and ask for clarification.\n" "Always respond as concise, well-structured bullet points.\n" "Absolutely avoid clinical diagnosis, disease identification, or treatment advice." ) tutor_agent = Agent( name="Biomedical Imaging Tutor", instructions=tutor_instructions, model=llm_model, tools=[infer_modality_from_filename, imaging_reference_guide, file_facts], ) class SafetyCheck(BaseModel): unsafe_medical_advice: bool requests_diagnosis: bool pii_included: bool reasoning: str guardrail_agent = Agent( name="Safety Classifier", instructions=( "Classify if the user's message requests medical diagnosis or unsafe medical advice, " "and if it includes personal identifiers. Respond as JSON with fields: " "{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}." ), model=llm_model, ) # ============================= # Chainlit flows # ============================= WELCOME = ( "🎓 **Multimodal Biomedical Imaging Tutor**\n\n" "Upload an **MRI** or **X-ray** image (PNG/JPG). I’ll explain:\n" "• Acquisition (how it’s made)\n" "• Common artifacts (what to watch for)\n" "• Preprocessing for study/teaching\n\n" "⚠️ *Education only — I do not provide diagnosis. For clinical concerns, consult a professional.*" ) @cl.on_chat_start async def on_chat_start(): await cl.Message(content=WELCOME).send() files = await cl.AskFileMessage( content="Please upload an **MRI or X-ray** image (PNG/JPG).", accept=["image/png", "image/jpeg"], max_size_mb=15, max_files=1, timeout=180, ).send() if not files: await cl.Message(content="No file uploaded. You can still ask general imaging questions.").send() return f = files[0] cl.user_session.set("last_file_path", f.path) cl.user_session.set("last_file_name", f.name) cl.user_session.set("last_file_size", f.size) await cl.Message( content=f"Received **{f.name}** ({f.size} bytes). " "Ask: *“Explain acquisition & artifacts for this image.”*" ).send() @cl.on_message async def on_message(message: cl.Message): # Safety check try: safety = await Runner.run(guardrail_agent, message.content) # parse best-effort parsed = safety.final_output try: data = json.loads(parsed) if isinstance(parsed, str) else parsed except Exception: data = {} if isinstance(data, dict): if data.get("unsafe_medical_advice") or data.get("requests_diagnosis"): await cl.Message( content=( "🚫 I can’t provide medical diagnoses or treatment advice.\n" "I’m happy to explain **imaging concepts**, **artifacts**, and **preprocessing** for learning." ) ).send() return except Exception: pass # continue gracefully # Context from last upload file_name = cl.user_session.get("last_file_name") file_size = cl.user_session.get("last_file_size") context_note = "" if file_name: context_note += f"The user uploaded a file named '{file_name}'.\n" if file_size is not None: context_note += f"File size: {file_size} bytes.\n" user_query = message.content if context_note: user_query = f"{user_query}\n\n[Context]\n{context_note}" # Run tutor result = await Runner.run(tutor_agent, user_query) # Quick reference facts facts_md = "" try: modality = infer_modality_from_filename(file_name or "").get("modality", "unknown") guide = imaging_reference_guide(modality) acq = "\n".join([f"- {b}" for b in guide.get("acquisition", [])]) art = "\n".join([f"- {b}" for b in guide.get("artifacts", [])]) prep = "\n".join([f"- {b}" for b in guide.get("preprocessing", [])]) tips = "\n".join([f"- {b}" for b in guide.get("study_tips", [])]) facts_md = ( f"### 📁 File\n" f"- Name: `{file_name or 'unknown'}`\n" f"- Size: `{file_size if file_size is not None else 'unknown'} bytes`\n\n" f"### 🔎 Modality (guess)\n- {modality}\n\n" f"### 📚 Reference Guide (study)\n" f"**Acquisition**\n{acq or '- (general)'}\n\n" f"**Common Artifacts**\n{art or '- (general)'}\n\n" f"**Preprocessing Ideas**\n{prep or '- (general)'}\n\n" f"**Study Tips**\n{tips or '- (general)'}\n\n" f"> ⚠️ Education only — no diagnosis.\n" ) except Exception: pass text = result.final_output or "I couldn’t generate an explanation." await cl.Message(content=f"{facts_md}\n---\n{text}").send()