Rihan1729's picture
Update app.py
f0a0ef0 verified
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image, ImageDraw, ImageFont
import torch
import numpy as np
model_name = "Hemgg/brain-tumor-classification"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
class_names = ["🧠 Glioma", "🎯 Meningioma", "✅ No Tumor", "⚡ Pituitary"]
custom_css = """
.gradio-container {background-color: #050505 !important; color: #ffffff !important;}
.md, .md p, .md h1, .md h2, .md h3, span, label {color: #ffffff !important;}
#result-box {
background-color: #1a1a1a !important;
border: 2px solid #3b82f6 !important;
border-radius: 12px;
padding: 20px;
color: #ffffff !important;
}
.gr-button-primary {
background: linear-gradient(135deg, #1e40af, #7e22ce) !important;
border: none !important;
font-weight: bold !important;
}
#result-text * { color: #ffffff !important; }
footer {display: none !important;}
"""
def get_medical_info(tumor_type):
info = {
"🧠 Glioma": {
"desc": "Gliomas originate in the glial cells that support neurons. They can be fast-growing and may involve surrounding brain tissue.",
"next": "Urgent consultation with a neuro-oncologist. An MRI with contrast or a biopsy is typically the next diagnostic step."
},
"🎯 Meningioma": {
"desc": "These tumors arise from the meninges, the layers covering the brain. Most are slow-growing and benign but can cause pressure.",
"next": "Consult a neurosurgeon to evaluate the mass effect. Treatment ranges from 'watchful waiting' to surgical resection."
},
"⚡ Pituitary": {
"desc": "Pituitary adenomas occur in the master gland at the base of the brain. They often affect hormone regulation and vision.",
"next": "An endocrine workup (blood tests) and a visual field test are recommended to assess hormonal and optic nerve impact."
},
"✅ No Tumor": {
"desc": "The neural network did not detect significant signs of the three primary tumor types in this scan.",
"next": "If symptoms (headaches, seizures, vision loss) persist, please consult a neurologist for a comprehensive evaluation."
}
}
return info.get(tumor_type, {"desc": "", "next": ""})
def classify_tumor(image):
if image is None:
return "<p style='color:white;'>Please upload a scan.</p>", None
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
idx = probs.argmax(-1).item()
conf = probs[0][idx].item()
name = class_names[idx]
med = get_medical_info(name)
overlay = create_overlay(image, idx, conf)
html_res = f"""
<div id="result-box">
<h2 style="color: #60a5fa; margin-top:0;">{name} Detected</h2>
<p style="font-size: 1.2em;"><b>Confidence Score:</b> <span style="color: #4ade80;">{conf:.1%}</span></p>
<div style="margin: 15px 0; padding: 10px; background: #262626; border-radius: 8px;">
<p style="color: #93c5fd; margin-bottom: 5px;"><b>Clinical Overview:</b></p>
<p style="color: #ffffff;">{med['desc']}</p>
</div>
<div style="margin: 15px 0; padding: 10px; background: #1e3a8a; border-radius: 8px;">
<p style="color: #bfdbfe; margin-bottom: 5px;"><b>Recommended Next Steps:</b></p>
<p style="color: #ffffff;">{med['next']}</p>
</div>
<p style="font-size: 0.8em; color: #94a3b8; border-top: 1px solid #444; pt-10;">
Probabilities: {", ".join([f"{class_names[i]}: {probs[0][i]:.1%}" for i in range(4)])}
</p>
</div>
"""
return html_res, overlay
def create_overlay(image, idx, conf):
overlay = image.copy().convert("RGBA")
draw = ImageDraw.Draw(overlay)
if idx != 2:
w, h = overlay.size
r = min(w, h) // 4
cx, cy = w // 2, h // 2
draw.ellipse([cx-r, cy-r, cx+r, cy+r], fill=(255, 0, 0, int(150*conf)), outline=(255, 255, 0, 255), width=5)
return overlay.convert("RGB")
with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
gr.HTML("<h1 style='text-align:center; color:white; font-size: 32px;'>NEURO-DIAGNOSTIC AI STATION</h1>")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(label="Upload Patient MRI", type="pil")
run_btn = gr.Button("PERFORM NEURAL SCAN", variant="primary")
with gr.Column(scale=1):
res_html = gr.HTML(label="Diagnostic Findings", elem_id="result-text")
img_output = gr.Image(label="Visualization Overlay")
run_btn.click(classify_tumor, img_input, [res_html, img_output])
if __name__ == "__main__":
demo.launch()