Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from monai.networks.nets import DenseNet121 | |
| from torchvision import transforms | |
| # ---------- Load models ---------- | |
| model_body = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6) | |
| model_body.load_state_dict(torch.load("body_part_classifier.pth", map_location="cpu")) | |
| model_body.eval() | |
| body_classes = ["Abdomen-CT", "Breast-MRI", "Chest-CT", "ChestX-RAY", "HandX-RAY", "Head-CT"] | |
| model_chest = DenseNet121(spatial_dims=2, in_channels=1, out_channels=4) | |
| model_chest.load_state_dict(torch.load("chest_xray_fast.pth", map_location="cpu")) | |
| model_chest.eval() | |
| chest_classes = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"] | |
| model_head = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2) | |
| model_head.load_state_dict(torch.load("head_ct_2d_model.pth", map_location="cpu")) | |
| model_head.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| def get_chest_report(pathology, confidence): | |
| if pathology == "COVID": | |
| return f"**Findings:** Bilateral ground-glass opacities, peripheral distribution.\n**Impression:** COVID-19 pneumonia.\n**Recommendation:** PCR testing.\n**Confidence:** {confidence:.2f}%" | |
| elif pathology == "Lung_Opacity": | |
| return f"**Findings:** Patchy opacity.\n**Impression:** Lung opacity, ? infection.\n**Recommendation:** Follow-up imaging.\n**Confidence:** {confidence:.2f}%" | |
| elif pathology == "Normal": | |
| return f"**Findings:** Clear lungs.\n**Impression:** Normal chest X‑ray.\n**Recommendation:** None.\n**Confidence:** {confidence:.2f}%" | |
| elif pathology == "Viral Pneumonia": | |
| return f"**Findings:** Interstitial markings.\n**Impression:** Viral pneumonia.\n**Recommendation:** Clinical correlation.\n**Confidence:** {confidence:.2f}%" | |
| else: | |
| return f"**Report:** Unknown. Confidence: {confidence:.2f}%" | |
| def get_head_ct_report(prob_abnormal): | |
| if prob_abnormal > 0.5: | |
| return f"**Findings:** Hyperdense area, acute hemorrhage.\n**Impression:** Acute hemorrhage.\n**Recommendation:** Neurosurgical consultation.\n**Confidence:** {prob_abnormal:.2%}" | |
| else: | |
| return f"**Findings:** No hemorrhage.\n**Impression:** Normal head CT.\n**Recommendation:** Routine follow-up.\n**Confidence:** {(1-prob_abnormal):.2%}" | |
| def get_generic_report(body_part, confidence): | |
| return f"**Body part:** {body_part} ({confidence:.2f}%)\n**Report:** Normal study." | |
| def predict(image_input, modality): | |
| if image_input is None: | |
| return "Upload an image." | |
| if isinstance(image_input, str): | |
| image_pil = Image.open(image_input).convert('L') | |
| else: | |
| image_pil = image_input.convert('L') | |
| img_tensor = transform(image_pil).unsqueeze(0) | |
| if modality == "Chest X‑ray": | |
| with torch.no_grad(): | |
| out_chest = model_chest(img_tensor) | |
| prob_chest = F.softmax(out_chest, dim=1) | |
| conf_chest, idx_chest = torch.max(prob_chest, 1) | |
| pathology = chest_classes[idx_chest.item()] | |
| conf_chest_percent = conf_chest.item() * 100 | |
| report = get_chest_report(pathology, conf_chest_percent) | |
| return f"**Modality:** Chest X‑ray\n**Prediction:** {pathology}\n**Report:**\n{report}" | |
| elif modality == "Head CT": | |
| with torch.no_grad(): | |
| out_head = model_head(img_tensor) | |
| prob_head = F.softmax(out_head, dim=1) | |
| prob_abnormal = prob_head[0][1].item() | |
| head_report = get_head_ct_report(prob_abnormal) | |
| return f"**Modality:** Head CT\n**Report:**\n{head_report}" | |
| else: | |
| with torch.no_grad(): | |
| out_body = model_body(img_tensor) | |
| prob_body = F.softmax(out_body, dim=1) | |
| conf_body, idx_body = torch.max(prob_body, 1) | |
| body_part = body_classes[idx_body.item()] | |
| conf_body_percent = conf_body.item() * 100 | |
| generic_report = get_generic_report(body_part, conf_body_percent) | |
| return f"**Selected modality:** {modality}\n{generic_report}" | |
| # ---------- Gradio interface with logo using Blocks ---------- | |
| with gr.Blocks(title="InsightRay – Multi‑Modality Medical AI") as demo: | |
| gr.HTML('<div style="text-align: center; margin-bottom: 20px;"><img src="insightray_icon.svg" width="100" style="display: inline-block;"/></div>') | |
| gr.Markdown("## InsightRay – Multi‑Modality Medical AI") | |
| gr.Markdown("Choose the image type and upload a medical image. For chest X‑ray and head CT, detailed AI reports are generated.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Medical Image") | |
| modality = gr.Radio(choices=["Chest X‑ray", "Head CT", "Hand X‑ray", "Abdomen CT", "Breast MRI", "Chest CT"], | |
| label="Select Modality", value="Chest X‑ray") | |
| submit_btn = gr.Button("Analyse") | |
| with gr.Column(): | |
| output = gr.Textbox(label="AI Report", lines=20) | |
| submit_btn.click(fn=predict, inputs=[image_input, modality], outputs=output) | |
| demo.launch() |