cfoli commited on
Commit
0bbb9a2
·
verified ·
1 Parent(s): e2def54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -14
app.py CHANGED
@@ -30,7 +30,7 @@ import gradio
30
  from functools import partial
31
  from transformers import AutoModel
32
 
33
- """### Set parameters"""
34
 
35
  configs = {
36
  "IMAGE_SIZE": (224, 224), # Resize images to (W, H)
@@ -46,13 +46,40 @@ configs = {
46
  "THRESHOLD": 0.2
47
  }
48
 
49
- MODEL_REGISTRY = {
50
  "CheXFormer-small": "m42-health/CXformer-small",
51
  # "CheXFormer-base": "m42-health/CXformer-base",
52
- "ViT-base-16": "google/vit-base-patch16-224",
53
- }
54
 
55
- MODEL_CACHE = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  """### Define helper functions"""
58
 
@@ -69,7 +96,7 @@ class get_pretrained_model(nn.Module):
69
  print(f"Loading pretrained [{model_name}] model")
70
 
71
  self.backbone = AutoModel.from_pretrained(
72
- MODEL_REGISTRY[model_name],
73
  # model_name,
74
  trust_remote_code=True)
75
 
@@ -243,9 +270,38 @@ class modelModule(torch_light.LightningModule):
243
 
244
  """### Create function for running inference (i.e., assistive medical diagnosis)"""
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  @torch.inference_mode()
247
  def run_diagnosis(
248
  backbone_name,
 
249
  input_image,
250
  threshold,
251
  preprocess_fn=None,
@@ -261,14 +317,16 @@ def run_diagnosis(
261
  if not os.path.exists(ckpt_path):
262
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
263
 
264
- # Load model (cache for speed)
265
  if backbone_name not in MODEL_CACHE:
266
- MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
267
  ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2)
268
- model = MODEL_CACHE[backbone_name]
269
 
270
  model.eval()
271
 
 
 
272
  # Forward
273
  logits = model(x)
274
  probs = torch.sigmoid(logits)[0].cpu().numpy()
@@ -278,10 +336,24 @@ def run_diagnosis(
278
  }
279
 
280
  predicted_classes = [
281
- Idx2labels[i] for i, p in enumerate(probs) if p >= threshold
282
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- return "\n".join(predicted_classes), output_probs
285
 
286
  """### Gradio app"""
287
 
@@ -299,12 +371,14 @@ example_list = [
299
  gradio_app = gradio.Interface(
300
  fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict),
301
 
302
- inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="CheXFormer-small", label="Select Backbone Model"),
 
303
  gradio.Image(type="pil", label="Load chest-X-ray image here"),
304
  gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.2, label = "Set Prediction Threshold")],
305
 
306
  outputs = [gradio.Textbox(label="Predicted Medical Condition(s)"),
307
- gradio.Label(label="Predicted Probabilities", show_label=False)],
 
308
 
309
  examples = example_list,
310
  cache_examples = False,
 
30
  from functools import partial
31
  from transformers import AutoModel
32
 
33
+ """### Initialize Containers"""
34
 
35
  configs = {
36
  "IMAGE_SIZE": (224, 224), # Resize images to (W, H)
 
46
  "THRESHOLD": 0.2
47
  }
48
 
49
+ ViT_REGISTRY = {
50
  "CheXFormer-small": "m42-health/CXformer-small",
51
  # "CheXFormer-base": "m42-health/CXformer-base",
52
+ "ViT-base-16": "google/vit-base-patch16-224"}
 
53
 
54
+ VLM_REGISTRY = {
55
+ "MedMO": "MBZUAI/MedMO-8B",
56
+ "Qwen3-VL-2B": "Qwen/Qwen3-VL-2B-Instruct",
57
+ "Lingshu-7B": "lingshu-medical-mllm/Lingshu-7B",
58
+ "MedGemma-4b": "google/medgemma-1.5-4b-it"}
59
+
60
+ VLM_SYSTEM_PROMPT = """ You are a medical imaging assistant specializing in chest radiography.
61
+
62
+ A trained multi-label classifier analyzed a chest X-ray and made a prediction, including predicted medical condition(s) and their associated probabilities:
63
+
64
+ Your task:
65
+ 1. Analyze the chest X-ray image to identify key features supporting the predicted condition(s).
66
+ 2. Do NOT introduce new diagnoses.
67
+ 3. Only explain radiographic findings that could support the listed prediction(s).
68
+ 4. Use cautious, uncertainty-aware language.
69
+ 5. If probability < 0.50, emphasize uncertainty.
70
+ 6. Do NOT contradict the classifier.
71
+
72
+ Structure your answer as:
73
+
74
+ Observed Radiographic Findings:
75
+ ...
76
+
77
+ How Chest X-ray Features Support the Predicted Conditions:
78
+ ...
79
+ """
80
+
81
+ ViT_MODEL_CACHE = {}
82
+ VLM_MODEL_CACHE = {}
83
 
84
  """### Define helper functions"""
85
 
 
96
  print(f"Loading pretrained [{model_name}] model")
97
 
98
  self.backbone = AutoModel.from_pretrained(
99
+ ViT_REGISTRY[model_name],
100
  # model_name,
101
  trust_remote_code=True)
102
 
 
270
 
271
  """### Create function for running inference (i.e., assistive medical diagnosis)"""
272
 
273
+ def generate_query(formatted_predictions):
274
+ return f"""
275
+ The predicted conditions and their corresponding probabilities are given by the following dictionary:
276
+
277
+ {formatted_predictions}
278
+
279
+ What features of the chest X-ray image support the predicted condition(s)?
280
+ """
281
+
282
+ def predictionReportGenerator(vlm_model, image_path, system_prompt, query_prompt):
283
+ image_ = Image.open(image_path).convert("RGB")
284
+ messages = [
285
+ {
286
+ "role": "system",
287
+ "content": [{"type": "text", "text": f"{system_prompt}"}]},
288
+ {
289
+ "role": "user",
290
+ "content": [
291
+ {"type": "image", "image": image_},
292
+ {"type": "text", "text": f"{query_prompt}"}]}]
293
+
294
+ output = vlm_model(text=messages, max_new_tokens=350)
295
+ prediction_explanation = output[0]["generated_text"][-1]["content"]
296
+
297
+ return prediction_explanation
298
+
299
+
300
+
301
  @torch.inference_mode()
302
  def run_diagnosis(
303
  backbone_name,
304
+ vlm_name,
305
  input_image,
306
  threshold,
307
  preprocess_fn=None,
 
317
  if not os.path.exists(ckpt_path):
318
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
319
 
320
+ # Load classification model (cache for speed)
321
  if backbone_name not in MODEL_CACHE:
322
+ ViT_MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
323
  ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2)
324
+ model = ViT_MODEL_CACHE[backbone_name]
325
 
326
  model.eval()
327
 
328
+ # device = 0 if torch.cuda.is_available() else -1
329
+
330
  # Forward
331
  logits = model(x)
332
  probs = torch.sigmoid(logits)[0].cpu().numpy()
 
336
  }
