Preetham22 commited on
Commit
42e56c5
·
1 Parent(s): 6803326

Add demo UI, token attention rollout & top5 table; clean ignores

Browse files
.gitignore CHANGED
@@ -1,43 +1,45 @@
1
- # Ignore Data files for tracking
2
- data/
3
- checkpoints/
4
  __pycache__/
5
  *.py[cod]
 
 
 
 
 
6
  .coverage
 
7
 
8
- # Weights & Biases
9
- wandb/
10
- wandb/*
11
- wandb_logs/
12
-
13
- # Optuna study databases
14
- *.db
15
-
16
- # Checkpoints / logs
17
- checkpoints/
18
  logs/
 
 
 
 
 
 
19
  *.pt
20
- .ipynb_checkpoints
21
  *.ckpt
22
 
23
- # Model files
24
- *.pth
 
25
 
26
- # Environment files
27
- .env
 
 
28
 
29
- # logs
30
- *.log
31
- *.tmp
32
 
33
- # --- EXCEPTIONS (ALLOW) ---
34
  !data/dummy_images/
35
  !data/dummy_images/COVID/*.png
36
  !data/dummy_images/NORMAL/*.png
37
  !data/dummy_images/VIRAL PNEUMONIA/*.png
38
-
39
- # Allow test CSV for CI
40
- !data/test_emr_records.csv
41
-
42
- # Mac system files
43
- .DS_Store
 
1
+ # --- Python ---
 
 
2
  __pycache__/
3
  *.py[cod]
4
+ *.pyo
5
+ *.so
6
+ *.dylib
7
+ .venv/
8
+ .env
9
  .coverage
10
+ .ipynb_checkpoints/
11
 
12
+ # --- Data & Artifacts ---
13
+ data/
14
+ results/
 
 
 
 
 
 
 
15
  logs/
16
+ checkpoints/
17
+ *.log
18
+ *.tmp
19
+ *.db
20
+
21
+ # Models / weights
22
  *.pt
23
+ *.pth
24
  *.ckpt
25
 
26
+ # W&B & experiment outputs
27
+ wandb/
28
+ wandb_logs/
29
 
30
+ # Predictions / exports
31
+ predictions_*.csv
32
+ app/demo/uploads/
33
+ app/demo/exports/
34
 
35
+ # OS junk
36
+ .DS_Store
 
37
 
38
+ # --- Exceptions (allow small fixtures/samples) ---
39
  !data/dummy_images/
40
  !data/dummy_images/COVID/*.png
41
  !data/dummy_images/NORMAL/*.png
42
  !data/dummy_images/VIRAL PNEUMONIA/*.png
43
+ !sample_data/**
44
+ !tests/**
45
+ !config/config.yaml.example
 
 
 
app/demo/demo.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import sys
3
  import time
 
4
  import gradio as gr
5
  import pandas as pd
 
6
  from pathlib import Path
7
 
8
  # Adds root directory to sys.path
@@ -14,91 +16,808 @@ from app.utils.inference_utils import load_model, predict
14
  # Initial default values
15
  DEFAULT_MODE = "multimodal"
16
  MODEL_PATHS = {
17
- "text": "medi_llm_state_dict_text.pth",
18
- "image": "medi_llm_state_dict_image.pth",
19
- "multimodal": "medi_llm_state_dict_multimodal.pth"
20
  }
21
 
22
  model_cache = {}
23
- prediction_log = []
 
24
 
25
 
26
- def classify(mode, emr_text, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if mode not in model_cache:
28
  model_cache[mode] = load_model(mode, MODEL_PATHS[mode])
29
  model = model_cache[mode]
30
- pred_text, cam_image, token_attn = predict(model, mode, emr_text=emr_text, image=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Save image to file if uploaded
33
- img_rel_path = None
34
- img_abs_path = None
35
- if image is not None:
36
- timestamp = time.strftime("%Y%m%d_%H%M%S")
37
- img_rel_path = f"app/demo/uploads/xray_{timestamp}.png"
38
  img_abs_path = os.path.abspath(img_rel_path)
39
  os.makedirs(os.path.dirname(img_abs_path), exist_ok=True)
40
- image.Save(img_abs_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Append to log
43
- prediction_log.append({
44
  "mode": mode,
45
- "emr": emr_text,
46
- "image_path": img_rel_path, # logged as relative path
47
- "prediction": pred_text
48
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return pred_text, cam_image, token_attn
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- def export_csv(filename):
54
  if not filename.strip():
55
  timestamp = time.strftime("%Y%m%d_%H%M%S")
56
- filename = f"demo_{timestamp}.csv"
57
  elif not filename.endswith(".csv"):
58
  filename += ".csv"
59
 
60
- csv_path = os.path.abspath(os.path.join("app/demo/exports", filename))
61
  os.makedirs(os.path.dirname(csv_path), exist_ok=True)
62
 
63
- df = pd.DataFrame(prediction_log)
 
 
 
 
 
 
 
 
 
 
 
64
  df.to_csv(csv_path, index=False)
65
 
66
- return csv_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- with gr.Blocks(theme=gr.themes.Glass(), css=".centered {text-align: center;}") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Centered title and subtitle
71
  gr.Markdown("<h2 class='centered'>🩺 Medi-LLM: Clinical Triage Assistant 🩻</h2>")
72
  gr.Markdown("<p class='centered'>Upload a chest X-ray and/or enter EMR text to get a triage level prediction.</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Mode selection
75
- with gr.Row():
76
- mode = gr.Radio(["text", "image", "multimodal"], value=DEFAULT_MODE, label="Select Input Mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Input: EMR text and/or image
79
  with gr.Row():
80
- emr_text = gr.Textbox(lines=6, label="EMR Text", placeholder="Enter clinical notes here...")
81
- image = gr.Image(type="pil", label="Chest X-ray")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  with gr.Row():
84
- submit_btn = gr.Button("Run Inference")
85
 
86
- result = gr.Textbox(label="Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- submit_btn.click(fn=classify, inputs=[mode, emr_text, image], outputs=result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # CSV Export UI
91
  gr.Markdown("### 📁 Export Prediction Log")
92
 
93
- with gr.Row():
94
- filename_input = gr.Textbox(label="CSV filename (optional)", placeholder="e.g., my_predictions.csv")
95
- download_btn = gr.Button("Export CSV")
96
- csv_output = gr.File(label="Download Link")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  download_btn.click(
99
  fn=export_csv,
100
- inputs=[filename_input],
 
 
 
 
 
 
 
 
101
  outputs=[csv_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
 
104
  if __name__ == "__main__":
 
1
  import os
2
  import sys
3
  import time
4
+ import shutil
5
  import gradio as gr
6
  import pandas as pd
7
+ from PIL import Image
8
  from pathlib import Path
9
 
10
  # Adds root directory to sys.path
 
16
  # Initial default values
17
  DEFAULT_MODE = "multimodal"
18
  MODEL_PATHS = {
19
+ "text": ROOT_DIR / "medi_llm_state_dict_text.pth",
20
+ "image": ROOT_DIR / "medi_llm_state_dict_image.pth",
21
+ "multimodal": ROOT_DIR / "medi_llm_state_dict_multimodal.pth"
22
  }
23
 
24
  model_cache = {}
25
+ prediction_log_user = []
26
+ prediction_log_doctor = []
27
 
28
 
29
+ def classify(role, mode, normalize_mode, emr_text, image, use_rollout):
30
+ grad_cam_path = "N/A"
31
+ token_attn_path = "N/A"
32
+
33
+ # Control output visibility
34
+ show_tabs = (role == "Doctor")
35
+ show_gradcam = (role == "Doctor" and mode in ["image", "multimodal"])
36
+ show_attention = (role == "Doctor" and mode in ["text", "multimodal"])
37
+
38
+ # ✅ Skip inference if no input is provided
39
+ if ((mode in ["text", "multimodal"] and (not emr_text or not emr_text.strip())) and (mode in ["image", "multimodal"] and image is None)):
40
+ count = len(prediction_log_doctor) if role == "Doctor" else len(prediction_log_user)
41
+ return (
42
+ gr.Textbox(value="⚠️ Please enter EMR text or upload an image to run inference."),
43
+ gr.Image(visible=False),
44
+ gr.HighlightedText(visible=False),
45
+ gr.HTML(value="", visible=False),
46
+ gr.Label(visible=False),
47
+ gr.Tabs(visible=False),
48
+ gr.Textbox(value=f"Predictions: {count}", interactive=False),
49
+ gr.JSON(value={}, visible=True) # JSON visible, but empty
50
+ )
51
+
52
+ # Image size guard + load
53
+ if image is not None:
54
+ image_path = Path(image)
55
+ image_size = image_path.stat().st_size
56
+ # Enforce 5MB limit (5 * 1024 * 1024 bytes)
57
+ if image_size > 5 * 1024 * 1024:
58
+ count = len(prediction_log_doctor) if role == "Doctor" else len(prediction_log_user)
59
+ return (
60
+ gr.Textbox(value="❌ Image exceeds 5MB size limit."),
61
+ gr.Image(visible=False),
62
+ gr.HighlightedText(visible=False),
63
+ gr.HTML(value="", visible=False),
64
+ gr.Label(visible=False),
65
+ gr.Tabs(visible=False), # Hide insights tab on error
66
+ gr.Textbox(value=f"Predictions: {count}", interactive=False),
67
+ gr.JSON(value={}, visible=True)
68
+ )
69
+ image = Image.open(image).convert("RGB")
70
+
71
+ # Model caching
72
  if mode not in model_cache:
73
  model_cache[mode] = load_model(mode, MODEL_PATHS[mode])
74
  model = model_cache[mode]
75
+
76
+ # Run prediction
77
+ try:
78
+ print("🧪 classify() passing normalize_mode:", normalize_mode, "| use_rollout:", use_rollout)
79
+ pred_text, cam_image, token_attn, confidence, probs, top5 = predict(
80
+ model,
81
+ mode,
82
+ emr_text=emr_text,
83
+ image=image,
84
+ normalize_mode=normalize_mode,
85
+ need_token_vis=show_attention,
86
+ use_rollout=use_rollout,
87
+ )
88
+
89
+ top5 = top5 or []
90
+ except ValueError as e:
91
+ print(f"⚠️ Inference failed: {e}")
92
+ count = len(prediction_log_doctor) if role == "Doctor" else len(prediction_log_user)
93
+ return (
94
+ gr.Textbox(value=f"❌ {str(e)}"),
95
+ gr.Image(visible=False),
96
+ gr.HighlightedText(visible=False),
97
+ gr.HTML(value="", visible=False),
98
+ gr.Label(visible=False),
99
+ gr.Tabs(visible=False),
100
+ gr.Textbox(value=f"Predictions: {count}", interactive=False),
101
+ gr.JSON(value={}, visible=True)
102
+ )
103
+
104
+ # Class probabilities (ensure always 3)
105
+ flat_probs = probs[0] if isinstance(probs[0], list) else probs
106
+ if len(flat_probs) != 3:
107
+ class_probs = {"low": 0.0, "medium": 0.0, "high": 0.0}
108
+ else:
109
+ class_probs = {label: round(prob, 3) for label, prob in zip(["low", "medium", "high"], flat_probs)}
110
+
111
+ # Save uploads (relative path in logs)
112
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
113
+ img_rel_path = f"app/demo/uploads/xray_{timestamp}.png" if image else "N/A"
114
 
115
  # Save image to file if uploaded
116
+ if image:
 
 
 
 
117
  img_abs_path = os.path.abspath(img_rel_path)
118
  os.makedirs(os.path.dirname(img_abs_path), exist_ok=True)
119
+ image.save(img_abs_path)
120
+
121
+ # Save Grad-CAM if Doctor and mode uses image
122
+ if cam_image and role == "Doctor" and mode in ["image", "multimodal"]:
123
+ cam_rel_path = f"app/demo/exports/{role.lower()}/gradcam/gradcam_{pred_text}_{timestamp}.png"
124
+ cam_abs_path = os.path.abspath(cam_rel_path)
125
+ os.makedirs(os.path.dirname(cam_abs_path), exist_ok=True)
126
+ cam_image.save(cam_abs_path)
127
+ grad_cam_path = cam_rel_path
128
+
129
+ # Save token attention if Doctor and mode uses text
130
+ if token_attn and role == "Doctor" and mode in ["text", "multimodal"]:
131
+ attn_rel_path = f"app/demo/exports/{role.lower()}/tokenattention/token_attn_{pred_text}_{timestamp}.txt"
132
+ attn_abs_path = os.path.abspath(attn_rel_path)
133
+ os.makedirs(os.path.dirname(attn_abs_path), exist_ok=True)
134
+ with open(attn_abs_path, "w") as f:
135
+ f.write(f"Normalization Mode: {normalize_mode}\n")
136
+ f.write(f"Use Rollout: {use_rollout}\n")
137
+ f.write("Token Attention (word | score):\n")
138
+ f.write(str(token_attn) + "\n\n")
139
+ f.write("Top 5 tokens (token | % contribution):\n")
140
+ if top5:
141
+ for tok, pct in top5:
142
+ f.write(f"{tok}\t{pct:.2f}%\n")
143
+ else:
144
+ f.write("(none)\n")
145
+ token_attn_path = attn_rel_path
146
 
147
  # Append to log
148
+ log_entry = {
149
  "mode": mode,
150
+ "normalize_mode": normalize_mode,
151
+ "use_rollout": bool(use_rollout),
152
+ "emr_text": emr_text or "N/A",
153
+ "image_path": img_rel_path if mode in ["image", "multimodal"] else "N/A", # logged as relative path
154
+ "prediction": pred_text,
155
+ "confidence": round(confidence, 3),
156
+ "grad_cam_path": grad_cam_path if role == "Doctor" else "N/A",
157
+ "token_attention_path": token_attn_path if role == "Doctor" else "N/A",
158
+ "top5_tokens": "; ".join([f"{tok}:{pct:.1f}%" for tok, pct in (top5 or [])])
159
+ }
160
+
161
+ if role == "Doctor":
162
+ prediction_log_doctor.append(log_entry)
163
+ count = len(prediction_log_doctor)
164
+ else:
165
+ prediction_log_user.append(log_entry)
166
+ count = len(prediction_log_user)
167
+
168
+ glow_class = f"prediction-{pred_text.lower()}" # 'high', 'medium', 'low'
169
+
170
+ return (
171
+ gr.Textbox(value=pred_text, elem_classes=[glow_class]),
172
+ gr.Image(value=cam_image, visible=show_gradcam),
173
+ gr.HighlightedText(value=token_attn, visible=show_attention),
174
+ render_top5_html(top5),
175
+ gr.Label(value=f"{confidence:.2f}", visible=True),
176
+ gr.Tabs(visible=show_tabs),
177
+ gr.Textbox(value=f"Predictions: {count}", interactive=False),
178
+ gr.JSON(value=class_probs, visible=True)
179
+ )
180
+
181
+
182
+ def render_inputs(mode):
183
+ is_text = mode in ["text", "multimodal"]
184
+ is_image = mode in ["image", "multimodal"]
185
+
186
+ emr_text = gr.Textbox(
187
+ visible=is_text,
188
+ lines=6,
189
+ label="EMR Text",
190
+ placeholder="Enter clinical notes here...",
191
+ elem_id="emr_textbox"
192
+ )
193
+
194
+ image = gr.Image(
195
+ visible=is_image,
196
+ type="filepath",
197
+ label="Chest X-ray",
198
+ image_mode="RGB",
199
+ show_label=True,
200
+ height=224,
201
+ elem_id="xray_image"
202
+ )
203
+
204
+ max_note = gr.HTML(
205
+ "<p style='font-size: 0.9em; color: #a9b1d6;'>Maximum file size: 5MB</p>",
206
+ visible=is_image
207
+ )
208
+
209
+ return emr_text, image, max_note
210
+
211
+
212
+ def render_top5_html(top5):
213
+ """
214
+ top5: list[ (token:str, pct:float) ] where pct is 0..100
215
+ Returns a gr.update with an HTML table colored by contribution (continuous gradient)
216
+ """
217
+ if not top5:
218
+ return gr.update(value="", visible=False)
219
 
220
+ def _lerp(a, b, t): # linear interpolation
221
+ return a + (b - a) * t
222
 
223
+ def _rgb_to_hex(rgb): # (r, g, b) -> "#rrggbb"
224
+ r, g, b = (max(0, min(255, int(round(x)))) for x in rgb)
225
+ return f"#{r:02x}{g:02x}{b:02x}"
226
+
227
+ def _interp_color(stops, t):
228
+ """
229
+ stops: list[(pos, (r,g,b))], pos in [0,1], sorted.
230
+ t in [0,1] -> interpolate between nearest stops
231
+ """
232
+ t = max(0.0, min(1.0, float(t)))
233
+ for i in range(len(stops) - 1):
234
+ p0, c0 = stops[i]
235
+ p1, c1 = stops[i + 1]
236
+ if t <= p1:
237
+ # local interpolation factor
238
+ if p1 == p0:
239
+ w = 0.0
240
+ else:
241
+ w = (t - p0) / (p1 - p0)
242
+ return (
243
+ _lerp(c0[0], c1[0], w),
244
+ _lerp(c0[1], c1[1], w),
245
+ _lerp(c0[2], c1[2], w),
246
+ )
247
+ return stops[-1][-1]
248
+
249
+ def _text_color_for_bg(rgb):
250
+ # YIQ luma for contrast; threshold ~128
251
+ r, g, b = rgb
252
+ yiq = (r * 299 + g * 587 + b * 114) / 1000.0
253
+ return "#000000" if yiq >= 128 else "#ffffff"
254
+
255
+ # --- gradient (low->high): green -> chartreuse -> orange -> red ---
256
+ # tweak the mid stops to our taste
257
+ color_stops = [
258
+ (0.00, (27, 67, 50)), # deep green
259
+ (0.40, (128, 170, 30)), # chartreuse-ish
260
+ (0.70, (255, 165, 0)), # orange
261
+ (1.00, (208, 0, 0)), # red
262
+ ]
263
+
264
+ # Normalize to [0, 1] on the 5 items so colors spread even if skewed
265
+ vals = [pct for _, pct in top5]
266
+ vmin, vmax = min(vals), max(vals)
267
+ if vmax - vmin < 1e-9:
268
+ norms = [0.5] * len(vals) # all equal -> neutral middle color
269
+ else:
270
+ norms = [(v - vmin) / (vmax - vmin) for v in vals]
271
+
272
+ # Build rows
273
+ row_html = []
274
+ for (tok, pct), t in zip(top5, norms):
275
+ rgb = _interp_color(color_stops, t)
276
+ bg = _rgb_to_hex(rgb)
277
+ fg = _text_color_for_bg(rgb)
278
+ row_html.append(
279
+ f"<tr style='background:{bg}; color:{fg};'>"
280
+ f"<td style='padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.06);'>{tok}</td>"
281
+ f"<td style='padding:10px 12px; text-align:right; border-bottom:1px solid rgba(255,255,255,0.06);'>{pct:.1f}%</td>"
282
+ "</tr>"
283
+ )
284
+
285
+ # color rows by normalized importance
286
+ max_score = max(score for _, score in top5)
287
+ min_score = min(score for _, score in top5)
288
+ rows = []
289
+
290
+ for tok, pct in top5:
291
+ # Normalize score 0-1
292
+ norm = (pct - min_score) / (max_score - min_score + 1e-9)
293
+ css = "top5-high" if norm > 0.66 else ("top5-medium" if norm > 0.33 else "top5-low")
294
+ rows.append(f"<tr class='{css}'><td>{tok}</td><td>{pct:.1f}%</td></tr>")
295
+
296
+ table = (
297
+ "<div class='top5-box' style='margin-top:10px;'>"
298
+ "<h4 style='margin:0 0 8px; color:#e5e7eb;'>Top 5 tokens (by contribution)</h4>"
299
+ "<table class='top5-table' style='width:100%; border-collapse:collapse;"
300
+ " background:#11131a; border:1px solid #2a2f3a; border-radius:10px; overflow:hidden;'>"
301
+ "<thead>"
302
+ "<tr style='background:#0f1320; color:#cbd5e1;'>"
303
+ "<th style='text-align:left; padding:10px 12px; font-weight:600;'>Token</th>"
304
+ "<th style='text-align:right; padding:10px 12px; font-weight:600;'>Contribution</th>"
305
+ "</tr>"
306
+ "</thead>"
307
+ f"<tbody>{''.join(row_html)}</tbody>"
308
+ "</table>"
309
+ "</div>"
310
+ )
311
+
312
+ return gr.update(value=table, visible=True)
313
+
314
+
315
+ def export_csv(filename, role):
316
+ log = prediction_log_doctor if role == "Doctor" else prediction_log_user
317
+ if not log:
318
+ # Return values to hide download and show warning
319
+ return None, gr.update(visible=False), gr.Textbox(value="⚠️ No predictions to export.", interactive=False) # Prevent empty exports
320
 
 
321
  if not filename.strip():
322
  timestamp = time.strftime("%Y%m%d_%H%M%S")
323
+ filename = f"{role.lower()}_predictions_{timestamp}.csv"
324
  elif not filename.endswith(".csv"):
325
  filename += ".csv"
326
 
327
+ csv_path = os.path.abspath(os.path.join(f"app/demo/exports/{role.lower()}", filename))
328
  os.makedirs(os.path.dirname(csv_path), exist_ok=True)
329
 
330
+ df = pd.DataFrame(log)
331
+ if role == "Doctor":
332
+ columns = [
333
+ "mode", "normalize_mode", "use_rollout", "emr_text", "image_path",
334
+ "prediction", "confidence",
335
+ "grad_cam_path", "token_attention_path",
336
+ "top5_tokens"
337
+ ]
338
+ else:
339
+ columns = ["mode", "emr_text", "image_path", "prediction", "confidence"]
340
+
341
+ df = df[columns]
342
  df.to_csv(csv_path, index=False)
343
 
344
+ return (
345
+ csv_path, # path string -> goes into csv_output (gr.File)
346
+ csv_path, # same path string again -> resused for blink_box_effect()
347
+ gr.update(value=f"✅ Exported to: {csv_path}", visible=True) # status string -> goes into export_status_box
348
+ )
349
+
350
+
351
+ def safe_delete_dir(path):
352
+ try:
353
+ if os.path.exists(path) and os.path.isdir(path):
354
+ shutil.rmtree(path)
355
+ except Exception as e:
356
+ print(f"⚠️ Failed to delete {path}: {e}")
357
+
358
+
359
+ def clear_logs(role):
360
+ # Step 1: Delete logged image files
361
+ log = prediction_log_doctor if role == "Doctor" else prediction_log_user
362
+ for entry in log:
363
+ # Delete X-ray image if exists and not "N/A"
364
+ if entry["image_path"] != "N/A":
365
+ image_file_path = ROOT_DIR / Path(entry["image_path"])
366
+ if image_file_path.exists():
367
+ try:
368
+ image_file_path.unlink()
369
+ except Exception as e:
370
+ print(f"⚠️ Failed to delete image folder: {image_file_path}: {e}")
371
+
372
+ # Delete Grad-CAM
373
+ if role == "Doctor" and entry.get("grad_cam_path") not in [None, "N/A"]:
374
+ grad_path = ROOT_DIR / Path(entry["grad_cam_path"])
375
+ if grad_path.exists():
376
+ try:
377
+ grad_path.unlink()
378
+ except Exception as e:
379
+ print(f"⚠️ Failed to delete Grad-CAM: {grad_path}: {e}")
380
+
381
+ # Delete token attention
382
+ if role == "Doctor" and entry.get("token_attention_path") not in [None, "N/A"]:
383
+ attn_path = ROOT_DIR / Path(entry["token_attention_path"])
384
+ if attn_path.exists():
385
+ try:
386
+ attn_path.unlink()
387
+ except Exception as e:
388
+ print(f"⚠️ Failed to delete token attention: {attn_path}: {e}")
389
+
390
+ # Step 2: Delete folders safely
391
+ if role == "Doctor":
392
+ safe_delete_dir(ROOT_DIR / "app/demo/uploads")
393
+ safe_delete_dir(ROOT_DIR / "app/demo/exports/doctor/gradcam")
394
+ safe_delete_dir(ROOT_DIR / "app/demo/exports/doctor/tokenattention")
395
+ safe_delete_dir(ROOT_DIR / "app/demo/exports/doctor")
396
+ else:
397
+ safe_delete_dir(ROOT_DIR / "app/demo/exports/user")
398
+ safe_delete_dir(ROOT_DIR / "app/demo/uploads")
399
 
400
+ # Step 3: Clear in-memory logs
401
+ prediction_log_doctor.clear() if role == "Doctor" else prediction_log_user.clear()
402
+
403
+ return gr.Textbox(value="Predictions: 0", interactive=False)
404
+
405
+
406
+ # Confirm before clearing logs
407
+ def confirm_clear():
408
+ return gr.Textbox(
409
+ value="⚠️ Are you sure you want to clear the logs? Click again to confirm.",
410
+ visible=True,
411
+ interactive=False,
412
+ label=""
413
+ )
414
+
415
+
416
+ def clear_confirmed(role):
417
+ cleared = clear_logs(role)
418
+ return (
419
+ cleared,
420
+ gr.Textbox(value="✅ Logs cleared successfully!", visible=True),
421
+ gr.update(value=None, visible=False), # csv_output
422
+ gr.update(interactive=True) # filename_input
423
+ )
424
 
425
+
426
+ def reset_confirm_box():
427
+ return gr.Textbox(value="", visible=False)
428
+
429
+
430
+ def disable_filename_input():
431
+ return gr.Textbox(interactive=False)
432
+
433
+
434
+ def show_loading_msg():
435
+ return gr.update(value="⏳ Running inference...", visible=True)
436
+
437
+
438
+ def blink_box_effect(path):
439
+ # return file component with blinking class
440
+ return gr.File(value=path, elem_classes=["download_box", "blink-csv"], visible=True, interactive=True)
441
+
442
+
443
+ def update_role_state(r):
444
+ # hide insights + token box when switching to User
445
+ tabs_vis = (r == "Doctor")
446
+ return (
447
+ r, # role_state
448
+ gr.update(visible=tabs_vis), # normalize_mode_column
449
+ gr.update(visible=tabs_vis), # insights_tab
450
+ gr.update(visible=False), # token_attention
451
+ gr.update(visible=False), # gradcam_img
452
+ gr.update(visible=tabs_vis), # use_rollout,
453
+ gr.update(visible=False), # top5_html
454
+ )
455
+
456
+
457
+ def rerun_if_done(ran, role, mode, normalize_mode, emr_text, image, use_rollout):
458
+ if not ran or role != "Doctor":
459
+ return (
460
+ gr.Textbox(visible=False),
461
+ gr.Image(visible=False),
462
+ gr.HighlightedText(visible=False),
463
+ gr.HTML(visible=False),
464
+ gr.Label(visible=False),
465
+ gr.Tabs(visible=False),
466
+ gr.Textbox(value="", interactive=False),
467
+ gr.JSON(value={}, visible=True)
468
+ )
469
+ # Let classify() run if already inferred once
470
+ return classify(role, mode, normalize_mode, emr_text, image, use_rollout)
471
+
472
+
473
+ def inject_tooltips():
474
+ return gr.HTML(
475
+ """
476
+ <script>
477
+ const observer = new MutationObserver(() => {
478
+ document.querySelectorAll(".token-attn-box .token").forEach(token => {
479
+ const text = token.innerText;
480
+ const pipeIndex = text.indexOf("|");
481
+ if (pipeIndex > -1) {
482
+ const display = text.slice(0, pipeIndex).trim();
483
+ const tooltip = text.slice(pipeIndex + 1).trim();
484
+ token.innerText = display;
485
+ token.setAttribute("data-tooltip", tooltip);
486
+ }
487
+ });
488
+ });
489
+ observer.observe(document.body, { childList: true, subtree: true });
490
+ </script>
491
+ """
492
+ )
493
+
494
+
495
+ def reset_ui():
496
+ is_text = DEFAULT_MODE in ["text", "multimodal"]
497
+ is_image = DEFAULT_MODE in ["image", "multimodal"]
498
+
499
+ return (
500
+ # Inputs (text/image areas)
501
+ gr.update(value="", visible=is_text), # emr_text
502
+ gr.update(value=None, visible=is_image), # image
503
+ gr.update(visible=is_image), # max_file_note
504
+
505
+ # Prediction/result area
506
+ gr.update(value="", visible=True), # result_box
507
+ gr.update(value=None, visible=False), # gradcam_img
508
+ gr.update(value=None, visible=False), # token_attention
509
+ gr.update(value="", visible=False), # top5_html
510
+ gr.update(value="", visible=False), # confidence_label
511
+ gr.update(visible=False), # insights_tab
512
+ gr.update(value={}, visible=True), # class_probs_json
513
+
514
+ # Role/mode controls + states
515
+ "User", # role_state
516
+ DEFAULT_MODE, # mode_state
517
+ "visual", # normalization_mode_state
518
+ gr.update(value="User"), # role (radio)
519
+ gr.update(value=DEFAULT_MODE), # mode (radio)
520
+ gr.update(value="visual"), # normalize_mode (radio)
521
+ gr.update(visible=False), # normalize_mode_column (hide in User)
522
+ gr.update(visible=False), # use_rollout
523
+ False, # rollout_state
524
+
525
+ # Loading + inference state
526
+ gr.update(value="", visible=False), # loading_msg
527
+ False, # inference_done
528
+ gr.update(value="", visible=False) # export_status_box
529
+ )
530
+
531
+
532
+ # --- Gradio UI ---
533
+ style_path = Path(__file__).resolve().parent / "style.css"
534
+ with open(style_path, "r") as f:
535
+ custom_css = f.read()
536
+
537
+ with gr.Blocks(css=custom_css) as demo:
538
  # Centered title and subtitle
539
  gr.Markdown("<h2 class='centered'>🩺 Medi-LLM: Clinical Triage Assistant 🩻</h2>")
540
  gr.Markdown("<p class='centered'>Upload a chest X-ray and/or enter EMR text to get a triage level prediction.</p>")
541
+ gr.HTML(
542
+ """
543
+ <div class='welcome-banner' style="background-color: #24283b; border-left: 4px solid #7aa2f7; padding: 16px; border-radius: 8px; margin-bottom: 16px;">
544
+ <h3 style="margin-top: 0; color: #c0caf5;">👋 Welcome to Medi-LLM</h3>
545
+ <p style="color: #a9b1d6; line-height: 1.6;">
546
+ This AI assistant helps triage patients using <strong>EMR text</strong> and <strong>chest X-rays</strong>.<br>
547
+ 📝 Enter EMR notes, 📷 upload a chest X-ray, or use both for a multimodal diagnosis.<br>
548
+ 👩‍⚕️ Select <strong>Doctor</strong> mode to view insights like Grad-CAM heatmaps and token-level attention.<br>
549
+ 💾 Save your results for later by exporting them to a CSV file.
550
+ </p>
551
+ </div>
552
+ """
553
+ )
554
 
555
+ # Hidden State
556
+ role_state = gr.State(value="User")
557
+ mode_state = gr.State(value=DEFAULT_MODE)
558
+ rollout_state = gr.State(value=False)
559
+ normaliza_mode_state = gr.State(value="visual")
560
+ inference_done = gr.State(value=False)
561
+
562
+ # Role and Mode selection
563
+ with gr.Row(equal_height=True):
564
+ with gr.Column():
565
+ role = gr.Radio(["User", "Doctor"], value="User", label="Select Role", info="Doctors see insights like Grad-CAM and token attention", elem_id="role_selector")
566
+ mode = gr.Radio(["text", "image", "multimodal"], value=DEFAULT_MODE, label="Select Input Mode", info="Choose Diagnosis input type", elem_id="mode_selector")
567
+ with gr.Column(visible=False) as normalize_mode_column:
568
+ normalize_mode = gr.Radio(
569
+ ["visual", "probabilistic"],
570
+ value="visual",
571
+ label="Attention Normalization",
572
+ info="Softmax sums to 1 (probabilistic). Visual uses gamma-boosted scaling for color clarity."
573
+ )
574
+ use_rollout = gr.Checkbox(
575
+ label="Use attention rollout (CLS -> inputs)",
576
+ value=False,
577
+ info="Includes residuals and multiplies attention across layers. Slower but often more faithful."
578
+ )
579
+
580
+ normalize_mode.change(
581
+ fn=lambda val: val,
582
+ inputs=[normalize_mode],
583
+ outputs=[normaliza_mode_state]
584
+ )
585
+
586
+ use_rollout.change(
587
+ fn=lambda v: v,
588
+ inputs=[use_rollout],
589
+ outputs=[rollout_state]
590
+ )
591
 
592
  # Input: EMR text and/or image
593
  with gr.Row():
594
+ with gr.Column(scale=3, elem_id="text_col") as text_col:
595
+ emr_text, image, max_file_note = render_inputs(DEFAULT_MODE)
596
+
597
+ # Submit button
598
+ with gr.Row():
599
+ submit_btn = gr.Button(
600
+ "🔍 Run Inference",
601
+ elem_id="inference_btn"
602
+ )
603
+ reset_btn = gr.Button(
604
+ "↩️ Reset",
605
+ elem_id="reset_btn"
606
+ )
607
+
608
+ # Outputs
609
+ with gr.Column(elem_classes=["output-box"]):
610
+ result_box = gr.Textbox(label="🧪 Triage Prediction", interactive=False)
611
+ confidence_label = gr.Label(label="📊 Confidence", visible=False)
612
+ prediction_count_box = gr.Textbox(value="Predictions: 0", interactive=False, label="🧮 Count", elem_id="prediction_count_box")
613
+ insights_tab = gr.Tabs(visible=False)
614
+ class_probs_json = gr.JSON(label="🔍 Class Probabilities", visible=True, elem_classes=["json-box"])
615
+ with insights_tab:
616
+ with gr.Tab("📷 Grad-CAM"):
617
+ gradcam_img = gr.Image(visible=False, elem_classes=["gr-image-box"])
618
+ with gr.Tab("🔬 Token Attention"):
619
+ token_attention = gr.HighlightedText(
620
+ visible=False,
621
+ show_legend=False,
622
+ color_map={
623
+ "0.0": "#7aa2f7", # blue
624
+ "0.25": "#80deea", # cyan
625
+ "0.5": "#fbc02d", # yellow
626
+ "0.75": "#ff8a65", # orange
627
+ "1.0": "#f7768e", # red
628
+ },
629
+ elem_classes=["token-attn-box"]
630
+ )
631
+ top5_html = gr.HTML(value="", visible=False)
632
+
633
+ inject_tooltips()
634
+
635
+ gr.HTML("""
636
+ <div class="attention-legend">
637
+ <div style="display: flex; align-items: center; gap: 8px;">
638
+ <span style="font-size: 14px; color: #c0caf5;">0.0</span>
639
+ <div class="attention-gradient-bar"></div>
640
+ <span style="font-size: 14px; color: #c0caf5;">1.0</span>
641
+ </div>
642
+ </div>
643
+ """)
644
 
645
  with gr.Row():
646
+ loading_msg = gr.Markdown(value="", visible=False, elem_classes=["loading-msg"])
647
 
648
+ # Bind inference
649
+ submit_btn.click(
650
+ fn=show_loading_msg,
651
+ outputs=[loading_msg]
652
+ ).then(
653
+ fn=classify,
654
+ inputs=[role_state, mode_state, normaliza_mode_state, emr_text, image, rollout_state],
655
+ outputs=[
656
+ result_box,
657
+ gradcam_img,
658
+ token_attention,
659
+ top5_html,
660
+ confidence_label,
661
+ insights_tab,
662
+ prediction_count_box,
663
+ class_probs_json,
664
+ ]
665
+ ).then(
666
+ fn=lambda: gr.update(value="", visible=False),
667
+ outputs=[loading_msg]
668
+ ).then(
669
+ fn=lambda: True,
670
+ outputs=[inference_done]
671
+ )
672
 
673
+ # Input Updates
674
+ mode.change(
675
+ fn=lambda m: (*render_inputs(m), m),
676
+ inputs=[mode],
677
+ outputs=[emr_text, image, max_file_note, mode_state]
678
+ )
679
+
680
+ role.change(
681
+ fn=update_role_state,
682
+ inputs=[role],
683
+ outputs=[role_state, normalize_mode_column, insights_tab, token_attention, gradcam_img, use_rollout, top5_html]
684
+ )
685
+
686
+ normalize_mode.change(
687
+ fn=rerun_if_done,
688
+ inputs=[inference_done, role_state, mode_state, normalize_mode, emr_text, image, rollout_state],
689
+ outputs=[
690
+ result_box,
691
+ gradcam_img,
692
+ token_attention,
693
+ top5_html,
694
+ confidence_label,
695
+ insights_tab,
696
+ prediction_count_box,
697
+ class_probs_json,
698
+ ]
699
+ )
700
+
701
+ use_rollout.change(
702
+ fn=rerun_if_done,
703
+ inputs=[inference_done, role_state, mode_state, normalize_mode, emr_text, image, rollout_state],
704
+ outputs=[
705
+ result_box,
706
+ gradcam_img,
707
+ token_attention,
708
+ top5_html,
709
+ confidence_label,
710
+ insights_tab,
711
+ prediction_count_box,
712
+ class_probs_json
713
+ ]
714
+ )
715
 
716
  # CSV Export UI
717
  gr.Markdown("### 📁 Export Prediction Log")
718
 
719
+ with gr.Row(equal_height=True):
720
+ with gr.Column(scale=3):
721
+ filename_input = gr.Textbox(
722
+ label="CSV filename (optional)",
723
+ placeholder="e.g., triage_results.csv",
724
+ info="Set filename as needed or leave blank for auto-naming",
725
+ elem_id="csv_filename"
726
+ )
727
+
728
+ export_status_box = gr.Textbox(
729
+ value="",
730
+ visible=False,
731
+ interactive=False,
732
+ label="",
733
+ elem_id="export_status"
734
+ )
735
+
736
+ with gr.Column(scale=4):
737
+ gr.Markdown(
738
+ "📑 **Summary**\n\nDownload your triage results for clinical review or research.",
739
+ elem_classes="centered"
740
+ )
741
+ with gr.Row():
742
+ with gr.Column(scale=1, min_width=200):
743
+ download_btn = gr.Button("💾 Export CSV", elem_id="export_button")
744
+ with gr.Column(scale=1, min_width=200):
745
+ clear_btn = gr.Button("🗑️ Clear Logs", elem_id="clear_button")
746
+ confirm_clear_btn = gr.Button("✅ Confirm Clear", visible=False, elem_id="confirm_button")
747
+ confirm_box = gr.Textbox(label="Status", interactive=False, visible=False, elem_id="confirm_box")
748
+
749
+ with gr.Column(scale=3):
750
+ csv_output = gr.File(label="📂 Download Link", elem_id="download_box")
751
 
752
  download_btn.click(
753
  fn=export_csv,
754
+ inputs=[filename_input, role_state],
755
+ outputs=[
756
+ csv_output,
757
+ csv_output,
758
+ export_status_box
759
+ ]
760
+ ).then(
761
+ fn=blink_box_effect,
762
+ inputs=[csv_output],
763
  outputs=[csv_output]
764
+ ).then(
765
+ fn=disable_filename_input,
766
+ outputs=[filename_input]
767
+ )
768
+
769
+ clear_btn.click(
770
+ fn=lambda: (
771
+ confirm_clear(),
772
+ gr.Button(visible=True),
773
+ ),
774
+ outputs=[confirm_box, confirm_clear_btn]
775
+ )
776
+
777
+ confirm_clear_btn.click(
778
+ fn=clear_confirmed,
779
+ inputs=[role_state],
780
+ outputs=[
781
+ prediction_count_box, # reset prediction count
782
+ confirm_box, # show success message
783
+ csv_output, # hide CSV output file
784
+ filename_input # re-enable input box
785
+ ]
786
+ ).then(
787
+ fn=lambda: gr.update(visible=False), # Hide confirm button
788
+ outputs=[confirm_clear_btn]
789
+ ).then(
790
+ fn=reset_confirm_box,
791
+ outputs=[confirm_box]
792
+ )
793
+
794
+ # Reset UI
795
+ reset_btn.click(
796
+ fn=reset_ui,
797
+ outputs=[
798
+ emr_text, # 1
799
+ image, # 2
800
+ max_file_note, # 3
801
+ result_box, # 4
802
+ gradcam_img, # 5
803
+ token_attention, # 6
804
+ top5_html, # 7
805
+ confidence_label, # 8
806
+ insights_tab, # 9
807
+ class_probs_json, # 10
808
+ role_state, # 11
809
+ mode_state, # 12
810
+ normaliza_mode_state, # 13
811
+ role, # 14 (radio)
812
+ mode, # 15 (radio)
813
+ normalize_mode, # 16 (radio)
814
+ normalize_mode_column, # 17 (column visibility)
815
+ use_rollout, # 18
816
+ rollout_state, # 19
817
+ loading_msg, # 20
818
+ inference_done, # 21
819
+ export_status_box # 22
820
+ ]
821
  )
822
 
823
  if __name__ == "__main__":
app/demo/style.css ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* === Base Layout === */
2
+ body {
3
+ background-color: #1a1b26 !important;
4
+ color: #c0caf5 !important;
5
+ font-family: 'Fira Code', monospace;
6
+ }
7
+
8
+ /* === Welcome Banner Hover Glow === */
9
+ .welcome-banner:hover {
10
+ box-shadow: 0 0 12px 3px #7aa2f7 !important;
11
+ transition: 0.3s ease-in-out;
12
+ cursor: pointer;
13
+ }
14
+
15
+ /* === Text Inputs Focus & Hover === */
16
+ #emr_textbox textarea:hover,
17
+ #emr_textbox textarea:focus {
18
+ border: 1px solid #7aa2f7 !important;
19
+ box-shadow: 0 0 6px 2px #7aa2f7 !important;
20
+ }
21
+
22
+ /* === Image Upload Hover & Focus === */
23
+ #xray_image:hover,
24
+ #xray_image:focus {
25
+ border: 1px solid #9ece6a !important;
26
+ box-shadow: 0 0 6px 2px #9ece6a !important;
27
+ }
28
+
29
+ /* === Grad-CAM Image Hover & Focus === */
30
+ .gr-image-box:hover,
31
+ .gr-image-box:focus {
32
+ border: 1px solid #f7768e !important;
33
+ box-shadow: 0 0 6px 2px #f7768e !important;
34
+ }
35
+
36
+ /* === Token Attention Hover & Focus Enhancements === */
37
+ .token-attn-box:hover,
38
+ .token-attn-box:focus {
39
+ border: 1px solid #bb9af7 !important;
40
+ box-shadow: 0 0 6px 2px #bb9af7 !important;
41
+ }
42
+
43
+ .token-attn-box .token {
44
+ transition: background-color 0.3s ease-in-out, box-shadow 0.3s ease-in-out, color 0.3s ease-in-out;
45
+ padding: 4px 8px;
46
+ border-radius: 4px;
47
+ font-weight: 500;
48
+ margin: 2px;
49
+ display: inline-block;
50
+ position: relative;
51
+ cursor: help;
52
+ }
53
+
54
+ /* Custom tooltip on hover using title attribute */
55
+ /* === Tooltip decoding for attention === */
56
+ .token-attn-box .token::after {
57
+ content: attr(data-tooltip);
58
+ position: absolute;
59
+ background: #1e1e2e;
60
+ color: #c0caf5;
61
+ padding: 4px 8px;
62
+ border-radius: 4px;
63
+ top: -30px;
64
+ left: 0;
65
+ white-space: nowrap;
66
+ font-size: 0.85em;
67
+ box-shadow: 0 0 6px rgba(0, 0, 0, 0.5);
68
+ z-index: 10;
69
+ opacity: 0;
70
+ pointer-events: none;
71
+ transition: opacity 0.2s ease-in-out;
72
+ }
73
+
74
+ .token-attn-box .token:hover::after {
75
+ opacity: 1;
76
+ }
77
+
78
+ /* === Tooltip arrow for custom data-tooltip === */
79
+ .token-attn-box .token[data-tooltip]:hover::before {
80
+ content: "";
81
+ position: absolute;
82
+ top: -12px;
83
+ left: 50%;
84
+ transform: translateX(-50%);
85
+ border-left: 6px solid transparent;
86
+ border-right: 6px solid transparent;
87
+ border-bottom: 6px solid #1e1e2e; /* Match tooltip background */
88
+ z-index: 9;
89
+ }
90
+
91
+ /* Hover and active styles */
92
+ .token-attn-box .token:hover {
93
+ outline: 2px solid #bb9af7 !important;
94
+ box-shadow: 0 0 8px 2px #bb9af7 !important;
95
+ cursor: pointer;
96
+ }
97
+
98
+ /* === Highlight top-attention token with glow === */
99
+ .token-attn-box .token[style*="rgba(247, 118, 142, 1)"] {
100
+ box-shadow: 0 0 10px 5px rgba(247, 118, 142, 0.85);
101
+ border-radius: 6px;
102
+ font-weight: 600;
103
+ }
104
+
105
+ /* === Attention-based text color tinting for stronger contrast === */
106
+ .token-attn-box .token[style*="rgba(255, 138, 101"],
107
+ .token-attn-box .token[style*="rgba(255, 138, 101, 1)"] {
108
+ color: #ff8a65;
109
+ }
110
+
111
+ .token-attn-box .token[style*="rgba(251, 192, 45"],
112
+ .token-attn-box .token[style*="rgba(251, 192, 45, 1)"] {
113
+ color: #fbc02d;
114
+ }
115
+
116
+ .token-attn-box .token[style*="rgba(128, 222, 234"],
117
+ .token-attn-box .token[style*="rgba(128, 222, 234, 1)"] {
118
+ color: #80deea;
119
+ }
120
+
121
+ .token-attn-box .token[style*="rgba(122, 162, 247"],
122
+ .token-attn-box .token[style*="rgba(122, 162, 247, 1)"] {
123
+ color: #7aa2f7;
124
+ }
125
+
126
+ .token-attn-box .token[style*="rgba(247, 118, 142"],
127
+ .token-attn-box .token[style*="rgba(247, 118, 142, 1)"] {
128
+ color: #f7768e;
129
+ }
130
+
131
+ /* === Token Attention Gradient Bar === */
132
+ .attention-gradient-bar {
133
+ flex-grow: 1;
134
+ height: 14px;
135
+ border-radius: 8px;
136
+ margin-top: 8px;
137
+ background: linear-gradient(
138
+ to right,
139
+ #7aa2f7 0%,
140
+ #80deea 25%,
141
+ #fbc02d 50%,
142
+ #ff8a65 75%,
143
+ #f7768e 100%
144
+ );
145
+ box-shadow: 0 0 3px rgba(0,0,0,0.4) inset;
146
+ }
147
+
148
+ /* === Top5 tokens box === */
149
+ .top5-box .top5-table {
150
+ box-shadow: 0 6px 16px rgba(0,0,0,0.25);
151
+ border-radius: 10px;
152
+ }
153
+ .top5-box h4 { letter-spacing: .2px; }
154
+
155
+ /* === Triage Prediction Box Glow (based on class) === */
156
+ .prediction-high {
157
+ border: 2px solid #f7768e !important;
158
+ box-shadow: 0 0 8px 3px #f7768e !important;
159
+ }
160
+
161
+ .prediction-medium {
162
+ border: 2px solid #e0af68 !important;
163
+ box-shadow: 0 0 8px 3px #e0af68 !important;
164
+ }
165
+
166
+ .prediction-low {
167
+ border: 2px solid #e0af68 !important;
168
+ box-shadow: 0 0 8px 3px #e0af68 !important;
169
+ }
170
+
171
+ /* === Basic Radio Button Styling (Role + Mode) === */
172
+ #role_selector label,
173
+ #mode_selector label {
174
+ display: block;
175
+ margin: 6px 0;
176
+ padding: 8px 12px;
177
+ border-radius: 6px;
178
+ border: 1px solid #3b4261;
179
+ background-color: #1f2335;
180
+ color: #c0caf5;
181
+ font-weight: 500;
182
+ transition: all 0.2s ease-in-out;
183
+ cursor: pointer;
184
+ }
185
+
186
+ /* Hover and Focus Glow */
187
+ #role_selector label:hover,
188
+ #role_selector input:focus + label,
189
+ #mode_selector label:hover,
190
+ #mode_selector input:focus + label {
191
+ border: 1px solid #a0cfff !important;
192
+ box-shadow: 0 0 6px 2px #a0cfff !important;
193
+ }
194
+
195
+ /* Selected Option */
196
+ #role_selector input:checked + label,
197
+ #mode_selector input:checked + label {
198
+ background-color: #3d59a1 !important;
199
+ border: 1px solid #7aa2f7 !important;
200
+ color: white !important;
201
+ box-shadow: 0 0 6px 2px #7aa2f7 !important;
202
+ }
203
+
204
+ /* Optional: Ensure radio circles are visible */
205
+ #role_selector input,
206
+ #mode_selector input {
207
+ margin-right: 8px;
208
+ transform: scale(1.1);
209
+ }
210
+
211
+ /* === Buttons === */
212
+ .gr-button {
213
+ border-radius: 8px !important;
214
+ font-weight: 500;
215
+ }
216
+
217
+ /* === Primary/Secondary Buttons via IDs === */
218
+ /* Inference (blue) */
219
+ #inference_btn {
220
+ background-color: #7aa2f7 !important;
221
+ color: #ffffff !important;
222
+ }
223
+ #inference_btn:hover {
224
+ background-color: #409eff !important;
225
+ transform: translateY(-1px);
226
+ box-shadow: 0 0 10px rgba(122,162,255,0.55) !important;
227
+ }
228
+
229
+ /* Reset (red/coral) */
230
+ #reset_btn {
231
+ background-color: #f7768e !important;
232
+ color: #ffffff !important;
233
+ }
234
+ #reset_btn:hover {
235
+ background-color: #ff5c7a !important;
236
+ transform: translateY(-1px);
237
+ box-shadow: 0 0 12px rgba(255,92,122,0.75) !important;
238
+ }
239
+
240
+ /* Export (blue, same as inference) */
241
+ #export_button {
242
+ background-color: #7aa2f7 !important;
243
+ color: #ffffff !important;
244
+ }
245
+ #export_button:hover {
246
+ background-color: #409eff !important;
247
+ transform: translateY(-1px);
248
+ box-shadow: 0 0 10px rgba(122,162,255,0.45) !important;
249
+ }
250
+
251
+ /* Clear Logs (red, same as reset) */
252
+ #clear_button {
253
+ background-color: #f7768e !important;
254
+ color: #ffffff !important;
255
+ }
256
+ #clear_button:hover {
257
+ background-color: #ff5c7a !important;
258
+ transform: translateY(-1px);
259
+ box-shadow: 0 0 12px rgba(255,92,122,0.75) !important;
260
+ }
261
+
262
+ /* Confirm Clear (yellow base, GREEN glow on hover) */
263
+ #confirm_button {
264
+ background-color: #e0af68 !important;
265
+ color: #ffffff !important;
266
+ border-radius: 8px !important;
267
+ padding: 10px 14px !important;
268
+ font-weight: 600 !important;
269
+ border: none !important;
270
+ cursor: pointer !important;
271
+ }
272
+
273
+ #confirm_button:hover {
274
+ background-color: #d9a147 !important;
275
+ box-shadow: 0 0 10px 3px rgba(158,206,106,0.9) !important; /* green glow */
276
+ border: 1px solid #9ece6a !important;
277
+ }
278
+
279
+ /* === Tab Panels === */
280
+ .gr-tabitem {
281
+ background-color: #1f2335 !important;
282
+ color: #c0caf5 !important;
283
+ }
284
+
285
+ /* === Markdown(centered) === */
286
+ .centered {
287
+ text-align: center;
288
+ }
289
+
290
+ /* === Loading message === */
291
+ .loading-msg {
292
+ text-align: center;
293
+ color: #7aa2f7;
294
+ font-weight: bold;
295
+ font-size: 1.1em;
296
+ }
297
+
298
+ /* === Hover & Focus Glow for CSV filename input === */
299
+ #csv_filename input:hover,
300
+ #csv_filename textarea:hover,
301
+ #csv_filename input:focus,
302
+ #csv_filename textarea:focus {
303
+ border: 1px solid #7aa2f7 !important;
304
+ box-shadow: 0 0 6px 2px #7aa2f7 !important;
305
+ transition: 0.3s ease-in-out;
306
+ }
307
+
308
+ /* === Blinking effect for CSV download box === */
309
+ @keyframes blink-box {
310
+ 0% { box-shadow: 0 0 6px 2px #7aa2f7; }
311
+ 50% { box-shadow: 0 0 12px 4px #7aa2f7; }
312
+ 100% { box-shadow: 0 0 6px 2px #7aa2f7; }
313
+ }
314
+
315
+ .blink-csv {
316
+ animation: blink-box 1.5s ease-in-out 3;
317
+ border-radius: 8px;
318
+ border: 1px solid #7aa2f7 !important;
319
+ }
320
+
321
+ /* === Prediction Count Box === */
322
+ #prediction_count_box {
323
+ font-size: 1em;
324
+ padding: 10px;
325
+ border-radius: 6px;
326
+ background-color: #1f2335;
327
+ color: #c0caf5;
328
+ border: 1px solid #7aa2f7;
329
+ transition: border-color 0.3s, box-shadow 0.3s;
330
+ }
331
+
332
+ #prediction_count_box:hover,
333
+ #prediction_count_box:focus {
334
+ border: 1px solid #a0cfff !important;
335
+ box-shadow: 0 0 6px 2px #7aa2f7 !important;
336
+ }
337
+
338
+ /* === Clear Logs Confirmation Box === */
339
+ #confirm_box {
340
+ font-size: 0.95em;
341
+ padding: 10px;
342
+ border-radius: 6px;
343
+ background-color: #1f2335;
344
+ color: #c0caf5;
345
+ border: 1px solid #e0af68;
346
+ transition: border-color 0.3s, box-shadow 0.3s;
347
+ }
348
+
349
+ #confirm_box:hover,
350
+ #confirm_box:focus {
351
+ border: 1px solid #e0af68 !important;
352
+ box-shadow: 0 0 6px 2px #e0af68 !important;
353
+ }
354
+
355
+ /* Export status message styling */
356
+ #export_status {
357
+ color: #9ece6a; /* Greenish success color */
358
+ font-weight: bold;
359
+ padding: 8px 12px;
360
+ border: 1px solid #9ece6a;
361
+ background-color: #1a1b26; /* Match your dark background */
362
+ border-radius: 6px;
363
+ margin-top: 8px;
364
+ transition: opacity 0.5s ease;
365
+ }
366
+
367
+ /* Optional fade-out animation (if using JS or if Gradio later supports it natively) */
368
+ #export_status.fade-out {
369
+ opacity: 0;
370
+ }
371
+
372
+ /* === Class level Probabilities === */
373
+ .json-box {
374
+ background-color: #1e222e;
375
+ padding: 12px;
376
+ border-radius: 8px;
377
+ border: 1px solid #7aa2f7;
378
+ }
379
+
app/utils/attention_utils.py DELETED
@@ -1,20 +0,0 @@
1
- def extract_token_attention(model, tokenizer, input_ids, attention_mask):
2
- if hasattr(model.text_encoder, 'bert'):
3
- try:
4
- outputs = model.text_encoder.bert(
5
- input_ids=input_ids,
6
- attention_mask=attention_mask,
7
- output_attentions=True
8
- )
9
- last_attn = outputs.attentions[-1] # (B, H, S, S), final layer
10
- weights = last_attn.mean(dim=1)[0, 0, :] # mean heads, CLS -> token, dim = 1 mean across heads from batch 0, from CLS token, to connection to all other tokens
11
-
12
- weights = weights.detach().cpu().numpy()
13
- weights = (weights - weights.min()) / (weights.max() - weights.min() + 1e-8)
14
-
15
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
16
- return [(tok, float(round(weights[i], 3))) for i, tok in enumerate(tokens)]
17
-
18
- except Exception as e:
19
- print("Attention extraction failed:", e)
20
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/utils/gradcam_utils.py CHANGED
@@ -16,7 +16,7 @@ def register_hooks(model):
16
 
