import gradio as gr import torch from PIL import Image import torch.nn.functional as F from monai.networks.nets import DenseNet121 from torchvision import transforms # ---------- Load models ---------- model_body = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6) model_body.load_state_dict(torch.load("body_part_classifier.pth", map_location="cpu")) model_body.eval() body_classes = ["Abdomen-CT", "Breast-MRI", "Chest-CT", "ChestX-RAY", "HandX-RAY", "Head-CT"] model_chest = DenseNet121(spatial_dims=2, in_channels=1, out_channels=4) model_chest.load_state_dict(torch.load("chest_xray_fast.pth", map_location="cpu")) model_chest.eval() chest_classes = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"] model_head = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2) model_head.load_state_dict(torch.load("head_ct_2d_model.pth", map_location="cpu")) model_head.eval() transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) def get_chest_report(pathology, confidence): if pathology == "COVID": return f"**Findings:** Bilateral ground-glass opacities, peripheral distribution.\n**Impression:** COVID-19 pneumonia.\n**Recommendation:** PCR testing.\n**Confidence:** {confidence:.2f}%" elif pathology == "Lung_Opacity": return f"**Findings:** Patchy opacity.\n**Impression:** Lung opacity, ? infection.\n**Recommendation:** Follow-up imaging.\n**Confidence:** {confidence:.2f}%" elif pathology == "Normal": return f"**Findings:** Clear lungs.\n**Impression:** Normal chest X‑ray.\n**Recommendation:** None.\n**Confidence:** {confidence:.2f}%" elif pathology == "Viral Pneumonia": return f"**Findings:** Interstitial markings.\n**Impression:** Viral pneumonia.\n**Recommendation:** Clinical correlation.\n**Confidence:** {confidence:.2f}%" else: return f"**Report:** Unknown. Confidence: {confidence:.2f}%" def get_head_ct_report(prob_abnormal): if prob_abnormal > 0.5: return f"**Findings:** Hyperdense area, acute hemorrhage.\n**Impression:** Acute hemorrhage.\n**Recommendation:** Neurosurgical consultation.\n**Confidence:** {prob_abnormal:.2%}" else: return f"**Findings:** No hemorrhage.\n**Impression:** Normal head CT.\n**Recommendation:** Routine follow-up.\n**Confidence:** {(1-prob_abnormal):.2%}" def get_generic_report(body_part, confidence): return f"**Body part:** {body_part} ({confidence:.2f}%)\n**Report:** Normal study." def predict(image_input, modality): if image_input is None: return "Upload an image." if isinstance(image_input, str): image_pil = Image.open(image_input).convert('L') else: image_pil = image_input.convert('L') img_tensor = transform(image_pil).unsqueeze(0) if modality == "Chest X‑ray": with torch.no_grad(): out_chest = model_chest(img_tensor) prob_chest = F.softmax(out_chest, dim=1) conf_chest, idx_chest = torch.max(prob_chest, 1) pathology = chest_classes[idx_chest.item()] conf_chest_percent = conf_chest.item() * 100 report = get_chest_report(pathology, conf_chest_percent) return f"**Modality:** Chest X‑ray\n**Prediction:** {pathology}\n**Report:**\n{report}" elif modality == "Head CT": with torch.no_grad(): out_head = model_head(img_tensor) prob_head = F.softmax(out_head, dim=1) prob_abnormal = prob_head[0][1].item() head_report = get_head_ct_report(prob_abnormal) return f"**Modality:** Head CT\n**Report:**\n{head_report}" else: with torch.no_grad(): out_body = model_body(img_tensor) prob_body = F.softmax(out_body, dim=1) conf_body, idx_body = torch.max(prob_body, 1) body_part = body_classes[idx_body.item()] conf_body_percent = conf_body.item() * 100 generic_report = get_generic_report(body_part, conf_body_percent) return f"**Selected modality:** {modality}\n{generic_report}" # ---------- Gradio interface with logo using Blocks ---------- with gr.Blocks(title="InsightRay – Multi‑Modality Medical AI") as demo: gr.HTML('