337
 
338
  predicted_classes = [
339
+ Idx2labels[i] for i, p in enumerate(probs) if p >= threshold]
340
+
341
+ explanation_ = "No prediction was made."
342
+ if predicted_classes not []:
343
+ # Load model (cache for speed)
344
+ if model_key not in MODEL_CACHE:
345
+ VLM_MODEL_CACHE[model_key] = pipeline(task = "image-text-to-text",
346
+ model = VLM_REGISTRY[vlm_name],
347
+ trust_remote_code = True)
348
+ VLM_model = VLM_MODEL_CACHE[model_key]
349
+
350
+ formatted_predictions = {label: output_probs[label] for label in predicted_classes}
351
+
352
+ query_prompt = generate_query(formatted_predictions)
353
+ explanation_ = predictionReportGenerator(vlm_model = VLM_model, image_path = input_image,
354
+ system_prompt = VLM_SYSTEM_PROMPT, query_prompt = query_prompt)
355
 
356
+ return "\n".join(predicted_classes), explanation_, output_probs
357
 
358
  """### Gradio app"""
359
 
 
371
  gradio_app = gradio.Interface(
372
  fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict),
373
 
374
+ inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="CheXFormer-small", label="Select Classification Model"),
375
+ gradio.Dropdown(["MedGemma-4b", "MedMO", "Lingshu-7B", "Qwen3-VL-2B"], value="Lingshu-7B", label="Select Explanation Model"),
376
  gradio.Image(type="pil", label="Load chest-X-ray image here"),
377
  gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.2, label = "Set Prediction Threshold")],
378
 
379
  outputs = [gradio.Textbox(label="Predicted Medical Condition(s)"),
380
+ gradio.Textbox(label="Prediction Report"),
381
+ gradio.Label(label="Predicted Probabilities", show_label=False)],
382
 
383
  examples = example_list,
384
  cache_examples = False,