17
  layer = model.image_encoder.layer4
18
  fwd_handle = layer.register_forward_hook(forward_hook)
19
- bwd_handle = layer.register_backward_hook(backward_hook)
20
 
21
  return activations, gradients, fwd_handle, bwd_handle
22
 
@@ -25,16 +25,24 @@ def generate_gradcam(image_pil, activations, gradients):
25
  grads = gradients["value"]
26
  acts = activations["value"]
27
 
 
28
  pooled_grads = torch.mean(grads, dim=[0, 2, 3])
29
  for i in range(acts.shape[1]):
30
  acts[:, i, :, :] *= pooled_grads[i]
31
 
32
- heatmap = torch.mean(acts, dim=1).squeeze().cpu().numpy()
 
33
  heatmap = np.maximum(heatmap, 0)
34
- heatmap /= heatmap.max()
35
 
36
- heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize((224, 224)).convert("L")
37
- image_np = np.array(image_pil.resize((224, 224)).convert("RGB"))
38
- overlay = np.uint8(0.6 * image_np + 0.4 * plt.cm.jet(heatmap / 255.0)[:, :, :3] * 255)
 
39
 
40
- return Image.fromarray(overlay.astype(np.uint8))
 
 
 
 
 
 
16
 
17
  layer = model.image_encoder.layer4
