ibadhasnain's picture
Create app.py
4f7d15b verified
# -----------------------------
# Single-file Chainlit app with inline "agents" shim
# Project: Multimodal Biomedical Imaging Tutor (education only)
# -----------------------------
import os, json
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from dotenv import load_dotenv
from pydantic import BaseModel, Field
import chainlit as cl
from openai import AsyncOpenAI as _SDKAsyncOpenAI
# =============================
# Inline "agents" shim
# =============================
def set_tracing_disabled(disabled: bool = True):
return disabled
def function_tool(func: Callable):
func._is_tool = True
return func
def handoff(*args, **kwargs):
return None
class InputGuardrail:
def __init__(self, guardrail_function: Callable):
self.guardrail_function = guardrail_function
@dataclass
class GuardrailFunctionOutput:
output_info: Any
tripwire_triggered: bool = False
tripwire_message: str = ""
class InputGuardrailTripwireTriggered(Exception):
pass
class AsyncOpenAI:
def __init__(self, api_key: str, base_url: Optional[str] = None):
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self._client = _SDKAsyncOpenAI(**kwargs)
@property
def client(self):
return self._client
class OpenAIChatCompletionsModel:
def __init__(self, model: str, openai_client: AsyncOpenAI):
self.model = model
self.client = openai_client.client
@dataclass
class Agent:
name: str
instructions: str
model: OpenAIChatCompletionsModel
tools: Optional[List[Callable]] = field(default_factory=list)
handoff_description: Optional[str] = None
output_type: Optional[type] = None # optional Pydantic model class
input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list)
def tool_specs(self) -> List[Dict[str, Any]]:
specs = []
for t in (self.tools or []):
if getattr(t, "_is_tool", False):
specs.append({
"type": "function",
"function": {
"name": t.__name__,
"description": (t.__doc__ or "")[:512],
"parameters": {
"type": "object",
"properties": {
p: {"type": "string"}
for p in t.__code__.co_varnames[:t.__code__.co_argcount]
},
"required": list(t.__code__.co_varnames[:t.__code__.co_argcount]),
},
},
})
return specs
class Runner:
@staticmethod
async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None):
msgs = [
{"role": "system", "content": agent.instructions},
{"role": "user", "content": user_input},
]
tools = agent.tool_specs()
tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)}
# simple tool loop
for _ in range(4):
resp = await agent.model.client.chat.completions.create(
model=agent.model.model,
messages=msgs,
tools=tools if tools else None,
tool_choice="auto" if tools else None,
)
choice = resp.choices[0]
msg = choice.message
msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls})
if msg.tool_calls:
for call in msg.tool_calls:
fn_name = call.function.name
args = json.loads(call.function.arguments or "{}")
if fn_name in tool_map:
try:
result = tool_map[fn_name](**args)
except Exception as e:
result = {"error": str(e)}
else:
result = {"error": f"Unknown tool: {fn_name}"}
msgs.append({
"role": "tool",
"tool_call_id": call.id,
"name": fn_name,
"content": json.dumps(result),
})
continue # let the model use tool outputs
# finalize
final_text = msg.content or ""
final_obj = type("Result", (), {})()
final_obj.final_output = final_text
final_obj.context = context or {}
if agent.output_type and issubclass(agent.output_type, BaseModel):
try:
data = agent.output_type.model_validate_json(final_text)
final_obj.final_output = data.model_dump_json()
final_obj.final_output_as = lambda t: data
except Exception:
final_obj.final_output_as = lambda t: final_text
else:
final_obj.final_output_as = lambda t: final_text
return final_obj
final_obj = type("Result", (), {})()
final_obj.final_output = "Sorry, I couldn't complete the request."
final_obj.context = context or {}
final_obj.final_output_as = lambda t: final_obj.final_output
return final_obj
# =============================
# App configuration
# =============================
load_dotenv()
API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY")
if not API_KEY:
raise RuntimeError(
"Missing GEMINI_API_KEY (or OPENAI_API_KEY). "
"Add it in the Space secrets or a .env file."
)
set_tracing_disabled(True)
external_client: AsyncOpenAI = AsyncOpenAI(
api_key=API_KEY,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel(
model="gemini-2.5-flash",
openai_client=external_client,
)
# =============================
# Domain models for tutor
# =============================
class Section(BaseModel):
title: str
bullets: List[str]
class TutorResponse(BaseModel):
modality: str
acquisition_overview: Section
common_artifacts: Section
preprocessing_methods: Section
study_tips: Section
caution: str
# =============================
# Tools
# =============================
@function_tool
def infer_modality_from_filename(filename: str) -> dict:
"""
Guess modality (MRI/X-ray/CT/Ultrasound) from filename keywords.
Returns: {"modality": "<guess or unknown>"}
"""
f = (filename or "").lower()
guess = "unknown"
mapping = {
"xray": "X-ray", "x_ray": "X-ray", "xr": "X-ray", "cxr": "X-ray",
"mri": "MRI", "t1": "MRI", "t2": "MRI", "flair": "MRI", "dwi": "MRI", "adc": "MRI",
"ct": "CT", "cta": "CT",
"ultrasound": "Ultrasound", "usg": "Ultrasound", "echo": "Ultrasound",
}
for key, mod in mapping.items():
if key in f:
guess = mod
break
return {"modality": guess}
@function_tool
def imaging_reference_guide(modality: str) -> dict:
"""
Educational points for acquisition, artifacts, preprocessing, and study tips by modality.
Education only (no diagnosis).
"""
mod = (modality or "").strip().lower()
if mod in ["xray", "x-ray", "x_ray"]:
return {
"acquisition": [
"Projection radiography using ionizing radiation.",
"Common views: AP, PA, lateral; exposure (kVp/mAs) and positioning matter.",
"Grids/collimation reduce scatter and improve contrast."
],
"artifacts": [
"Motion blur; under/overexposure affecting contrast.",
"Grid cut-off; foreign objects (buttons, jewelry).",
"Magnification/distortion from object–detector distance."
],
"preprocessing": [
"Denoising (median/NLM), histogram equalization.",
"Window/level selection (bone vs soft tissue) for teaching.",
"Edge enhancement (unsharp mask) with caution (halo artifacts)."
],
"study_tips": [
"Use a systematic approach (e.g., ABCDE for chest X-ray).",
"Compare sides; verify devices, labels, positioning.",
"Correlate with clinical scenario; keep a checklist."
],
}
if mod in ["mri", "mr"]:
return {
"acquisition": [
"MR uses RF pulses in a strong magnetic field; sequences set contrast.",
"Key sequences: T1, T2, FLAIR, DWI/ADC, GRE/SWI.",
"TR/TE/flip angle shape SNR, contrast, time."
],
"artifacts": [
"Motion/ghosting (movement, pulsation).",
"Susceptibility (metal, air-bone interfaces).",
"Chemical shift, Gibbs ringing.",
"B0/B1 inhomogeneity causing intensity bias."
],
"preprocessing": [
"Bias-field correction (N4).",
"Denoising (non-local means), registration/normalization.",
"Skull stripping (brain), intensity standardization."
],
"study_tips": [
"Know sequence intent (T1 anatomy, T2 fluid, FLAIR edema).",
"Check diffusion for acute ischemia (with ADC).",
"Use consistent windowing for longitudinal comparison."
],
}
if mod == "ct":
return {
"acquisition": [
"Helical CT reconstructs attenuation in Hounsfield Units.",
"Kernels (bone vs soft) change sharpness/noise.",
"Contrast phases (arterial/venous) match the task."
],
"artifacts": [
"Beam hardening (streaks), partial volume.",
"Motion (breathing/cardiac).",
"Metal artifacts; consider MAR algorithms."
],
"preprocessing": [
"Denoising (bilateral/NLM) while preserving edges.",
"Appropriate window/level (lung, mediastinum, bone).",
"Iterative reconstruction / metal artifact reduction."
],
"study_tips": [
"Use standard planes; scroll systematically.",
"Compare windows; document sizes/HU as needed.",
"Correlate phase with the clinical question."
],
}
return {
"acquisition": [
"Acquisition parameters define contrast, resolution, and noise.",
"Positioning and motion control are crucial for quality."
],
"artifacts": [
"Motion blur/ghosting; foreign objects and hardware.",
"Parameter misconfiguration harms interpretability."
],
"preprocessing": [
"Denoising and contrast normalization for clarity.",
"Registration to standard planes for comparison."
],
"study_tips": [
"Adopt a checklist; compare across time or sides.",
"Learn modality-specific knobs (window/level, sequences)."
],
}
@function_tool
def file_facts(filename: str, size_bytes: str) -> dict:
"""
Returns lightweight file facts: filename and byte size (as string).
"""
try:
size = int(size_bytes)
except Exception:
size = -1
return {"filename": filename, "size_bytes": size}
# =============================
# Agents
# =============================
tutor_instructions = (
"You are a Biomedical Imaging Education Tutor. TEACH, do not diagnose.\n"
"Given an uploaded MRI or X-ray, provide:\n"
"1) Acquisition overview\n"
"2) Common artifacts\n"
"3) Preprocessing methods\n"
"4) Study tips\n"
"5) A caution line: education only, no diagnosis\n"
"Use tools to infer modality from filename and to fetch a modality reference guide.\n"
"If unclear, provide a generic overview and ask for clarification.\n"
"Always respond as concise, well-structured bullet points.\n"
"Absolutely avoid clinical diagnosis, disease identification, or treatment advice."
)
tutor_agent = Agent(
name="Biomedical Imaging Tutor",
instructions=tutor_instructions,
model=llm_model,
tools=[infer_modality_from_filename, imaging_reference_guide, file_facts],
)
class SafetyCheck(BaseModel):
unsafe_medical_advice: bool
requests_diagnosis: bool
pii_included: bool
reasoning: str
guardrail_agent = Agent(
name="Safety Classifier",
instructions=(
"Classify if the user's message requests medical diagnosis or unsafe medical advice, "
"and if it includes personal identifiers. Respond as JSON with fields: "
"{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}."
),
model=llm_model,
)
# =============================
# Chainlit flows
# =============================
WELCOME = (
"🎓 **Multimodal Biomedical Imaging Tutor**\n\n"
"Upload an **MRI** or **X-ray** image (PNG/JPG). I’ll explain:\n"
"• Acquisition (how it’s made)\n"
"• Common artifacts (what to watch for)\n"
"• Preprocessing for study/teaching\n\n"
"⚠️ *Education only — I do not provide diagnosis. For clinical concerns, consult a professional.*"
)
@cl.on_chat_start
async def on_chat_start():
await cl.Message(content=WELCOME).send()
files = await cl.AskFileMessage(
content="Please upload an **MRI or X-ray** image (PNG/JPG).",
accept=["image/png", "image/jpeg"],
max_size_mb=15,
max_files=1,
timeout=180,
).send()
if not files:
await cl.Message(content="No file uploaded. You can still ask general imaging questions.").send()
return
f = files[0]
cl.user_session.set("last_file_path", f.path)
cl.user_session.set("last_file_name", f.name)
cl.user_session.set("last_file_size", f.size)
await cl.Message(
content=f"Received **{f.name}** ({f.size} bytes). "
"Ask: *“Explain acquisition & artifacts for this image.”*"
).send()
@cl.on_message
async def on_message(message: cl.Message):
# Safety check
try:
safety = await Runner.run(guardrail_agent, message.content)
# parse best-effort
parsed = safety.final_output
try:
data = json.loads(parsed) if isinstance(parsed, str) else parsed
except Exception:
data = {}
if isinstance(data, dict):
if data.get("unsafe_medical_advice") or data.get("requests_diagnosis"):
await cl.Message(
content=(
"🚫 I can’t provide medical diagnoses or treatment advice.\n"
"I’m happy to explain **imaging concepts**, **artifacts**, and **preprocessing** for learning."
)
).send()
return
except Exception:
pass # continue gracefully
# Context from last upload
file_name = cl.user_session.get("last_file_name")
file_size = cl.user_session.get("last_file_size")
context_note = ""
if file_name:
context_note += f"The user uploaded a file named '{file_name}'.\n"
if file_size is not None:
context_note += f"File size: {file_size} bytes.\n"
user_query = message.content
if context_note:
user_query = f"{user_query}\n\n[Context]\n{context_note}"
# Run tutor
result = await Runner.run(tutor_agent, user_query)
# Quick reference facts
facts_md = ""
try:
modality = infer_modality_from_filename(file_name or "").get("modality", "unknown")
guide = imaging_reference_guide(modality)
acq = "\n".join([f"- {b}" for b in guide.get("acquisition", [])])
art = "\n".join([f"- {b}" for b in guide.get("artifacts", [])])
prep = "\n".join([f"- {b}" for b in guide.get("preprocessing", [])])
tips = "\n".join([f"- {b}" for b in guide.get("study_tips", [])])
facts_md = (
f"### 📁 File\n"
f"- Name: `{file_name or 'unknown'}`\n"
f"- Size: `{file_size if file_size is not None else 'unknown'} bytes`\n\n"
f"### 🔎 Modality (guess)\n- {modality}\n\n"
f"### 📚 Reference Guide (study)\n"
f"**Acquisition**\n{acq or '- (general)'}\n\n"
f"**Common Artifacts**\n{art or '- (general)'}\n\n"
f"**Preprocessing Ideas**\n{prep or '- (general)'}\n\n"
f"**Study Tips**\n{tips or '- (general)'}\n\n"
f"> ⚠️ Education only — no diagnosis.\n"
)
except Exception:
pass
text = result.final_output or "I couldn’t generate an explanation."
await cl.Message(content=f"{facts_md}\n---\n{text}").send()