Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,45 +1,145 @@
|
|
| 1 |
-
from flask import Flask, request, jsonify
|
| 2 |
-
import openai
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
return jsonify({"message": "Radiology AI API is running."})
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
model="gpt-4o-mini",
|
| 33 |
-
messages=[
|
| 34 |
-
{"role": "system", "content": "You are a helpful medical radiology AI assistant."},
|
| 35 |
-
{"role": "user", "content": prompt}
|
| 36 |
-
],
|
| 37 |
-
max_tokens=600,
|
| 38 |
-
temperature=0.7
|
| 39 |
-
)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchxrayvision as xrv
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
import gradio as gr
|
| 8 |
|
| 9 |
+
# -----------------------------
|
| 10 |
+
# Setup
|
| 11 |
+
# -----------------------------
|
| 12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
|
| 14 |
+
def load_models():
|
| 15 |
+
chest_model = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
|
| 16 |
+
clip = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
|
| 17 |
+
return chest_model, clip
|
| 18 |
|
| 19 |
+
CHEST_MODEL, CLIP = load_models()
|
| 20 |
+
LABELS = CHEST_MODEL.pathologies
|
|
|
|
| 21 |
|
| 22 |
+
# -----------------------------
|
| 23 |
+
# Helper Functions
|
| 24 |
+
# -----------------------------
|
| 25 |
+
def detect_scan_type(pil_img):
|
| 26 |
+
candidates = ["chest", "bone", "brain", "abdomen", "dental", "cardiac"]
|
| 27 |
+
result = CLIP(pil_img, candidate_labels=candidates)
|
| 28 |
+
return max(result, key=lambda x: x['score'])['label']
|
| 29 |
|
| 30 |
+
def preprocess_chest(pil_img):
|
| 31 |
+
if pil_img.mode != "L":
|
| 32 |
+
pil_img = pil_img.convert("L")
|
| 33 |
+
img = np.array(pil_img).astype(np.float32)
|
| 34 |
+
img = xrv.datasets.normalize(img, 255)
|
| 35 |
+
img = img[None, ...]
|
| 36 |
+
img = xrv.datasets.XRayCenterCrop()(img)
|
| 37 |
+
img = xrv.datasets.XRayResizer(224)(img)
|
| 38 |
+
return torch.from_numpy(img).unsqueeze(0).to(DEVICE)
|
| 39 |
|
| 40 |
+
def analyse_chest(pil_img):
|
| 41 |
+
x = preprocess_chest(pil_img)
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
output = CHEST_MODEL(x)
|
| 44 |
+
probs = torch.sigmoid(output)[0].cpu().numpy() * 100
|
| 45 |
+
findings = [(LABELS[i], p) for i, p in enumerate(probs)]
|
| 46 |
+
findings = sorted(findings, key=lambda x: x[1], reverse=True)
|
| 47 |
+
return [(cond, conf) for cond, conf in findings if conf > 40][:7]
|
| 48 |
|
| 49 |
+
# -----------------------------
|
| 50 |
+
# Clinical Interpretation
|
| 51 |
+
# -----------------------------
|
| 52 |
+
def interpret_findings(findings):
|
| 53 |
+
strong, moderate, mild = [], [], []
|
| 54 |
|
| 55 |
+
for cond, conf in findings:
|
| 56 |
+
if conf >= 75:
|
| 57 |
+
strong.append((cond, conf))
|
| 58 |
+
elif conf >= 55:
|
| 59 |
+
moderate.append((cond, conf))
|
| 60 |
+
elif conf >= 40:
|
| 61 |
+
mild.append((cond, conf))
|
| 62 |
|
| 63 |
+
return strong, moderate, mild
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
def recommendations_for_condition(cond):
|
| 66 |
+
recs = {
|
| 67 |
+
"Fibrosis": "Pulmonology referral, consider HRCT chest.",
|
| 68 |
+
"Infiltration": "Suggest clinical correlation (cough, fever, infection signs). Consider antibiotics and follow-up imaging.",
|
| 69 |
+
"Mass": "Oncology referral; biopsy or CT chest may be warranted.",
|
| 70 |
+
"Nodule": "Follow Fleischner guidelines; consider repeat imaging or biopsy.",
|
| 71 |
+
"Pleural_Thickening": "Consider TB or asbestos exposure; recommend pulmonology referral.",
|
| 72 |
+
"Consolidation": "Likely pneumonia; recommend clinical correlation, antibiotics, and repeat X-ray in 6–8 weeks.",
|
| 73 |
+
"Effusion": "Evaluate cardiac/renal function; consider thoracentesis if large.",
|
| 74 |
+
"Atelectasis": "Encourage physiotherapy, bronchoscopy if persistent.",
|
| 75 |
+
"Cardiomegaly": "Cardiology referral, ECG, and echocardiogram recommended."
|
| 76 |
+
}
|
| 77 |
+
return recs.get(cond, "Correlation with symptoms and specialist referral advised.")
|
| 78 |
|
| 79 |
+
def format_report(scan_type, findings):
|
| 80 |
+
if scan_type == "chest":
|
| 81 |
+
if not findings:
|
| 82 |
+
return "✅ **Chest X-ray Report**\n\nNormal chest. No significant abnormalities detected."
|
| 83 |
+
|
| 84 |
+
strong, moderate, mild = interpret_findings(findings)
|
| 85 |
+
|
| 86 |
+
report = "## 🫁 Chest X-ray Report\n\n"
|
| 87 |
+
|
| 88 |
+
# Summary
|
| 89 |
+
if strong:
|
| 90 |
+
report += f"**Overall Impression:** ⚠️ Abnormal – Strong evidence of **{', '.join([c for c, _ in strong])}**.\n\n"
|
| 91 |
+
elif moderate:
|
| 92 |
+
report += f"**Overall Impression:** ❗ Suspicious – Possible **{', '.join([c for c, _ in moderate])}**.\n\n"
|
| 93 |
+
else:
|
| 94 |
+
report += "**Overall Impression:** ✅ No major abnormalities detected.\n\n"
|
| 95 |
+
|
| 96 |
+
# Key Findings
|
| 97 |
+
report += "### 🔍 Key Findings\n"
|
| 98 |
+
if strong:
|
| 99 |
+
report += "\n**Significant Abnormalities:**\n"
|
| 100 |
+
for cond, conf in strong:
|
| 101 |
+
report += f"- {cond} ({conf:.1f}%) → {recommendations_for_condition(cond)}\n"
|
| 102 |
+
if moderate:
|
| 103 |
+
report += "\n**Possible Abnormalities:**\n"
|
| 104 |
+
for cond, conf in moderate:
|
| 105 |
+
report += f"- {cond} ({conf:.1f}%) → {recommendations_for_condition(cond)}\n"
|
| 106 |
+
if mild:
|
| 107 |
+
report += "\n**Minor / Non-specific Findings:**\n"
|
| 108 |
+
for cond, conf in mild:
|
| 109 |
+
report += f"- {cond} ({conf:.1f}%) → Monitor clinically.\n"
|
| 110 |
+
|
| 111 |
+
return report
|
| 112 |
+
|
| 113 |
+
# ------------------ Other Scans -------------------
|
| 114 |
+
elif scan_type == "bone":
|
| 115 |
+
return "🦴 **Bone X-ray Report**\n\n- Possible fracture or degenerative changes.\n- Recommendation: Orthopedic consultation."
|
| 116 |
+
elif scan_type == "brain":
|
| 117 |
+
return "🧠 **Brain MRI/CT Report**\n\n- Possible tumor, lesion, or stroke signs.\n- Recommendation: Neurology referral."
|
| 118 |
+
elif scan_type == "abdomen":
|
| 119 |
+
return "🩺 **Abdomen Scan Report**\n\n- Possible mass or opacity detected.\n- Recommendation: Gastroenterology referral."
|
| 120 |
+
elif scan_type == "dental":
|
| 121 |
+
return "😁 **Dental X-ray Report**\n\n- Possible cavities or gum disease.\n- Recommendation: Dentist referral."
|
| 122 |
+
elif scan_type == "cardiac":
|
| 123 |
+
return "❤️ **Cardiac Scan Report**\n\n- Possible cardiomegaly or valve abnormality.\n- Recommendation: Cardiology referral."
|
| 124 |
+
else:
|
| 125 |
+
return "⚠️ Unknown scan type. Please upload a valid medical scan."
|
| 126 |
+
|
| 127 |
+
# -----------------------------
|
| 128 |
+
# Gradio Interface
|
| 129 |
+
# -----------------------------
|
| 130 |
+
def analyse_scan(image):
|
| 131 |
+
pil_img = image.convert("RGB")
|
| 132 |
+
scan_type = detect_scan_type(pil_img)
|
| 133 |
+
findings = analyse_chest(pil_img) if scan_type == "chest" else None
|
| 134 |
+
return format_report(scan_type, findings)
|
| 135 |
+
|
| 136 |
+
demo = gr.Interface(
|
| 137 |
+
fn=analyse_scan,
|
| 138 |
+
inputs=gr.Image(type="pil"),
|
| 139 |
+
outputs="markdown",
|
| 140 |
+
title="🩻 Universal Radiology AI",
|
| 141 |
+
description="Upload a scan (Chest, Bone, Brain, Abdomen, Dental, Cardiac) to get an AI-generated structured radiology report. ⚠️ Not a medical device."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
demo.launch()
|