RohitManglik's picture
Update app.py
e4f384b verified
import os
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# =========================
# ๐Ÿ“ MODEL CONFIG
# =========================
MODEL_REPO = "InfoBayAI/resnet18-mri-anatomy-classifier"
weights_path = hf_hub_download(repo_id=MODEL_REPO, filename="pytorch_model.bin")
config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json")
classes_path = hf_hub_download(repo_id=MODEL_REPO, filename="classes.json")
with open(config_path, "r") as f:
config = json.load(f)
with open(classes_path, "r") as f:
classes = json.load(f)
# =========================
# โš™๏ธ DEVICE
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =========================
# ๐Ÿง  LOAD MODEL
# =========================
model = models.resnet18(pretrained=False)
model.conv1 = nn.Conv2d(
config["num_channels"],
64,
kernel_size=7,
stride=2,
padding=3,
bias=False
)
model.fc = nn.Linear(model.fc.in_features, config["num_classes"])
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()
print("โœ… MRI Model Loaded Successfully")
# =========================
# ๐Ÿงผ PREPROCESS
# =========================
def preprocess_image(image):
img = image.convert("L")
img = img.resize((config["input_size"], config["input_size"]))
img = np.array(img)
# Normalize (same as training)
img = (img / 255.0 - 0.5) / 0.5
img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float()
return img.to(device)
# =========================
# ๐Ÿ”ฎ PREDICT
# =========================
def predict_mri(image):
if image is None:
return "โš ๏ธ Please upload an MRI image.", None
img = preprocess_image(image)
with torch.no_grad():
output = model(img)
probs = F.softmax(output, dim=1)
probs_np = probs.cpu().numpy()[0]
# Sort probabilities (highest โ†’ lowest)
sorted_indices = np.argsort(probs_np)[::-1]
top_label = classes[sorted_indices[0]]
top_conf = probs_np[sorted_indices[0]] * 100
# Format probabilities nicely
formatted_probs = "\n".join([
f"{classes[i]} โ€” {probs_np[i]*100:.2f}%"
for i in sorted_indices
])
result_text = f"""
๐Ÿง  MRI Classification Result
๐Ÿ“Œ Prediction: {top_label}
๐Ÿ” Confidence: {top_conf:.2f}%
๐Ÿ“Š Probabilities:
{formatted_probs}
โš ๏ธ AI-assisted output โ€” not a medical diagnosis
"""
# For chart display
probs_dict = {
classes[i]: float(probs_np[i])
for i in range(len(classes))
}
return result_text, probs_dict
# =========================
# ๐ŸŽจ WIDE LANDSCAPE UI
# =========================
def create_interface():
with gr.Blocks(
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1600px !important;
margin: auto;
}
.gr-image {
min-height: 500px !important;
}
textarea {
font-size: 16px !important;
}
h1, h2, h3 {
text-align: center;
}
"""
) as interface:
gr.Markdown("# ๐Ÿง  MRI Anatomy Classifier")
gr.Markdown("### Upload an MRI scan to identify the body region")
with gr.Row(equal_height=True):
# LEFT PANEL (IMAGE)
with gr.Column(scale=1.2):
image_input = gr.Image(
type="pil",
label="๐Ÿ“ค Upload MRI Image",
height=500,
sources=["upload"]
)
with gr.Row():
predict_btn = gr.Button("๐Ÿš€ Analyze MRI", variant="primary")
clear_btn = gr.Button("๐Ÿงน Clear")
# RIGHT PANEL (RESULT - WIDER)
with gr.Column(scale=1.8):
output_text = gr.Textbox(
label="๐Ÿ“‹ Result",
lines=18
)
output_chart = gr.Label(
label="๐Ÿ“Š Confidence Breakdown"
)
gr.Markdown("---")
gr.Markdown("๐Ÿ’ก Tip: Use clear MRI scans for better predictions.")
predict_btn.click(
fn=predict_mri,
inputs=image_input,
outputs=[output_text, output_chart]
)
clear_btn.click(
fn=lambda: (None, "", None),
inputs=[],
outputs=[image_input, output_text, output_chart]
)
return interface
# =========================
# ๐Ÿš€ LAUNCH
# =========================
if __name__ == "__main__":
app = create_interface()
app.launch(share=True)