18
  fwd_handle = layer.register_forward_hook(forward_hook)
19
+ bwd_handle = layer.register_full_backward_hook(backward_hook)
20
 
21
  return activations, gradients, fwd_handle, bwd_handle
22
 
 
25
  grads = gradients["value"]
26
  acts = activations["value"]
27
 
28
+ # Out-of-place Grad-CAM weighting
29
  pooled_grads = torch.mean(grads, dim=[0, 2, 3])
30
  for i in range(acts.shape[1]):
31
  acts[:, i, :, :] *= pooled_grads[i]
32
 
33
+ # Normalize heatmap
34
+ heatmap = torch.mean(acts, dim=1).squeeze().detach().cpu().numpy()
35
  heatmap = np.maximum(heatmap, 0)
36
+ heatmap /= heatmap.max() + 1e-8
37
 
38
+ # Convert to image and overlay
39
+ heatmap_resized = Image.fromarray(np.uint8(255 * heatmap)).resize((224, 224))
40
+ heatmap_array = np.array(heatmap_resized)
41
+ colormap = plt.cm.jet(heatmap_array / 255.0)[..., :3] # shape (H, W, 3), RGB
42
 
43
+ # Combine with original image
44
+ image_np = np.array(image_pil.resize((224, 224)).convert("RGB")) / 255.0
45
+ overlay = (0.6 * image_np + 0.4 * colormap) * 255
46
+ overlay = overlay.astype(np.uint8)
47
+
48
+ return Image.fromarray(overlay)
app/utils/inference_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import sys
2
  import torch
