Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,7 +30,7 @@ import gradio
|
|
| 30 |
from functools import partial
|
| 31 |
from transformers import AutoModel
|
| 32 |
|
| 33 |
-
"""###
|
| 34 |
|
| 35 |
configs = {
|
| 36 |
"IMAGE_SIZE": (224, 224), # Resize images to (W, H)
|
|
@@ -46,13 +46,40 @@ configs = {
|
|
| 46 |
"THRESHOLD": 0.2
|
| 47 |
}
|
| 48 |
|
| 49 |
-
|
| 50 |
"CheXFormer-small": "m42-health/CXformer-small",
|
| 51 |
# "CheXFormer-base": "m42-health/CXformer-base",
|
| 52 |
-
"ViT-base-16": "google/vit-base-patch16-224"
|
| 53 |
-
}
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
"""### Define helper functions"""
|
| 58 |
|
|
@@ -69,7 +96,7 @@ class get_pretrained_model(nn.Module):
|
|
| 69 |
print(f"Loading pretrained [{model_name}] model")
|
| 70 |
|
| 71 |
self.backbone = AutoModel.from_pretrained(
|
| 72 |
-
|
| 73 |
# model_name,
|
| 74 |
trust_remote_code=True)
|
| 75 |
|
|
@@ -243,9 +270,38 @@ class modelModule(torch_light.LightningModule):
|
|
| 243 |
|
| 244 |
"""### Create function for running inference (i.e., assistive medical diagnosis)"""
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
@torch.inference_mode()
|
| 247 |
def run_diagnosis(
|
| 248 |
backbone_name,
|
|
|
|
| 249 |
input_image,
|
| 250 |
threshold,
|
| 251 |
preprocess_fn=None,
|
|
@@ -261,14 +317,16 @@ def run_diagnosis(
|
|
| 261 |
if not os.path.exists(ckpt_path):
|
| 262 |
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 263 |
|
| 264 |
-
# Load model (cache for speed)
|
| 265 |
if backbone_name not in MODEL_CACHE:
|
| 266 |
-
|
| 267 |
ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2)
|
| 268 |
-
model =
|
| 269 |
|
| 270 |
model.eval()
|
| 271 |
|
|
|
|
|
|
|
| 272 |
# Forward
|
| 273 |
logits = model(x)
|
| 274 |
probs = torch.sigmoid(logits)[0].cpu().numpy()
|
|
@@ -278,10 +336,24 @@ def run_diagnosis(
|
|
| 278 |
}
|
| 279 |
|
| 280 |
predicted_classes = [
|
| 281 |
-
Idx2labels[i] for i, p in enumerate(probs) if p >= threshold
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
-
return "\n".join(predicted_classes), output_probs
|
| 285 |
|
| 286 |
"""### Gradio app"""
|
| 287 |
|
|
@@ -299,12 +371,14 @@ example_list = [
|
|
| 299 |
gradio_app = gradio.Interface(
|
| 300 |
fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict),
|
| 301 |
|
| 302 |
-
inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="CheXFormer-small", label="Select
|
|
|
|
| 303 |
gradio.Image(type="pil", label="Load chest-X-ray image here"),
|
| 304 |
gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.2, label = "Set Prediction Threshold")],
|
| 305 |
|
| 306 |
outputs = [gradio.Textbox(label="Predicted Medical Condition(s)"),
|
| 307 |
-
|
|
|
|
| 308 |
|
| 309 |
examples = example_list,
|
| 310 |
cache_examples = False,
|
|
|
|
| 30 |
from functools import partial
|
| 31 |
from transformers import AutoModel
|
| 32 |
|
| 33 |
+
"""### Initialize Containers"""
|
| 34 |
|
| 35 |
configs = {
|
| 36 |
"IMAGE_SIZE": (224, 224), # Resize images to (W, H)
|
|
|
|
| 46 |
"THRESHOLD": 0.2
|
| 47 |
}
|
| 48 |
|
| 49 |
+
ViT_REGISTRY = {
|
| 50 |
"CheXFormer-small": "m42-health/CXformer-small",
|
| 51 |
# "CheXFormer-base": "m42-health/CXformer-base",
|
| 52 |
+
"ViT-base-16": "google/vit-base-patch16-224"}
|
|
|
|
| 53 |
|
| 54 |
+
VLM_REGISTRY = {
|
| 55 |
+
"MedMO": "MBZUAI/MedMO-8B",
|
| 56 |
+
"Qwen3-VL-2B": "Qwen/Qwen3-VL-2B-Instruct",
|
| 57 |
+
"Lingshu-7B": "lingshu-medical-mllm/Lingshu-7B",
|
| 58 |
+
"MedGemma-4b": "google/medgemma-1.5-4b-it"}
|
| 59 |
+
|
| 60 |
+
VLM_SYSTEM_PROMPT = """ You are a medical imaging assistant specializing in chest radiography.
|
| 61 |
+
|
| 62 |
+
A trained multi-label classifier analyzed a chest X-ray and made a prediction, including predicted medical condition(s) and their associated probabilities:
|
| 63 |
+
|
| 64 |
+
Your task:
|
| 65 |
+
1. Analyze the chest X-ray image to identify key features supporting the predicted condition(s).
|
| 66 |
+
2. Do NOT introduce new diagnoses.
|
| 67 |
+
3. Only explain radiographic findings that could support the listed prediction(s).
|
| 68 |
+
4. Use cautious, uncertainty-aware language.
|
| 69 |
+
5. If probability < 0.50, emphasize uncertainty.
|
| 70 |
+
6. Do NOT contradict the classifier.
|
| 71 |
+
|
| 72 |
+
Structure your answer as:
|
| 73 |
+
|
| 74 |
+
Observed Radiographic Findings:
|
| 75 |
+
...
|
| 76 |
+
|
| 77 |
+
How Chest X-ray Features Support the Predicted Conditions:
|
| 78 |
+
...
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
ViT_MODEL_CACHE = {}
|
| 82 |
+
VLM_MODEL_CACHE = {}
|
| 83 |
|
| 84 |
"""### Define helper functions"""
|
| 85 |
|
|
|
|
| 96 |
print(f"Loading pretrained [{model_name}] model")
|
| 97 |
|
| 98 |
self.backbone = AutoModel.from_pretrained(
|
| 99 |
+
ViT_REGISTRY[model_name],
|
| 100 |
# model_name,
|
| 101 |
trust_remote_code=True)
|
| 102 |
|
|
|
|
| 270 |
|
| 271 |
"""### Create function for running inference (i.e., assistive medical diagnosis)"""
|
| 272 |
|
| 273 |
+
def generate_query(formatted_predictions):
|
| 274 |
+
return f"""
|
| 275 |
+
The predicted conditions and their corresponding probabilities are given by the following dictionary:
|
| 276 |
+
|
| 277 |
+
{formatted_predictions}
|
| 278 |
+
|
| 279 |
+
What features of the chest X-ray image support the predicted condition(s)?
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
def predictionReportGenerator(vlm_model, image_path, system_prompt, query_prompt):
|
| 283 |
+
image_ = Image.open(image_path).convert("RGB")
|
| 284 |
+
messages = [
|
| 285 |
+
{
|
| 286 |
+
"role": "system",
|
| 287 |
+
"content": [{"type": "text", "text": f"{system_prompt}"}]},
|
| 288 |
+
{
|
| 289 |
+
"role": "user",
|
| 290 |
+
"content": [
|
| 291 |
+
{"type": "image", "image": image_},
|
| 292 |
+
{"type": "text", "text": f"{query_prompt}"}]}]
|
| 293 |
+
|
| 294 |
+
output = vlm_model(text=messages, max_new_tokens=350)
|
| 295 |
+
prediction_explanation = output[0]["generated_text"][-1]["content"]
|
| 296 |
+
|
| 297 |
+
return prediction_explanation
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
@torch.inference_mode()
|
| 302 |
def run_diagnosis(
|
| 303 |
backbone_name,
|
| 304 |
+
vlm_name,
|
| 305 |
input_image,
|
| 306 |
threshold,
|
| 307 |
preprocess_fn=None,
|
|
|
|
| 317 |
if not os.path.exists(ckpt_path):
|
| 318 |
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 319 |
|
| 320 |
+
# Load classification model (cache for speed)
|
| 321 |
if backbone_name not in MODEL_CACHE:
|
| 322 |
+
ViT_MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
|
| 323 |
ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2)
|
| 324 |
+
model = ViT_MODEL_CACHE[backbone_name]
|
| 325 |
|
| 326 |
model.eval()
|
| 327 |
|
| 328 |
+
# device = 0 if torch.cuda.is_available() else -1
|
| 329 |
+
|
| 330 |
# Forward
|
| 331 |
logits = model(x)
|
| 332 |
probs = torch.sigmoid(logits)[0].cpu().numpy()
|
|
|
|
| 336 |
}
|
| 337 |
|
| 338 |
predicted_classes = [
|
| 339 |
+
Idx2labels[i] for i, p in enumerate(probs) if p >= threshold]
|
| 340 |
+
|
| 341 |
+
explanation_ = "No prediction was made."
|
| 342 |
+
if predicted_classes not []:
|
| 343 |
+
# Load model (cache for speed)
|
| 344 |
+
if model_key not in MODEL_CACHE:
|
| 345 |
+
VLM_MODEL_CACHE[model_key] = pipeline(task = "image-text-to-text",
|
| 346 |
+
model = VLM_REGISTRY[vlm_name],
|
| 347 |
+
trust_remote_code = True)
|
| 348 |
+
VLM_model = VLM_MODEL_CACHE[model_key]
|
| 349 |
+
|
| 350 |
+
formatted_predictions = {label: output_probs[label] for label in predicted_classes}
|
| 351 |
+
|
| 352 |
+
query_prompt = generate_query(formatted_predictions)
|
| 353 |
+
explanation_ = predictionReportGenerator(vlm_model = VLM_model, image_path = input_image,
|
| 354 |
+
system_prompt = VLM_SYSTEM_PROMPT, query_prompt = query_prompt)
|
| 355 |
|
| 356 |
+
return "\n".join(predicted_classes), explanation_, output_probs
|
| 357 |
|
| 358 |
"""### Gradio app"""
|
| 359 |
|
|
|
|
| 371 |
gradio_app = gradio.Interface(
|
| 372 |
fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict),
|
| 373 |
|
| 374 |
+
inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="CheXFormer-small", label="Select Classification Model"),
|
| 375 |
+
gradio.Dropdown(["MedGemma-4b", "MedMO", "Lingshu-7B", "Qwen3-VL-2B"], value="Lingshu-7B", label="Select Explanation Model"),
|
| 376 |
gradio.Image(type="pil", label="Load chest-X-ray image here"),
|
| 377 |
gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.2, label = "Set Prediction Threshold")],
|
| 378 |
|
| 379 |
outputs = [gradio.Textbox(label="Predicted Medical Condition(s)"),
|
| 380 |
+
gradio.Textbox(label="Prediction Report"),
|
| 381 |
+
gradio.Label(label="Predicted Probabilities", show_label=False)],
|
| 382 |
|
| 383 |
examples = example_list,
|
| 384 |
cache_examples = False,
|