|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, json |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Callable, Dict, List, Optional |
|
|
from dotenv import load_dotenv |
|
|
from pydantic import BaseModel, Field |
|
|
import chainlit as cl |
|
|
from openai import AsyncOpenAI as _SDKAsyncOpenAI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_tracing_disabled(disabled: bool = True): |
|
|
return disabled |
|
|
|
|
|
def function_tool(func: Callable): |
|
|
func._is_tool = True |
|
|
return func |
|
|
|
|
|
def handoff(*args, **kwargs): |
|
|
return None |
|
|
|
|
|
class InputGuardrail: |
|
|
def __init__(self, guardrail_function: Callable): |
|
|
self.guardrail_function = guardrail_function |
|
|
|
|
|
@dataclass |
|
|
class GuardrailFunctionOutput: |
|
|
output_info: Any |
|
|
tripwire_triggered: bool = False |
|
|
tripwire_message: str = "" |
|
|
|
|
|
class InputGuardrailTripwireTriggered(Exception): |
|
|
pass |
|
|
|
|
|
class AsyncOpenAI: |
|
|
def __init__(self, api_key: str, base_url: Optional[str] = None): |
|
|
kwargs = {"api_key": api_key} |
|
|
if base_url: |
|
|
kwargs["base_url"] = base_url |
|
|
self._client = _SDKAsyncOpenAI(**kwargs) |
|
|
|
|
|
@property |
|
|
def client(self): |
|
|
return self._client |
|
|
|
|
|
class OpenAIChatCompletionsModel: |
|
|
def __init__(self, model: str, openai_client: AsyncOpenAI): |
|
|
self.model = model |
|
|
self.client = openai_client.client |
|
|
|
|
|
@dataclass |
|
|
class Agent: |
|
|
name: str |
|
|
instructions: str |
|
|
model: OpenAIChatCompletionsModel |
|
|
tools: Optional[List[Callable]] = field(default_factory=list) |
|
|
handoff_description: Optional[str] = None |
|
|
output_type: Optional[type] = None |
|
|
input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list) |
|
|
|
|
|
def tool_specs(self) -> List[Dict[str, Any]]: |
|
|
specs = [] |
|
|
for t in (self.tools or []): |
|
|
if getattr(t, "_is_tool", False): |
|
|
specs.append({ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": t.__name__, |
|
|
"description": (t.__doc__ or "")[:512], |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
p: {"type": "string"} |
|
|
for p in t.__code__.co_varnames[:t.__code__.co_argcount] |
|
|
}, |
|
|
"required": list(t.__code__.co_varnames[:t.__code__.co_argcount]), |
|
|
}, |
|
|
}, |
|
|
}) |
|
|
return specs |
|
|
|
|
|
class Runner: |
|
|
@staticmethod |
|
|
async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None): |
|
|
msgs = [ |
|
|
{"role": "system", "content": agent.instructions}, |
|
|
{"role": "user", "content": user_input}, |
|
|
] |
|
|
tools = agent.tool_specs() |
|
|
tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)} |
|
|
|
|
|
|
|
|
for _ in range(4): |
|
|
resp = await agent.model.client.chat.completions.create( |
|
|
model=agent.model.model, |
|
|
messages=msgs, |
|
|
tools=tools if tools else None, |
|
|
tool_choice="auto" if tools else None, |
|
|
) |
|
|
|
|
|
choice = resp.choices[0] |
|
|
msg = choice.message |
|
|
msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls}) |
|
|
|
|
|
if msg.tool_calls: |
|
|
for call in msg.tool_calls: |
|
|
fn_name = call.function.name |
|
|
args = json.loads(call.function.arguments or "{}") |
|
|
if fn_name in tool_map: |
|
|
try: |
|
|
result = tool_map[fn_name](**args) |
|
|
except Exception as e: |
|
|
result = {"error": str(e)} |
|
|
else: |
|
|
result = {"error": f"Unknown tool: {fn_name}"} |
|
|
msgs.append({ |
|
|
"role": "tool", |
|
|
"tool_call_id": call.id, |
|
|
"name": fn_name, |
|
|
"content": json.dumps(result), |
|
|
}) |
|
|
continue |
|
|
|
|
|
|
|
|
final_text = msg.content or "" |
|
|
final_obj = type("Result", (), {})() |
|
|
final_obj.final_output = final_text |
|
|
final_obj.context = context or {} |
|
|
if agent.output_type and issubclass(agent.output_type, BaseModel): |
|
|
try: |
|
|
data = agent.output_type.model_validate_json(final_text) |
|
|
final_obj.final_output = data.model_dump_json() |
|
|
final_obj.final_output_as = lambda t: data |
|
|
except Exception: |
|
|
final_obj.final_output_as = lambda t: final_text |
|
|
else: |
|
|
final_obj.final_output_as = lambda t: final_text |
|
|
return final_obj |
|
|
|
|
|
final_obj = type("Result", (), {})() |
|
|
final_obj.final_output = "Sorry, I couldn't complete the request." |
|
|
final_obj.context = context or {} |
|
|
final_obj.final_output_as = lambda t: final_obj.final_output |
|
|
return final_obj |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY") |
|
|
if not API_KEY: |
|
|
raise RuntimeError( |
|
|
"Missing GEMINI_API_KEY (or OPENAI_API_KEY). " |
|
|
"Add it in the Space secrets or a .env file." |
|
|
) |
|
|
|
|
|
set_tracing_disabled(True) |
|
|
|
|
|
external_client: AsyncOpenAI = AsyncOpenAI( |
|
|
api_key=API_KEY, |
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
) |
|
|
llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel( |
|
|
model="gemini-2.5-flash", |
|
|
openai_client=external_client, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Section(BaseModel): |
|
|
title: str |
|
|
bullets: List[str] |
|
|
|
|
|
class TutorResponse(BaseModel): |
|
|
modality: str |
|
|
acquisition_overview: Section |
|
|
common_artifacts: Section |
|
|
preprocessing_methods: Section |
|
|
study_tips: Section |
|
|
caution: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@function_tool |
|
|
def infer_modality_from_filename(filename: str) -> dict: |
|
|
""" |
|
|
Guess modality (MRI/X-ray/CT/Ultrasound) from filename keywords. |
|
|
Returns: {"modality": "<guess or unknown>"} |
|
|
""" |
|
|
f = (filename or "").lower() |
|
|
guess = "unknown" |
|
|
mapping = { |
|
|
"xray": "X-ray", "x_ray": "X-ray", "xr": "X-ray", "cxr": "X-ray", |
|
|
"mri": "MRI", "t1": "MRI", "t2": "MRI", "flair": "MRI", "dwi": "MRI", "adc": "MRI", |
|
|
"ct": "CT", "cta": "CT", |
|
|
"ultrasound": "Ultrasound", "usg": "Ultrasound", "echo": "Ultrasound", |
|
|
} |
|
|
for key, mod in mapping.items(): |
|
|
if key in f: |
|
|
guess = mod |
|
|
break |
|
|
return {"modality": guess} |
|
|
|
|
|
@function_tool |
|
|
def imaging_reference_guide(modality: str) -> dict: |
|
|
""" |
|
|
Educational points for acquisition, artifacts, preprocessing, and study tips by modality. |
|
|
Education only (no diagnosis). |
|
|
""" |
|
|
mod = (modality or "").strip().lower() |
|
|
if mod in ["xray", "x-ray", "x_ray"]: |
|
|
return { |
|
|
"acquisition": [ |
|
|
"Projection radiography using ionizing radiation.", |
|
|
"Common views: AP, PA, lateral; exposure (kVp/mAs) and positioning matter.", |
|
|
"Grids/collimation reduce scatter and improve contrast." |
|
|
], |
|
|
"artifacts": [ |
|
|
"Motion blur; under/overexposure affecting contrast.", |
|
|
"Grid cut-off; foreign objects (buttons, jewelry).", |
|
|
"Magnification/distortion from object–detector distance." |
|
|
], |
|
|
"preprocessing": [ |
|
|
"Denoising (median/NLM), histogram equalization.", |
|
|
"Window/level selection (bone vs soft tissue) for teaching.", |
|
|
"Edge enhancement (unsharp mask) with caution (halo artifacts)." |
|
|
], |
|
|
"study_tips": [ |
|
|
"Use a systematic approach (e.g., ABCDE for chest X-ray).", |
|
|
"Compare sides; verify devices, labels, positioning.", |
|
|
"Correlate with clinical scenario; keep a checklist." |
|
|
], |
|
|
} |
|
|
if mod in ["mri", "mr"]: |
|
|
return { |
|
|
"acquisition": [ |
|
|
"MR uses RF pulses in a strong magnetic field; sequences set contrast.", |
|
|
"Key sequences: T1, T2, FLAIR, DWI/ADC, GRE/SWI.", |
|
|
"TR/TE/flip angle shape SNR, contrast, time." |
|
|
], |
|
|
"artifacts": [ |
|
|
"Motion/ghosting (movement, pulsation).", |
|
|
"Susceptibility (metal, air-bone interfaces).", |
|
|
"Chemical shift, Gibbs ringing.", |
|
|
"B0/B1 inhomogeneity causing intensity bias." |
|
|
], |
|
|
"preprocessing": [ |
|
|
"Bias-field correction (N4).", |
|
|
"Denoising (non-local means), registration/normalization.", |
|
|
"Skull stripping (brain), intensity standardization." |
|
|
], |
|
|
"study_tips": [ |
|
|
"Know sequence intent (T1 anatomy, T2 fluid, FLAIR edema).", |
|
|
"Check diffusion for acute ischemia (with ADC).", |
|
|
"Use consistent windowing for longitudinal comparison." |
|
|
], |
|
|
} |
|
|
if mod == "ct": |
|
|
return { |
|
|
"acquisition": [ |
|
|
"Helical CT reconstructs attenuation in Hounsfield Units.", |
|
|
"Kernels (bone vs soft) change sharpness/noise.", |
|
|
"Contrast phases (arterial/venous) match the task." |
|
|
], |
|
|
"artifacts": [ |
|
|
"Beam hardening (streaks), partial volume.", |
|
|
"Motion (breathing/cardiac).", |
|
|
"Metal artifacts; consider MAR algorithms." |
|
|
], |
|
|
"preprocessing": [ |
|
|
"Denoising (bilateral/NLM) while preserving edges.", |
|
|
"Appropriate window/level (lung, mediastinum, bone).", |
|
|
"Iterative reconstruction / metal artifact reduction." |
|
|
], |
|
|
"study_tips": [ |
|
|
"Use standard planes; scroll systematically.", |
|
|
"Compare windows; document sizes/HU as needed.", |
|
|
"Correlate phase with the clinical question." |
|
|
], |
|
|
} |
|
|
return { |
|
|
"acquisition": [ |
|
|
"Acquisition parameters define contrast, resolution, and noise.", |
|
|
"Positioning and motion control are crucial for quality." |
|
|
], |
|
|
"artifacts": [ |
|
|
"Motion blur/ghosting; foreign objects and hardware.", |
|
|
"Parameter misconfiguration harms interpretability." |
|
|
], |
|
|
"preprocessing": [ |
|
|
"Denoising and contrast normalization for clarity.", |
|
|
"Registration to standard planes for comparison." |
|
|
], |
|
|
"study_tips": [ |
|
|
"Adopt a checklist; compare across time or sides.", |
|
|
"Learn modality-specific knobs (window/level, sequences)." |
|
|
], |
|
|
} |
|
|
|
|
|
@function_tool |
|
|
def file_facts(filename: str, size_bytes: str) -> dict: |
|
|
""" |
|
|
Returns lightweight file facts: filename and byte size (as string). |
|
|
""" |
|
|
try: |
|
|
size = int(size_bytes) |
|
|
except Exception: |
|
|
size = -1 |
|
|
return {"filename": filename, "size_bytes": size} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tutor_instructions = ( |
|
|
"You are a Biomedical Imaging Education Tutor. TEACH, do not diagnose.\n" |
|
|
"Given an uploaded MRI or X-ray, provide:\n" |
|
|
"1) Acquisition overview\n" |
|
|
"2) Common artifacts\n" |
|
|
"3) Preprocessing methods\n" |
|
|
"4) Study tips\n" |
|
|
"5) A caution line: education only, no diagnosis\n" |
|
|
"Use tools to infer modality from filename and to fetch a modality reference guide.\n" |
|
|
"If unclear, provide a generic overview and ask for clarification.\n" |
|
|
"Always respond as concise, well-structured bullet points.\n" |
|
|
"Absolutely avoid clinical diagnosis, disease identification, or treatment advice." |
|
|
) |
|
|
|
|
|
tutor_agent = Agent( |
|
|
name="Biomedical Imaging Tutor", |
|
|
instructions=tutor_instructions, |
|
|
model=llm_model, |
|
|
tools=[infer_modality_from_filename, imaging_reference_guide, file_facts], |
|
|
) |
|
|
|
|
|
class SafetyCheck(BaseModel): |
|
|
unsafe_medical_advice: bool |
|
|
requests_diagnosis: bool |
|
|
pii_included: bool |
|
|
reasoning: str |
|
|
|
|
|
guardrail_agent = Agent( |
|
|
name="Safety Classifier", |
|
|
instructions=( |
|
|
"Classify if the user's message requests medical diagnosis or unsafe medical advice, " |
|
|
"and if it includes personal identifiers. Respond as JSON with fields: " |
|
|
"{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}." |
|
|
), |
|
|
model=llm_model, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WELCOME = ( |
|
|
"🎓 **Multimodal Biomedical Imaging Tutor**\n\n" |
|
|
"Upload an **MRI** or **X-ray** image (PNG/JPG). I’ll explain:\n" |
|
|
"• Acquisition (how it’s made)\n" |
|
|
"• Common artifacts (what to watch for)\n" |
|
|
"• Preprocessing for study/teaching\n\n" |
|
|
"⚠️ *Education only — I do not provide diagnosis. For clinical concerns, consult a professional.*" |
|
|
) |
|
|
|
|
|
@cl.on_chat_start |
|
|
async def on_chat_start(): |
|
|
await cl.Message(content=WELCOME).send() |
|
|
files = await cl.AskFileMessage( |
|
|
content="Please upload an **MRI or X-ray** image (PNG/JPG).", |
|
|
accept=["image/png", "image/jpeg"], |
|
|
max_size_mb=15, |
|
|
max_files=1, |
|
|
timeout=180, |
|
|
).send() |
|
|
|
|
|
if not files: |
|
|
await cl.Message(content="No file uploaded. You can still ask general imaging questions.").send() |
|
|
return |
|
|
|
|
|
f = files[0] |
|
|
cl.user_session.set("last_file_path", f.path) |
|
|
cl.user_session.set("last_file_name", f.name) |
|
|
cl.user_session.set("last_file_size", f.size) |
|
|
|
|
|
await cl.Message( |
|
|
content=f"Received **{f.name}** ({f.size} bytes). " |
|
|
"Ask: *“Explain acquisition & artifacts for this image.”*" |
|
|
).send() |
|
|
|
|
|
@cl.on_message |
|
|
async def on_message(message: cl.Message): |
|
|
|
|
|
try: |
|
|
safety = await Runner.run(guardrail_agent, message.content) |
|
|
|
|
|
parsed = safety.final_output |
|
|
try: |
|
|
data = json.loads(parsed) if isinstance(parsed, str) else parsed |
|
|
except Exception: |
|
|
data = {} |
|
|
if isinstance(data, dict): |
|
|
if data.get("unsafe_medical_advice") or data.get("requests_diagnosis"): |
|
|
await cl.Message( |
|
|
content=( |
|
|
"🚫 I can’t provide medical diagnoses or treatment advice.\n" |
|
|
"I’m happy to explain **imaging concepts**, **artifacts**, and **preprocessing** for learning." |
|
|
) |
|
|
).send() |
|
|
return |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
file_name = cl.user_session.get("last_file_name") |
|
|
file_size = cl.user_session.get("last_file_size") |
|
|
|
|
|
context_note = "" |
|
|
if file_name: |
|
|
context_note += f"The user uploaded a file named '{file_name}'.\n" |
|
|
if file_size is not None: |
|
|
context_note += f"File size: {file_size} bytes.\n" |
|
|
|
|
|
user_query = message.content |
|
|
if context_note: |
|
|
user_query = f"{user_query}\n\n[Context]\n{context_note}" |
|
|
|
|
|
|
|
|
result = await Runner.run(tutor_agent, user_query) |
|
|
|
|
|
|
|
|
facts_md = "" |
|
|
try: |
|
|
modality = infer_modality_from_filename(file_name or "").get("modality", "unknown") |
|
|
guide = imaging_reference_guide(modality) |
|
|
acq = "\n".join([f"- {b}" for b in guide.get("acquisition", [])]) |
|
|
art = "\n".join([f"- {b}" for b in guide.get("artifacts", [])]) |
|
|
prep = "\n".join([f"- {b}" for b in guide.get("preprocessing", [])]) |
|
|
tips = "\n".join([f"- {b}" for b in guide.get("study_tips", [])]) |
|
|
|
|
|
facts_md = ( |
|
|
f"### 📁 File\n" |
|
|
f"- Name: `{file_name or 'unknown'}`\n" |
|
|
f"- Size: `{file_size if file_size is not None else 'unknown'} bytes`\n\n" |
|
|
f"### 🔎 Modality (guess)\n- {modality}\n\n" |
|
|
f"### 📚 Reference Guide (study)\n" |
|
|
f"**Acquisition**\n{acq or '- (general)'}\n\n" |
|
|
f"**Common Artifacts**\n{art or '- (general)'}\n\n" |
|
|
f"**Preprocessing Ideas**\n{prep or '- (general)'}\n\n" |
|
|
f"**Study Tips**\n{tips or '- (general)'}\n\n" |
|
|
f"> ⚠️ Education only — no diagnosis.\n" |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
text = result.final_output or "I couldn’t generate an explanation." |
|
|
await cl.Message(content=f"{facts_md}\n---\n{text}").send() |
|
|
|