Sirivennela's picture
Update app.py
457157a verified
import os
import numpy as np
import torch
from PIL import Image
import torchxrayvision as xrv
from transformers import pipeline
import gradio as gr
# -----------------------------
# Setup
# -----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_models():
chest_model = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
clip = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
return chest_model, clip
CHEST_MODEL, CLIP = load_models()
LABELS = CHEST_MODEL.pathologies
# -----------------------------
# Helper Functions
# -----------------------------
def detect_scan_type(pil_img):
candidates = ["chest", "bone", "brain", "abdomen", "dental", "cardiac"]
result = CLIP(pil_img, candidate_labels=candidates)
return max(result, key=lambda x: x['score'])['label']
def preprocess_chest(pil_img):
if pil_img.mode != "L":
pil_img = pil_img.convert("L")
img = np.array(pil_img).astype(np.float32)
img = xrv.datasets.normalize(img, 255)
img = img[None, ...]
img = xrv.datasets.XRayCenterCrop()(img)
img = xrv.datasets.XRayResizer(224)(img)
return torch.from_numpy(img).unsqueeze(0).to(DEVICE)
def analyse_chest(pil_img):
x = preprocess_chest(pil_img)
with torch.no_grad():
output = CHEST_MODEL(x)
probs = torch.sigmoid(output)[0].cpu().numpy() * 100
findings = [(LABELS[i], p) for i, p in enumerate(probs)]
findings = sorted(findings, key=lambda x: x[1], reverse=True)
return [(cond, conf) for cond, conf in findings if conf > 40][:7]
# -----------------------------
# Clinical Interpretation
# -----------------------------
def interpret_findings(findings):
strong, moderate, mild = [], [], []
for cond, conf in findings:
if conf >= 75:
strong.append((cond, conf))
elif conf >= 55:
moderate.append((cond, conf))
elif conf >= 40:
mild.append((cond, conf))
return strong, moderate, mild
def recommendations_for_condition(cond):
recs = {
"Fibrosis": "Pulmonology referral, consider HRCT chest.",
"Infiltration": "Suggest clinical correlation (cough, fever, infection signs). Consider antibiotics and follow-up imaging.",
"Mass": "Oncology referral; biopsy or CT chest may be warranted.",
"Nodule": "Follow Fleischner guidelines; consider repeat imaging or biopsy.",
"Pleural_Thickening": "Consider TB or asbestos exposure; recommend pulmonology referral.",
"Consolidation": "Likely pneumonia; recommend clinical correlation, antibiotics, and repeat X-ray in 6–8 weeks.",
"Effusion": "Evaluate cardiac/renal function; consider thoracentesis if large.",
"Atelectasis": "Encourage physiotherapy, bronchoscopy if persistent.",
"Cardiomegaly": "Cardiology referral, ECG, and echocardiogram recommended."
}
return recs.get(cond, "Correlation with symptoms and specialist referral advised.")
def format_report(scan_type, findings):
if scan_type == "chest":
if not findings:
return "βœ… **Chest X-ray Report**\n\nNormal chest. No significant abnormalities detected."
strong, moderate, mild = interpret_findings(findings)
report = "## 🫁 Chest X-ray Report\n\n"
# Summary
if strong:
report += f"**Overall Impression:** ⚠️ Abnormal – Strong evidence of **{', '.join([c for c, _ in strong])}**.\n\n"
elif moderate:
report += f"**Overall Impression:** ❗ Suspicious – Possible **{', '.join([c for c, _ in moderate])}**.\n\n"
else:
report += "**Overall Impression:** βœ… No major abnormalities detected.\n\n"
# Key Findings
report += "### πŸ” Key Findings\n"
if strong:
report += "\n**Significant Abnormalities:**\n"
for cond, conf in strong:
report += f"- {cond} ({conf:.1f}%) β†’ {recommendations_for_condition(cond)}\n"
if moderate:
report += "\n**Possible Abnormalities:**\n"
for cond, conf in moderate:
report += f"- {cond} ({conf:.1f}%) β†’ {recommendations_for_condition(cond)}\n"
if mild:
report += "\n**Minor / Non-specific Findings:**\n"
for cond, conf in mild:
report += f"- {cond} ({conf:.1f}%) β†’ Monitor clinically.\n"
return report
# ------------------ Other Scans -------------------
elif scan_type == "bone":
return """🦴 **Bone X-ray Report**
- Possible fracture or degenerative changes.
- Signs of joint space narrowing suggesting arthritis.
- Mild osteopenia changes noted.
- Recommendation: Orthopedic consultation and further imaging if required.
"""
elif scan_type == "brain":
return """🧠 **Brain MRI/CT Report**
- Possible lesion or abnormal density detected.
- Early ischemic changes cannot be ruled out.
- Signs suggestive of mass effect or edema.
- Recommendation: Neurology referral and MRI correlation.
"""
elif scan_type == "abdomen":
return """🩺 **Abdomen Scan Report**
- Abnormal soft tissue shadow or opacity.
- Possible hepatomegaly or splenomegaly noted.
- Signs of fluid collection (ascites) suspected.
- Recommendation: Gastroenterology referral.
"""
elif scan_type == "dental":
return """😁 **Dental X-ray Report**
- Possible dental cavities identified.
- Signs of gum disease or bone loss.
- Malalignment or impacted tooth noted.
- Recommendation: Dentist referral.
"""
elif scan_type == "cardiac":
return """❀️ **Cardiac Scan Report**
- Possible cardiomegaly (enlarged heart).
- Suspicious valve abnormality detected.
- Pulmonary congestion signs may be present.
- Recommendation: Cardiology referral.
"""
else:
return "⚠️ Unknown scan type. Please upload a valid medical scan."
# -----------------------------
# Gradio Interface
# -----------------------------
def analyse_scan(image):
pil_img = image.convert("RGB")
scan_type = detect_scan_type(pil_img)
findings = analyse_chest(pil_img) if scan_type == "chest" else None
return format_report(scan_type, findings)
demo = gr.Interface(
fn=analyse_scan,
inputs=gr.Image(type="pil"),
outputs="markdown",
title="🩻 Universal Radiology AI",
description="Upload a scan (Chest, Bone, Brain, Abdomen, Dental, Cardiac) to get an AI-generated structured radiology report. ⚠️ Not a medical device."
)
if __name__ == "__main__":
demo.launch()