3
  import yaml
 
4
  from pathlib import Path
5
  from transformers import AutoTokenizer
6
  from torchvision import transforms
@@ -9,6 +10,8 @@ ROOT_DIR = Path(__file__).resolve().parent.parent.parent
9
  sys.path.append(str(ROOT_DIR))
10
 
11
  from src.multimodal_model import MediLLMModel
 
 
12
 
13
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
@@ -33,32 +36,222 @@ def load_model(mode, model_path, config_path=str(Path("config/config.yaml").reso
33
  dropout=config["dropout"],
34
  hidden_dim=config["hidden_dim"]
35
  )
36
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
 
37
  model.to(DEVICE)
38
  model.eval()
39
  return model
40
 
41
 
42
- def predict(model, mode, emr_text=None, image=None):
43
- with torch.no_grad():
44
- input_ids = attention_mask = img_tensor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- if mode in ["text", "multimodal"] and emr_text:
47
- text_tokens = tokenizer(
48
- emr_text,
49
- return_tensors="pt",
50
- truncation=True,
51
- padding="max_length",
52
- max_length=128,
53
  )
54
- input_ids = text_tokens["input_ids"].to(DEVICE)
55
- attention_mask = text_tokens["attention_mask"].to(DEVICE)
56
 
57
- if mode in ["image", "multimodal"] and image:
58
- img_tensor = image_transform(image).unsqueeze(0).to(DEVICE)
59
 
60
- output = model(input_ids=input_ids, attention_mask=attention_mask, image=img_tensor)
61
- pred = torch.argmax(output, dim=1).item()
62
- confidence = torch.softmax(output, dim=1).squeeze()[pred].item()
 
 
63
 
64
- return f"{inv_map[pred]} (Confidence: {confidence:.2f})"
 
1
  import sys
2
  import torch
3
  import yaml
4
+ import numpy as np
5
  from pathlib import Path
6
  from transformers import AutoTokenizer
7
  from torchvision import transforms
 
10
  sys.path.append(str(ROOT_DIR))
11
 
12
  from src.multimodal_model import MediLLMModel
13
+ from app.utils.gradcam_utils import register_hooks, generate_gradcam
14
+
15
 
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
 
36
  dropout=config["dropout"],
37
  hidden_dim=config["hidden_dim"]
38
  )
