File size: 4,989 Bytes
c47f58c
 
 
 
 
 
 
 
 
 
 
 
f0a0ef0
 
 
 
 
 
c47f58c
f0a0ef0
c47f58c
 
f0a0ef0
c47f58c
f0a0ef0
c47f58c
f0a0ef0
c47f58c
 
 
f0a0ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c47f58c
 
f0a0ef0
c47f58c
 
 
 
 
f0a0ef0
 
 
 
c47f58c
f0a0ef0
c47f58c
f0a0ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c47f58c
 
f0a0ef0
c47f58c
f0a0ef0
c47f58c
 
f0a0ef0
 
 
c47f58c
f0a0ef0
c47f58c
 
f0a0ef0
 
 
c47f58c
f0a0ef0
 
 
c47f58c
f0a0ef0
 
 
c47f58c
f0a0ef0
c47f58c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image, ImageDraw, ImageFont
import torch
import numpy as np

model_name = "Hemgg/brain-tumor-classification"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
class_names = ["🧠 Glioma", "🎯 Meningioma", "✅ No Tumor", "⚡ Pituitary"]

custom_css = """
.gradio-container {background-color: #050505 !important; color: #ffffff !important;}
.md, .md p, .md h1, .md h2, .md h3, span, label {color: #ffffff !important;}
#result-box {
    background-color: #1a1a1a !important; 
    border: 2px solid #3b82f6 !important; 
    border-radius: 12px; 
    padding: 20px;
    color: #ffffff !important;
}
.gr-button-primary {
    background: linear-gradient(135deg, #1e40af, #7e22ce) !important;
    border: none !important;
    font-weight: bold !important;
}
#result-text * { color: #ffffff !important; }
footer {display: none !important;}
"""

def get_medical_info(tumor_type):
    info = {
        "🧠 Glioma": {
            "desc": "Gliomas originate in the glial cells that support neurons. They can be fast-growing and may involve surrounding brain tissue.",
            "next": "Urgent consultation with a neuro-oncologist. An MRI with contrast or a biopsy is typically the next diagnostic step."
        },
        "🎯 Meningioma": {
            "desc": "These tumors arise from the meninges, the layers covering the brain. Most are slow-growing and benign but can cause pressure.",
            "next": "Consult a neurosurgeon to evaluate the mass effect. Treatment ranges from 'watchful waiting' to surgical resection."
        },
        "⚡ Pituitary": {
            "desc": "Pituitary adenomas occur in the master gland at the base of the brain. They often affect hormone regulation and vision.",
            "next": "An endocrine workup (blood tests) and a visual field test are recommended to assess hormonal and optic nerve impact."
        },
        "✅ No Tumor": {
            "desc": "The neural network did not detect significant signs of the three primary tumor types in this scan.",
            "next": "If symptoms (headaches, seizures, vision loss) persist, please consult a neurologist for a comprehensive evaluation."
        }
    }
    return info.get(tumor_type, {"desc": "", "next": ""})

def classify_tumor(image):
    if image is None:
        return "<p style='color:white;'>Please upload a scan.</p>", None
    
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        idx = probs.argmax(-1).item()
        conf = probs[0][idx].item()
        name = class_names[idx]
        med = get_medical_info(name)

    overlay = create_overlay(image, idx, conf)
    
    html_res = f"""
    <div id="result-box">
        <h2 style="color: #60a5fa; margin-top:0;">{name} Detected</h2>
        <p style="font-size: 1.2em;"><b>Confidence Score:</b> <span style="color: #4ade80;">{conf:.1%}</span></p>
        <div style="margin: 15px 0; padding: 10px; background: #262626; border-radius: 8px;">
            <p style="color: #93c5fd; margin-bottom: 5px;"><b>Clinical Overview:</b></p>
            <p style="color: #ffffff;">{med['desc']}</p>
        </div>
        <div style="margin: 15px 0; padding: 10px; background: #1e3a8a; border-radius: 8px;">
            <p style="color: #bfdbfe; margin-bottom: 5px;"><b>Recommended Next Steps:</b></p>
            <p style="color: #ffffff;">{med['next']}</p>
        </div>
        <p style="font-size: 0.8em; color: #94a3b8; border-top: 1px solid #444; pt-10;">
            Probabilities: {", ".join([f"{class_names[i]}: {probs[0][i]:.1%}" for i in range(4)])}
        </p>
    </div>
    """
    return html_res, overlay

def create_overlay(image, idx, conf):
    overlay = image.copy().convert("RGBA")
    draw = ImageDraw.Draw(overlay)
    if idx != 2:
        w, h = overlay.size
        r = min(w, h) // 4
        cx, cy = w // 2, h // 2
        draw.ellipse([cx-r, cy-r, cx+r, cy+r], fill=(255, 0, 0, int(150*conf)), outline=(255, 255, 0, 255), width=5)
    return overlay.convert("RGB")

with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
    gr.HTML("<h1 style='text-align:center; color:white; font-size: 32px;'>NEURO-DIAGNOSTIC AI STATION</h1>")
    
    with gr.Row():
        with gr.Column(scale=1):
            img_input = gr.Image(label="Upload Patient MRI", type="pil")
            run_btn = gr.Button("PERFORM NEURAL SCAN", variant="primary")
            
        with gr.Column(scale=1):
            res_html = gr.HTML(label="Diagnostic Findings", elem_id="result-text")
            img_output = gr.Image(label="Visualization Overlay")
            
    run_btn.click(classify_tumor, img_input, [res_html, img_output])

if __name__ == "__main__":
    demo.launch()