Spaces:
Running
Running
File size: 5,177 Bytes
7778c34 e3e881d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import gradio as gr
import torch
from PIL import Image
import torch.nn.functional as F
from monai.networks.nets import DenseNet121
from torchvision import transforms
# ---------- Load models ----------
model_body = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6)
model_body.load_state_dict(torch.load("body_part_classifier.pth", map_location="cpu"))
model_body.eval()
body_classes = ["Abdomen-CT", "Breast-MRI", "Chest-CT", "ChestX-RAY", "HandX-RAY", "Head-CT"]
model_chest = DenseNet121(spatial_dims=2, in_channels=1, out_channels=4)
model_chest.load_state_dict(torch.load("chest_xray_fast.pth", map_location="cpu"))
model_chest.eval()
chest_classes = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]
model_head = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2)
model_head.load_state_dict(torch.load("head_ct_2d_model.pth", map_location="cpu"))
model_head.eval()
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def get_chest_report(pathology, confidence):
if pathology == "COVID":
return f"**Findings:** Bilateral ground-glass opacities, peripheral distribution.\n**Impression:** COVID-19 pneumonia.\n**Recommendation:** PCR testing.\n**Confidence:** {confidence:.2f}%"
elif pathology == "Lung_Opacity":
return f"**Findings:** Patchy opacity.\n**Impression:** Lung opacity, ? infection.\n**Recommendation:** Follow-up imaging.\n**Confidence:** {confidence:.2f}%"
elif pathology == "Normal":
return f"**Findings:** Clear lungs.\n**Impression:** Normal chest X‑ray.\n**Recommendation:** None.\n**Confidence:** {confidence:.2f}%"
elif pathology == "Viral Pneumonia":
return f"**Findings:** Interstitial markings.\n**Impression:** Viral pneumonia.\n**Recommendation:** Clinical correlation.\n**Confidence:** {confidence:.2f}%"
else:
return f"**Report:** Unknown. Confidence: {confidence:.2f}%"
def get_head_ct_report(prob_abnormal):
if prob_abnormal > 0.5:
return f"**Findings:** Hyperdense area, acute hemorrhage.\n**Impression:** Acute hemorrhage.\n**Recommendation:** Neurosurgical consultation.\n**Confidence:** {prob_abnormal:.2%}"
else:
return f"**Findings:** No hemorrhage.\n**Impression:** Normal head CT.\n**Recommendation:** Routine follow-up.\n**Confidence:** {(1-prob_abnormal):.2%}"
def get_generic_report(body_part, confidence):
return f"**Body part:** {body_part} ({confidence:.2f}%)\n**Report:** Normal study."
def predict(image_input, modality):
if image_input is None:
return "Upload an image."
if isinstance(image_input, str):
image_pil = Image.open(image_input).convert('L')
else:
image_pil = image_input.convert('L')
img_tensor = transform(image_pil).unsqueeze(0)
if modality == "Chest X‑ray":
with torch.no_grad():
out_chest = model_chest(img_tensor)
prob_chest = F.softmax(out_chest, dim=1)
conf_chest, idx_chest = torch.max(prob_chest, 1)
pathology = chest_classes[idx_chest.item()]
conf_chest_percent = conf_chest.item() * 100
report = get_chest_report(pathology, conf_chest_percent)
return f"**Modality:** Chest X‑ray\n**Prediction:** {pathology}\n**Report:**\n{report}"
elif modality == "Head CT":
with torch.no_grad():
out_head = model_head(img_tensor)
prob_head = F.softmax(out_head, dim=1)
prob_abnormal = prob_head[0][1].item()
head_report = get_head_ct_report(prob_abnormal)
return f"**Modality:** Head CT\n**Report:**\n{head_report}"
else:
with torch.no_grad():
out_body = model_body(img_tensor)
prob_body = F.softmax(out_body, dim=1)
conf_body, idx_body = torch.max(prob_body, 1)
body_part = body_classes[idx_body.item()]
conf_body_percent = conf_body.item() * 100
generic_report = get_generic_report(body_part, conf_body_percent)
return f"**Selected modality:** {modality}\n{generic_report}"
# ---------- Gradio interface with logo using Blocks ----------
with gr.Blocks(title="InsightRay – Multi‑Modality Medical AI") as demo:
gr.HTML('<div style="text-align: center; margin-bottom: 20px;"><img src="insightray_icon.svg" width="100" style="display: inline-block;"/></div>')
gr.Markdown("## InsightRay – Multi‑Modality Medical AI")
gr.Markdown("Choose the image type and upload a medical image. For chest X‑ray and head CT, detailed AI reports are generated.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Medical Image")
modality = gr.Radio(choices=["Chest X‑ray", "Head CT", "Hand X‑ray", "Abdomen CT", "Breast MRI", "Chest CT"],
label="Select Modality", value="Chest X‑ray")
submit_btn = gr.Button("Analyse")
with gr.Column():
output = gr.Textbox(label="AI Report", lines=20)
submit_btn.click(fn=predict, inputs=[image_input, modality], outputs=output)
demo.launch() |