39
+ state = torch.load(model_path, map_location=DEVICE)
40
+ model.load_state_dict(state)
41
  model.to(DEVICE)
42
  model.eval()
43
  return model
44
 
45
 
46
+ def attention_rollout(attentions, last_k=4, residual_alpha=0.5):
47
+ """
48
+ attentions_tuple: tuple/list of layer attentions; each is (B,H,S,S)
49
+ last_k: only roll back through the last k layers (keeps contrast)
50
+ residual_alpha: how much identity to add before normalizing (preserve token self-info)
51
+ returns: [B, S, S] rollout matrix, or None if input is invalid
52
+ """
53
+ if attentions is None:
54
+ return None
55
+ if isinstance(attentions, (list, tuple)) and len(attentions) == 0:
56
+ return None
57
+
58
+ first = attentions[0]
59
+ if first is None or first.ndim != 4:
60
+ return None # expect [B, H, S, S]
61
+
62
+ B, H, S, _ = first.shape
63
+ eye = torch.eye(S, device=first.device).unsqueeze(0).expand(B, S, S) # [B, S, S]
64
+
65
+ L = len(attentions)
66
+ if last_k is None:
67
+ last_k = L
68
+ if last_k <= 0:
69
+ # No layers selected -> return identity (no propagation)
70
+ return eye.clone()
71
+
72
+ start = max(0, L - last_k)
73
+ A = None
74
+ for layer in range(start, L):
75
+ a = attentions[layer]
76
+ if a is None or a.ndim != 4 or a.shape[0] != B or a.shape[-1] != S:
77
+ # Skip malformed layer
78
+ continue
79
+ a = a.mean(dim=1) # [B, S, S] (avg heads)
80
+ a = a + float(residual_alpha) * eye
81
+ a = a / (a.sum(dim=-1, keepdim=True) + 1e-12) # row-normalize
82
+ A = a if A is None else torch.bmm(A, a)
83
+
84
+ # if we never multiplied like when all layers skipped, fall back to identity
85
+ return A if A is not None else eye.clone() # [B,S,S]
86
+
87
+
88
+ def merge_wordpieces(tokens, scores):
89
+ merged_tokens, merged_scores = [], []
90
+ cur_tok, cur_scores = "", []
91
+ for t, s in zip(tokens, scores):
92
+ if t.startswith("##"):
93
+ cur_tok += t[2:]
94
+ cur_scores.append(s)
95
+ else:
96
+ if cur_tok:
97
+ merged_tokens.append(cur_tok)
98
+ merged_scores.append(sum(cur_scores) / max(1, len(cur_scores)))
99
+ cur_tok, cur_scores = t, [s]
100
+ if cur_tok:
101
+ merged_tokens.append(cur_tok)
102
+ merged_scores.append(sum(cur_scores) / max(1, len(cur_scores)))
103
+ return merged_tokens, merged_scores
104
+
105
+
106
+ def _normalize_for_display_wordlevel(attn_scores, normalize_mode="visual", temperature=0.30):
107
+ """
108
+ Convert raw *word-level* token scores into:
109
+ - probabilistic mode: probabilities that sum to 1.0 (100%), with labels like "0.237 | 23.7% (contrib)"
110
+ - visual mode: min-max + gamma scaling (contrast, not sum-to-100), with labels like "0.68 | visual score"
111
+
112
+ Returns:
113
+ attn_final: np.ndarray of floats in [0, 1] for color scale
114
+ labels: list[str] per token (tooltip text; first number stays up front for your color_map bucketing)
115
+ """
116
+ attn_array = np.array(attn_scores, dtype=float)
117
+
118
+ if normalize_mode == "probabilistic":
119
+ # ---- percentage view that sums up to 100% ----
120
+ attn_array = np.maximum(attn_array, 0.0)
121
+ if attn_array.max() > 0:
122
+ attn_array = attn_array / (attn_array.max() + 1e-12) # scale to [0, 1] for stability
123
+ # sharpen (lower temp => peakier)
124
+ attn_array = np.power(attn_array + 1e-12, 1.0 / max(1e-6, float(temperature)))
125
+ prob = attn_array / (attn_array.sum() + 1e-12)
126
+ percent = prob * 100.0
127
+
128
+ # keep prob (0..1) for color scale; label with % contrib
129
+ labels = [f"{prob[i]:.3f} | {percent[i]:.1f}% (contrib)" for i in range(len(prob))]
130
+ return prob, labels
131
+ else:
132
+ # ---- visual: min-max + gamma (contrast, not sum-to-100) ---
133
+ if attn_array.max() > attn_array.min():
134
+ attn_array0 = (attn_array - attn_array.min()) / (attn_array.max() - attn_array.min() + 1e-8)
135
+ attn_array0 = np.clip(np.power(attn_array0, 0.75), 0.1, 1.0)
136
+ else:
137
+ attn_array0 = np.zeros_like(attn_array)
138
+ labels = [f"{attn_array0[i]:.2f} | visual score" for i in range(len(attn_array0))]
139
+ return attn_array0, labels
140
+
141
+
142
+ def predict(
143
+ model,
144
+ mode,
145
+ emr_text=None,
146
+ image=None,
147
+ normalize_mode="visual",
148
+ need_token_vis=False,
149
+ use_rollout=False
150
+ ):
151
+ """
152
+ normalize_mode: "visual" (min-max + gamma boost) or "probabilistic" (softmax)
153
+ need_token_vis: request/compute token-level attentions (Doctor mode + text/multimodal)
154
+ use_rollout: use attention rollout across layers
155
+ """
156
+ input_ids = attention_mask = img_tensor = None
157
+ cam_image = None
158
+ highlighted_tokens = None
159
+ top5 = []
160
+
161
+ if mode in ["text", "multimodal"] and emr_text:
162
+ text_tokens = tokenizer(
163
+ emr_text,
164
+ return_tensors="pt",
165
+ truncation=True,
166
+ padding="max_length",
167
+ max_length=128,
168
+ )
169
+ input_ids = text_tokens["input_ids"].to(DEVICE)
170
+ attention_mask = text_tokens["attention_mask"].to(DEVICE)
171
+
172
+ if mode in ["image", "multimodal"] and image:
173
+ img_tensor = image_transform(image).unsqueeze(0).to(DEVICE)
174
+
175
+ # Only Register hooks for Grad-CAM if needed
176
+ if mode in ["image", "multimodal"]:
177
+ activations, gradients, fwd_handle, bwd_handle = register_hooks(model)
178
+ model.zero_grad()
179
+
180
+ # === Forward ===
181
+ # Only enable attentions when planning to visualize them
182
+ outputs = model(
183
+ input_ids=input_ids,
184
+ attention_mask=attention_mask,
185
+ image=img_tensor,
186
+ output_attentions=bool(need_token_vis and (mode in ["text", "multimodal"])),
187
+ return_raw_attentions=bool(use_rollout and need_token_vis)
188
+ )
189
+
190
+ logits = outputs["logits"]
191
+ if logits.numel() == 0:
192
+ raise ValueError("Model returned empty logits. Check input format.")
193
+
194
+ probs = torch.softmax(logits, dim=1)
195
+ pred = torch.argmax(probs, dim=1).item()
196
+ confidence = probs.squeeze()[pred].item()
197
+
198
+ # === Grad-CAM ===
199
+ if mode in ["image", "multimodal"]:
200
+ # Enable gradients only for Grad-CAM
201
+ logits[0, pred].backward(retain_graph=True)
202
+ cam_image = generate_gradcam(image, activations, gradients)
203
+ fwd_handle.remove()
204
+ bwd_handle.remove()
205
+
206
+ # === Token-level attention ===
207
+ if need_token_vis and (mode in ["text", "multimodal"]):
208
+ token_attn_scores = None
209
+
210
+ if use_rollout and outputs.get("raw_attentions") is not None:
211
+ # partial rollout
212
+ # roll: [B, S, S]; roll[b, 0, :] is CLS-to-all tokens for that batch item
213
+ roll = attention_rollout(outputs["raw_attentions"], last_k=4, residual_alpha=0.5) # [B,S,S] # (S, S)
214
+ if roll is not None:
215
+ # roll: [B, S, S]; pick CLS row (index 0)
216
+ cls_to_tokens = roll[0, 0].detach().cpu().numpy().tolist() # CLS row
217
+ token_attn_scores = cls_to_tokens
218
+ elif outputs.get("token_attentions") is not None:
219
+ token_attn_scores = outputs["token_attentions"].squeeze().tolist()
220
+
221
+ if token_attn_scores is not None:
222
+ # Filter out specials/pad + aligh to wordpieces
223
+ ids = input_ids[0].tolist()
224
+ amask = attention_mask[0].tolist() if attention_mask is not None else [1] * len(ids)
225
+ wp_all = tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False)
226
+ special_ids = set(tokenizer.all_special_ids)
227
+ keep_idx = [i for i, (tid, m) in enumerate(zip(ids, amask)) if (tid not in special_ids) and (m == 1)]
228
+ wp_tokens = [wp_all[i] for i in keep_idx]
229
+ wp_scores = [token_attn_scores[i] if i < len(token_attn_scores) else 0.0 for i in keep_idx]
230
+
231
+ # Merge wordpieces into words
232
+ word_tokens, attn_scores = merge_wordpieces(wp_tokens, wp_scores)
233
+
234
+ # Build Top-5 (probabilistic normalization for ranking)
235
+ _probs_for_rank, _ = _normalize_for_display_wordlevel(
236
+ attn_scores, normalize_mode="probabilistic", temperature=0.30
237
+ )
238
+ pairs = list(zip(word_tokens, _probs_for_rank))
239
+ pairs.sort(key=lambda x: x[1], reverse=True)
240
+ top5 = [(tok, float(p * 100.0)) for tok, p in pairs[:5]]
241
 
