File size: 5,177 Bytes
7778c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e881d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
from PIL import Image
import torch.nn.functional as F
from monai.networks.nets import DenseNet121
from torchvision import transforms

# ---------- Load models ----------
model_body = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6)
model_body.load_state_dict(torch.load("body_part_classifier.pth", map_location="cpu"))
model_body.eval()
body_classes = ["Abdomen-CT", "Breast-MRI", "Chest-CT", "ChestX-RAY", "HandX-RAY", "Head-CT"]

model_chest = DenseNet121(spatial_dims=2, in_channels=1, out_channels=4)
model_chest.load_state_dict(torch.load("chest_xray_fast.pth", map_location="cpu"))
model_chest.eval()
chest_classes = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]

model_head = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2)
model_head.load_state_dict(torch.load("head_ct_2d_model.pth", map_location="cpu"))
model_head.eval()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

def get_chest_report(pathology, confidence):
    if pathology == "COVID":
        return f"**Findings:** Bilateral ground-glass opacities, peripheral distribution.\n**Impression:** COVID-19 pneumonia.\n**Recommendation:** PCR testing.\n**Confidence:** {confidence:.2f}%"
    elif pathology == "Lung_Opacity":
        return f"**Findings:** Patchy opacity.\n**Impression:** Lung opacity, ? infection.\n**Recommendation:** Follow-up imaging.\n**Confidence:** {confidence:.2f}%"
    elif pathology == "Normal":
        return f"**Findings:** Clear lungs.\n**Impression:** Normal chest X‑ray.\n**Recommendation:** None.\n**Confidence:** {confidence:.2f}%"
    elif pathology == "Viral Pneumonia":
        return f"**Findings:** Interstitial markings.\n**Impression:** Viral pneumonia.\n**Recommendation:** Clinical correlation.\n**Confidence:** {confidence:.2f}%"
    else:
        return f"**Report:** Unknown. Confidence: {confidence:.2f}%"

def get_head_ct_report(prob_abnormal):
    if prob_abnormal > 0.5:
        return f"**Findings:** Hyperdense area, acute hemorrhage.\n**Impression:** Acute hemorrhage.\n**Recommendation:** Neurosurgical consultation.\n**Confidence:** {prob_abnormal:.2%}"
    else:
        return f"**Findings:** No hemorrhage.\n**Impression:** Normal head CT.\n**Recommendation:** Routine follow-up.\n**Confidence:** {(1-prob_abnormal):.2%}"

def get_generic_report(body_part, confidence):
    return f"**Body part:** {body_part} ({confidence:.2f}%)\n**Report:** Normal study."

def predict(image_input, modality):
    if image_input is None:
        return "Upload an image."
    if isinstance(image_input, str):
        image_pil = Image.open(image_input).convert('L')
    else:
        image_pil = image_input.convert('L')
    img_tensor = transform(image_pil).unsqueeze(0)

    if modality == "Chest X‑ray":
        with torch.no_grad():
            out_chest = model_chest(img_tensor)
            prob_chest = F.softmax(out_chest, dim=1)
            conf_chest, idx_chest = torch.max(prob_chest, 1)
        pathology = chest_classes[idx_chest.item()]
        conf_chest_percent = conf_chest.item() * 100
        report = get_chest_report(pathology, conf_chest_percent)
        return f"**Modality:** Chest X‑ray\n**Prediction:** {pathology}\n**Report:**\n{report}"

    elif modality == "Head CT":
        with torch.no_grad():
            out_head = model_head(img_tensor)
            prob_head = F.softmax(out_head, dim=1)
            prob_abnormal = prob_head[0][1].item()
        head_report = get_head_ct_report(prob_abnormal)
        return f"**Modality:** Head CT\n**Report:**\n{head_report}"

    else:
        with torch.no_grad():
            out_body = model_body(img_tensor)
            prob_body = F.softmax(out_body, dim=1)
            conf_body, idx_body = torch.max(prob_body, 1)
        body_part = body_classes[idx_body.item()]
        conf_body_percent = conf_body.item() * 100
        generic_report = get_generic_report(body_part, conf_body_percent)
        return f"**Selected modality:** {modality}\n{generic_report}"

# ---------- Gradio interface with logo using Blocks ----------
with gr.Blocks(title="InsightRay – Multi‑Modality Medical AI") as demo:
    gr.HTML('<div style="text-align: center; margin-bottom: 20px;"><img src="insightray_icon.svg" width="100" style="display: inline-block;"/></div>')
    gr.Markdown("## InsightRay – Multi‑Modality Medical AI")
    gr.Markdown("Choose the image type and upload a medical image. For chest X‑ray and head CT, detailed AI reports are generated.")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Medical Image")
            modality = gr.Radio(choices=["Chest X‑ray", "Head CT", "Hand X‑ray", "Abdomen CT", "Breast MRI", "Chest CT"],
                                label="Select Modality", value="Chest X‑ray")
            submit_btn = gr.Button("Analyse")
        with gr.Column():
            output = gr.Textbox(label="AI Report", lines=20)
    
    submit_btn.click(fn=predict, inputs=[image_input, modality], outputs=output)

demo.launch()