ibadhasnain commited on
Commit
4f7d15b
·
verified ·
1 Parent(s): 0afe5e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +460 -0
app.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------
2
+ # Single-file Chainlit app with inline "agents" shim
3
+ # Project: Multimodal Biomedical Imaging Tutor (education only)
4
+ # -----------------------------
5
+ import os, json
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable, Dict, List, Optional
8
+ from dotenv import load_dotenv
9
+ from pydantic import BaseModel, Field
10
+ import chainlit as cl
11
+ from openai import AsyncOpenAI as _SDKAsyncOpenAI
12
+
13
+ # =============================
14
+ # Inline "agents" shim
15
+ # =============================
16
+ def set_tracing_disabled(disabled: bool = True):
17
+ return disabled
18
+
19
+ def function_tool(func: Callable):
20
+ func._is_tool = True
21
+ return func
22
+
23
+ def handoff(*args, **kwargs):
24
+ return None
25
+
26
+ class InputGuardrail:
27
+ def __init__(self, guardrail_function: Callable):
28
+ self.guardrail_function = guardrail_function
29
+
30
+ @dataclass
31
+ class GuardrailFunctionOutput:
32
+ output_info: Any
33
+ tripwire_triggered: bool = False
34
+ tripwire_message: str = ""
35
+
36
+ class InputGuardrailTripwireTriggered(Exception):
37
+ pass
38
+
39
+ class AsyncOpenAI:
40
+ def __init__(self, api_key: str, base_url: Optional[str] = None):
41
+ kwargs = {"api_key": api_key}
42
+ if base_url:
43
+ kwargs["base_url"] = base_url
44
+ self._client = _SDKAsyncOpenAI(**kwargs)
45
+
46
+ @property
47
+ def client(self):
48
+ return self._client
49
+
50
+ class OpenAIChatCompletionsModel:
51
+ def __init__(self, model: str, openai_client: AsyncOpenAI):
52
+ self.model = model
53
+ self.client = openai_client.client
54
+
55
+ @dataclass
56
+ class Agent:
57
+ name: str
58
+ instructions: str
59
+ model: OpenAIChatCompletionsModel
60
+ tools: Optional[List[Callable]] = field(default_factory=list)
61
+ handoff_description: Optional[str] = None
62
+ output_type: Optional[type] = None # optional Pydantic model class
63
+ input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list)
64
+
65
+ def tool_specs(self) -> List[Dict[str, Any]]:
66
+ specs = []
67
+ for t in (self.tools or []):
68
+ if getattr(t, "_is_tool", False):
69
+ specs.append({
70
+ "type": "function",
71
+ "function": {
72
+ "name": t.__name__,
73
+ "description": (t.__doc__ or "")[:512],
74
+ "parameters": {
75
+ "type": "object",
76
+ "properties": {
77
+ p: {"type": "string"}
78
+ for p in t.__code__.co_varnames[:t.__code__.co_argcount]
79
+ },
80
+ "required": list(t.__code__.co_varnames[:t.__code__.co_argcount]),
81
+ },
82
+ },
83
+ })
84
+ return specs
85
+
86
+ class Runner:
87
+ @staticmethod
88
+ async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None):
89
+ msgs = [
90
+ {"role": "system", "content": agent.instructions},
91
+ {"role": "user", "content": user_input},
92
+ ]
93
+ tools = agent.tool_specs()
94
+ tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)}
95
+
96
+ # simple tool loop
97
+ for _ in range(4):
98
+ resp = await agent.model.client.chat.completions.create(
99
+ model=agent.model.model,
100
+ messages=msgs,
101
+ tools=tools if tools else None,
102
+ tool_choice="auto" if tools else None,
103
+ )
104
+
105
+ choice = resp.choices[0]
106
+ msg = choice.message
107
+ msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls})
108
+
109
+ if msg.tool_calls:
110
+ for call in msg.tool_calls:
111
+ fn_name = call.function.name
112
+ args = json.loads(call.function.arguments or "{}")
113
+ if fn_name in tool_map:
114
+ try:
115
+ result = tool_map[fn_name](**args)
116
+ except Exception as e:
117
+ result = {"error": str(e)}
118
+ else:
119
+ result = {"error": f"Unknown tool: {fn_name}"}
120
+ msgs.append({
121
+ "role": "tool",
122
+ "tool_call_id": call.id,
123
+ "name": fn_name,
124
+ "content": json.dumps(result),
125
+ })
126
+ continue # let the model use tool outputs
127
+
128
+ # finalize
129
+ final_text = msg.content or ""
130
+ final_obj = type("Result", (), {})()
131
+ final_obj.final_output = final_text
132
+ final_obj.context = context or {}
133
+ if agent.output_type and issubclass(agent.output_type, BaseModel):
134
+ try:
135
+ data = agent.output_type.model_validate_json(final_text)
136
+ final_obj.final_output = data.model_dump_json()
137
+ final_obj.final_output_as = lambda t: data
138
+ except Exception:
139
+ final_obj.final_output_as = lambda t: final_text
140
+ else:
141
+ final_obj.final_output_as = lambda t: final_text
142
+ return final_obj
143
+
144
+ final_obj = type("Result", (), {})()
145
+ final_obj.final_output = "Sorry, I couldn't complete the request."
146
+ final_obj.context = context or {}
147
+ final_obj.final_output_as = lambda t: final_obj.final_output
148
+ return final_obj
149
+
150
+ # =============================
151
+ # App configuration
152
+ # =============================
153
+ load_dotenv()
154
+ API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY")
155
+ if not API_KEY:
156
+ raise RuntimeError(
157
+ "Missing GEMINI_API_KEY (or OPENAI_API_KEY). "
158
+ "Add it in the Space secrets or a .env file."
159
+ )
160
+
161
+ set_tracing_disabled(True)
162
+
163
+ external_client: AsyncOpenAI = AsyncOpenAI(
164
+ api_key=API_KEY,
165
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
166
+ )
167
+ llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel(
168
+ model="gemini-2.5-flash",
169
+ openai_client=external_client,
170
+ )
171
+
172
+ # =============================
173
+ # Domain models for tutor
174
+ # =============================
175
+ class Section(BaseModel):
176
+ title: str
177
+ bullets: List[str]
178
+
179
+ class TutorResponse(BaseModel):
180
+ modality: str
181
+ acquisition_overview: Section
182
+ common_artifacts: Section
183
+ preprocessing_methods: Section
184
+ study_tips: Section
185
+ caution: str
186
+
187
+ # =============================
188
+ # Tools
189
+ # =============================
190
+ @function_tool
191
+ def infer_modality_from_filename(filename: str) -> dict:
192
+ """
193
+ Guess modality (MRI/X-ray/CT/Ultrasound) from filename keywords.
194
+ Returns: {"modality": "<guess or unknown>"}
195
+ """
196
+ f = (filename or "").lower()
197
+ guess = "unknown"
198
+ mapping = {
199
+ "xray": "X-ray", "x_ray": "X-ray", "xr": "X-ray", "cxr": "X-ray",
200
+ "mri": "MRI", "t1": "MRI", "t2": "MRI", "flair": "MRI", "dwi": "MRI", "adc": "MRI",
201
+ "ct": "CT", "cta": "CT",
202
+ "ultrasound": "Ultrasound", "usg": "Ultrasound", "echo": "Ultrasound",
203
+ }
204
+ for key, mod in mapping.items():
205
+ if key in f:
206
+ guess = mod
207
+ break
208
+ return {"modality": guess}
209
+
210
+ @function_tool
211
+ def imaging_reference_guide(modality: str) -> dict:
212
+ """
213
+ Educational points for acquisition, artifacts, preprocessing, and study tips by modality.
214
+ Education only (no diagnosis).
215
+ """
216
+ mod = (modality or "").strip().lower()
217
+ if mod in ["xray", "x-ray", "x_ray"]:
218
+ return {
219
+ "acquisition": [
220
+ "Projection radiography using ionizing radiation.",
221
+ "Common views: AP, PA, lateral; exposure (kVp/mAs) and positioning matter.",
222
+ "Grids/collimation reduce scatter and improve contrast."
223
+ ],
224
+ "artifacts": [
225
+ "Motion blur; under/overexposure affecting contrast.",
226
+ "Grid cut-off; foreign objects (buttons, jewelry).",
227
+ "Magnification/distortion from object–detector distance."
228
+ ],
229
+ "preprocessing": [
230
+ "Denoising (median/NLM), histogram equalization.",
231
+ "Window/level selection (bone vs soft tissue) for teaching.",
232
+ "Edge enhancement (unsharp mask) with caution (halo artifacts)."
233
+ ],
234
+ "study_tips": [
235
+ "Use a systematic approach (e.g., ABCDE for chest X-ray).",
236
+ "Compare sides; verify devices, labels, positioning.",
237
+ "Correlate with clinical scenario; keep a checklist."
238
+ ],
239
+ }
240
+ if mod in ["mri", "mr"]:
241
+ return {
242
+ "acquisition": [
243
+ "MR uses RF pulses in a strong magnetic field; sequences set contrast.",
244
+ "Key sequences: T1, T2, FLAIR, DWI/ADC, GRE/SWI.",
245
+ "TR/TE/flip angle shape SNR, contrast, time."
246
+ ],
247
+ "artifacts": [
248
+ "Motion/ghosting (movement, pulsation).",
249
+ "Susceptibility (metal, air-bone interfaces).",
250
+ "Chemical shift, Gibbs ringing.",
251
+ "B0/B1 inhomogeneity causing intensity bias."
252
+ ],
253
+ "preprocessing": [
254
+ "Bias-field correction (N4).",
255
+ "Denoising (non-local means), registration/normalization.",
256
+ "Skull stripping (brain), intensity standardization."
257
+ ],
258
+ "study_tips": [
259
+ "Know sequence intent (T1 anatomy, T2 fluid, FLAIR edema).",
260
+ "Check diffusion for acute ischemia (with ADC).",
261
+ "Use consistent windowing for longitudinal comparison."
262
+ ],
263
+ }
264
+ if mod == "ct":
265
+ return {
266
+ "acquisition": [
267
+ "Helical CT reconstructs attenuation in Hounsfield Units.",
268
+ "Kernels (bone vs soft) change sharpness/noise.",
269
+ "Contrast phases (arterial/venous) match the task."
270
+ ],
271
+ "artifacts": [
272
+ "Beam hardening (streaks), partial volume.",
273
+ "Motion (breathing/cardiac).",
274
+ "Metal artifacts; consider MAR algorithms."
275
+ ],
276
+ "preprocessing": [
277
+ "Denoising (bilateral/NLM) while preserving edges.",
278
+ "Appropriate window/level (lung, mediastinum, bone).",
279
+ "Iterative reconstruction / metal artifact reduction."
280
+ ],
281
+ "study_tips": [
282
+ "Use standard planes; scroll systematically.",
283
+ "Compare windows; document sizes/HU as needed.",
284
+ "Correlate phase with the clinical question."
285
+ ],
286
+ }
287
+ return {
288
+ "acquisition": [
289
+ "Acquisition parameters define contrast, resolution, and noise.",
290
+ "Positioning and motion control are crucial for quality."
291
+ ],
292
+ "artifacts": [
293
+ "Motion blur/ghosting; foreign objects and hardware.",
294
+ "Parameter misconfiguration harms interpretability."
295
+ ],
296
+ "preprocessing": [
297
+ "Denoising and contrast normalization for clarity.",
298
+ "Registration to standard planes for comparison."
299
+ ],
300
+ "study_tips": [
301
+ "Adopt a checklist; compare across time or sides.",
302
+ "Learn modality-specific knobs (window/level, sequences)."
303
+ ],
304
+ }
305
+
306
+ @function_tool
307
+ def file_facts(filename: str, size_bytes: str) -> dict:
308
+ """
309
+ Returns lightweight file facts: filename and byte size (as string).
310
+ """
311
+ try:
312
+ size = int(size_bytes)
313
+ except Exception:
314
+ size = -1
315
+ return {"filename": filename, "size_bytes": size}
316
+
317
+ # =============================
318
+ # Agents
319
+ # =============================
320
+ tutor_instructions = (
321
+ "You are a Biomedical Imaging Education Tutor. TEACH, do not diagnose.\n"
322
+ "Given an uploaded MRI or X-ray, provide:\n"
323
+ "1) Acquisition overview\n"
324
+ "2) Common artifacts\n"
325
+ "3) Preprocessing methods\n"
326
+ "4) Study tips\n"
327
+ "5) A caution line: education only, no diagnosis\n"
328
+ "Use tools to infer modality from filename and to fetch a modality reference guide.\n"
329
+ "If unclear, provide a generic overview and ask for clarification.\n"
330
+ "Always respond as concise, well-structured bullet points.\n"
331
+ "Absolutely avoid clinical diagnosis, disease identification, or treatment advice."
332
+ )
333
+
334
+ tutor_agent = Agent(
335
+ name="Biomedical Imaging Tutor",
336
+ instructions=tutor_instructions,
337
+ model=llm_model,
338
+ tools=[infer_modality_from_filename, imaging_reference_guide, file_facts],
339
+ )
340
+
341
+ class SafetyCheck(BaseModel):
342
+ unsafe_medical_advice: bool
343
+ requests_diagnosis: bool
344
+ pii_included: bool
345
+ reasoning: str
346
+
347
+ guardrail_agent = Agent(
348
+ name="Safety Classifier",
349
+ instructions=(
350
+ "Classify if the user's message requests medical diagnosis or unsafe medical advice, "
351
+ "and if it includes personal identifiers. Respond as JSON with fields: "
352
+ "{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}."
353
+ ),
354
+ model=llm_model,
355
+ )
356
+
357
+ # =============================
358
+ # Chainlit flows
359
+ # =============================
360
+ WELCOME = (
361
+ "🎓 **Multimodal Biomedical Imaging Tutor**\n\n"
362
+ "Upload an **MRI** or **X-ray** image (PNG/JPG). I’ll explain:\n"
363
+ "• Acquisition (how it’s made)\n"
364
+ "• Common artifacts (what to watch for)\n"
365
+ "• Preprocessing for study/teaching\n\n"
366
+ "⚠️ *Education only — I do not provide diagnosis. For clinical concerns, consult a professional.*"
367
+ )
368
+
369
+ @cl.on_chat_start
370
+ async def on_chat_start():
371
+ await cl.Message(content=WELCOME).send()
372
+ files = await cl.AskFileMessage(
373
+ content="Please upload an **MRI or X-ray** image (PNG/JPG).",
374
+ accept=["image/png", "image/jpeg"],
375
+ max_size_mb=15,
376
+ max_files=1,
377
+ timeout=180,
378
+ ).send()
379
+
380
+ if not files:
381
+ await cl.Message(content="No file uploaded. You can still ask general imaging questions.").send()
382
+ return
383
+
384
+ f = files[0]
385
+ cl.user_session.set("last_file_path", f.path)
386
+ cl.user_session.set("last_file_name", f.name)
387
+ cl.user_session.set("last_file_size", f.size)
388
+
389
+ await cl.Message(
390
+ content=f"Received **{f.name}** ({f.size} bytes). "
391
+ "Ask: *“Explain acquisition & artifacts for this image.”*"
392
+ ).send()
393
+
394
+ @cl.on_message
395
+ async def on_message(message: cl.Message):
396
+ # Safety check
397
+ try:
398
+ safety = await Runner.run(guardrail_agent, message.content)
399
+ # parse best-effort
400
+ parsed = safety.final_output
401
+ try:
402
+ data = json.loads(parsed) if isinstance(parsed, str) else parsed
403
+ except Exception:
404
+ data = {}
405
+ if isinstance(data, dict):
406
+ if data.get("unsafe_medical_advice") or data.get("requests_diagnosis"):
407
+ await cl.Message(
408
+ content=(
409
+ "🚫 I can’t provide medical diagnoses or treatment advice.\n"
410
+ "I’m happy to explain **imaging concepts**, **artifacts**, and **preprocessing** for learning."
411
+ )
412
+ ).send()
413
+ return
414
+ except Exception:
415
+ pass # continue gracefully
416
+
417
+ # Context from last upload
418
+ file_name = cl.user_session.get("last_file_name")
419
+ file_size = cl.user_session.get("last_file_size")
420
+
421
+ context_note = ""
422
+ if file_name:
423
+ context_note += f"The user uploaded a file named '{file_name}'.\n"
424
+ if file_size is not None:
425
+ context_note += f"File size: {file_size} bytes.\n"
426
+
427
+ user_query = message.content
428
+ if context_note:
429
+ user_query = f"{user_query}\n\n[Context]\n{context_note}"
430
+
431
+ # Run tutor
432
+ result = await Runner.run(tutor_agent, user_query)
433
+
434
+ # Quick reference facts
435
+ facts_md = ""
436
+ try:
437
+ modality = infer_modality_from_filename(file_name or "").get("modality", "unknown")
438
+ guide = imaging_reference_guide(modality)
439
+ acq = "\n".join([f"- {b}" for b in guide.get("acquisition", [])])
440
+ art = "\n".join([f"- {b}" for b in guide.get("artifacts", [])])
441
+ prep = "\n".join([f"- {b}" for b in guide.get("preprocessing", [])])
442
+ tips = "\n".join([f"- {b}" for b in guide.get("study_tips", [])])
443
+
444
+ facts_md = (
445
+ f"### 📁 File\n"
446
+ f"- Name: `{file_name or 'unknown'}`\n"
447
+ f"- Size: `{file_size if file_size is not None else 'unknown'} bytes`\n\n"
448
+ f"### 🔎 Modality (guess)\n- {modality}\n\n"
449
+ f"### 📚 Reference Guide (study)\n"
450
+ f"**Acquisition**\n{acq or '- (general)'}\n\n"
451
+ f"**Common Artifacts**\n{art or '- (general)'}\n\n"
452
+ f"**Preprocessing Ideas**\n{prep or '- (general)'}\n\n"
453
+ f"**Study Tips**\n{tips or '- (general)'}\n\n"
454
+ f"> ⚠️ Education only — no diagnosis.\n"
455
+ )
456
+ except Exception:
457
+ pass
458
+
459
+ text = result.final_output or "I couldn’t generate an explanation."
460
+ await cl.Message(content=f"{facts_md}\n---\n{text}").send()