242
+ # Final display (probabilistic or visual)
243
+ attn_final, labels = _normalize_for_display_wordlevel(
244
+ attn_scores,
245
+ normalize_mode=normalize_mode,
246
+ temperature=0.30,
 
 
247
  )
 
 
248
 
249
+ highlighted_tokens = [(tok, labels[i]) for i, tok in enumerate(word_tokens)]
 
250
 
251
+ print("🧪 Normalization Mode Received:", normalize_mode)
252
+ if highlighted_tokens:
253
+ print("🟣 Highlighted tokens sample:", highlighted_tokens[:5])
254
+ else:
255
+ print("🟣 No highlighted tokens (no text or attentions unavailable).")
256
 
257
+ return inv_map[pred], cam_image, highlighted_tokens, confidence, probs.tolist(), top5
app/utils/test.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from pathlib import Path
4
+ from transformers import AutoTokenizer
5
+
6
+ ROOT_DIR = Path(__file__).resolve().parent.parent.parent
7
+ sys.path.append(str(ROOT_DIR))
8
+
9
+ from app.utils.inference_utils import load_model
10
+ from app.utils.attention_utils import extract_token_attention
11
+
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
14
+
15
+ # Load model from config
16
+ model = load_model("multimodal", "medi_llm_state_dict_multimodal.pth")
17
+
18
+ # Test input
19
+ text = "Patient-A reports shortness of breath and low oxygen levels."
20
+ tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
21
+ input_ids = tokens["input_ids"].to(DEVICE)
22
+ mask = tokens["attention_mask"].to(DEVICE)
23
+
24
+ # Extract token attention
25
+ attention = extract_token_attention(model, tokenizer, input_ids, mask)
26
+ print(attention)
config/config.yaml.example ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text:
2
+ lr: 1.8711332079056742e-05
3
+ dropout: 0.33274218952802376
4
+ hidden_dim: 512
5
+ batch_size: 8
6
+ epochs: 5
7
+ image:
8
+ lr: 9.99473327273459e-05
9
+ dropout: 0.4451972461446767
10
+ hidden_dim: 256
11
+ batch_size: 4
12
+ epochs: 5
13
+ multimodal:
14
+ lr: 3.7443867882936816e-05
15
+ dropout: 0.29940046032586376
16
+ hidden_dim: 512
17
+ batch_size: 4
18
+ epochs: 5
sample_data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/multimodal_model.py CHANGED
@@ -86,39 +86,51 @@ class MediLLMModel(nn.Module):
86
  nn.Linear(hidden_dim, num_classes), # Final Classification output
87
  )
88
 
89
- def forward(self, input_ids=None, attention_mask=None, image=None):
90
  # input_ids shape: [batch, seq_length]
91
  # attention_mask: mask to ignore padding, same shape as input_ids
92
  # image: [batch, 3, 224, 224]
93
  # Text features
94
- if self.mode == "text":
95
  text_outputs = self.text_encoder(
96
- input_ids=input_ids, attention_mask=attention_mask
 
 
97
  )
98
  # feed tokenized text into the BERT Model which returns a
99
  # dictionary with last_hidden_state: [batch_size, seq_len,
100
  # hidden_size], pooler_output: [batch_size, hidden_size]
101
  # (CLS embeddings), hidden_states: List of tensors,
102
  # attentions(weights): List of Tensors
103
- features = text_outputs.last_hidden_state[
104
- :, 0, :
105
- ] # CLS token, return CLS tokens from all batches, position 0,
106
  # a batch of 3 sentences has 3 CLS tokens
 
107
 
108
- # Image features
109
- elif self.mode == "image":
110
- features = self.image_encoder(
111
- image
112
- ) # pass the image through ResNet, returns a [batch, 2048] tensor
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
 
 
114
  else: # multimodal
115
- text_outputs = self.text_encoder(
116
- input_ids=input_ids, attention_mask=attention_mask
117
- )
118
- text_feat = text_outputs.last_hidden_state[:, 0, :] # CLS token
119
  image_feat = self.image_encoder(image)
120
  features = torch.cat(
121
- (text_feat, image_feat), dim=1
122
  ) # Concatenates text and image features along feature dimension
123
  # [CLS vector from BERT] + [ResNet image vector]
124
  # -> [batch_size, 2816]
@@ -143,4 +155,9 @@ class MediLLMModel(nn.Module):
143
  # return self.classifier(fused)
144
 
145
  # Return logits for each class, later apply softmax during evaluation
146
- return self.classifier(features)
 
 
 
 
 
 
86
  nn.Linear(hidden_dim, num_classes), # Final Classification output
87
  )
88
 
89
+ def forward(self, input_ids=None, attention_mask=None, image=None, output_attentions=False, return_raw_attentions=False):
90
  # input_ids shape: [batch, seq_length]
91
  # attention_mask: mask to ignore padding, same shape as input_ids
92
  # image: [batch, 3, 224, 224]
93
  # Text features
94
+ if self.mode in ["text", "multimodal"]:
95
  text_outputs = self.text_encoder(
96
+ input_ids=input_ids,
97
+ attention_mask=attention_mask,
98
+ output_attentions=output_attentions,
99
  )
100
  # feed tokenized text into the BERT Model which returns a
101
  # dictionary with last_hidden_state: [batch_size, seq_len,
102
  # hidden_size], pooler_output: [batch_size, hidden_size]
103
  # (CLS embeddings), hidden_states: List of tensors,
104
  # attentions(weights): List of Tensors
105
+ last_hidden = text_outputs.last_hidden_state # CLS token, return CLS tokens from all batches, position 0,
 
 
106
  # a batch of 3 sentences has 3 CLS tokens
107
+ cls_embedding = last_hidden[:, 0, :] # CLS tokens of all batches [batch, hidden_dim]
108
 
109
+ # Real token attention using last-layer CLS attention weights
110
+ # attentions = List[12 tensors] -> each [batch, heads, seq_len, seq_len]
111
+ token_attn_scores = None
112
+ raw_attentions = None
113
+ if output_attentions:
114
+ attention_maps = text_outputs.attentions
115
+ last_layer_attn = attention_maps[-1] # [batch, heads, seq_len, seq_len]
116
+ avg_attn = last_layer_attn.mean(dim=1) # Average across heads -> [batch, seq_len, seq_len]
117
+ token_attn_scores = avg_attn[:, 0, :] # CLS attends to all tokens -> [batch, seq_len]
118
+ if return_raw_attentions:
119
+ raw_attentions = attention_maps
120
+ else:
121
+ cls_embedding = None
122
+ token_attn_scores = None
123
+ raw_attentions = None
124
 
125
+ # Image features
126
+ if self.mode == "image":
127
+ features = self.image_encoder(image) # pass the image through ResNet, returns a [batch, 2048] tensor
128
+ elif self.mode == "text": # text
129
+ features = cls_embedding
130
  else: # multimodal
 
 
 
 
131
  image_feat = self.image_encoder(image)
132
  features = torch.cat(
133
+ (cls_embedding, image_feat), dim=1
134
  ) # Concatenates text and image features along feature dimension
135
  # [CLS vector from BERT] + [ResNet image vector]
136
  # -> [batch_size, 2816]
 
155
  # return self.classifier(fused)
156
 
157
  # Return logits for each class, later apply softmax during evaluation
158
+ logits = self.classifier(features)
159
+ return {
160
+ "logits": logits,
161
+ "token_attentions": token_attn_scores, # [batch, seq_len] or None
162
+ "raw_attentions": raw_attentions if return_raw_attentions else None,
163
+ }
tests/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
tests/__pycache__/test_dummy.cpython-310-pytest-8.4.1.pyc ADDED
Binary file (746 Bytes). View file
 
tests/__pycache__/test_generate_emr_csv.cpython-310-pytest-8.4.1.pyc ADDED
Binary file (15.8 kB). View file
 
tests/__pycache__/test_multimodal_model.cpython-310-pytest-8.4.1.pyc ADDED
Binary file (4.79 kB). View file
 
tests/__pycache__/test_triage_dataset.cpython-310-pytest-8.4.1.pyc ADDED
Binary file (4.47 kB). View file