import os from huggingface_hub import login login(token=os.getenv("HF_TOKEN")) import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import numpy as np from PIL import Image import json from huggingface_hub import hf_hub_download # ========================= # ๐Ÿ“ MODEL CONFIG # ========================= MODEL_REPO = "InfoBayAI/resnet18-mri-anatomy-classifier" weights_path = hf_hub_download(repo_id=MODEL_REPO, filename="pytorch_model.bin") config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json") classes_path = hf_hub_download(repo_id=MODEL_REPO, filename="classes.json") with open(config_path, "r") as f: config = json.load(f) with open(classes_path, "r") as f: classes = json.load(f) # ========================= # โš™๏ธ DEVICE # ========================= device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ========================= # ๐Ÿง  LOAD MODEL # ========================= model = models.resnet18(pretrained=False) model.conv1 = nn.Conv2d( config["num_channels"], 64, kernel_size=7, stride=2, padding=3, bias=False ) model.fc = nn.Linear(model.fc.in_features, config["num_classes"]) model.load_state_dict(torch.load(weights_path, map_location=device)) model.to(device) model.eval() print("โœ… MRI Model Loaded Successfully") # ========================= # ๐Ÿงผ PREPROCESS # ========================= def preprocess_image(image): img = image.convert("L") img = img.resize((config["input_size"], config["input_size"])) img = np.array(img) # Normalize (same as training) img = (img / 255.0 - 0.5) / 0.5 img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float() return img.to(device) # ========================= # ๐Ÿ”ฎ PREDICT # ========================= def predict_mri(image): if image is None: return "โš ๏ธ Please upload an MRI image.", None img = preprocess_image(image) with torch.no_grad(): output = model(img) probs = F.softmax(output, dim=1) probs_np = probs.cpu().numpy()[0] # Sort probabilities (highest โ†’ lowest) sorted_indices = np.argsort(probs_np)[::-1] top_label = classes[sorted_indices[0]] top_conf = probs_np[sorted_indices[0]] * 100 # Format probabilities nicely formatted_probs = "\n".join([ f"{classes[i]} โ€” {probs_np[i]*100:.2f}%" for i in sorted_indices ]) result_text = f""" ๐Ÿง  MRI Classification Result ๐Ÿ“Œ Prediction: {top_label} ๐Ÿ” Confidence: {top_conf:.2f}% ๐Ÿ“Š Probabilities: {formatted_probs} โš ๏ธ AI-assisted output โ€” not a medical diagnosis """ # For chart display probs_dict = { classes[i]: float(probs_np[i]) for i in range(len(classes)) } return result_text, probs_dict # ========================= # ๐ŸŽจ WIDE LANDSCAPE UI # ========================= def create_interface(): with gr.Blocks( theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1600px !important; margin: auto; } .gr-image { min-height: 500px !important; } textarea { font-size: 16px !important; } h1, h2, h3 { text-align: center; } """ ) as interface: gr.Markdown("# ๐Ÿง  MRI Anatomy Classifier") gr.Markdown("### Upload an MRI scan to identify the body region") with gr.Row(equal_height=True): # LEFT PANEL (IMAGE) with gr.Column(scale=1.2): image_input = gr.Image( type="pil", label="๐Ÿ“ค Upload MRI Image", height=500, sources=["upload"] ) with gr.Row(): predict_btn = gr.Button("๐Ÿš€ Analyze MRI", variant="primary") clear_btn = gr.Button("๐Ÿงน Clear") # RIGHT PANEL (RESULT - WIDER) with gr.Column(scale=1.8): output_text = gr.Textbox( label="๐Ÿ“‹ Result", lines=18 ) output_chart = gr.Label( label="๐Ÿ“Š Confidence Breakdown" ) gr.Markdown("---") gr.Markdown("๐Ÿ’ก Tip: Use clear MRI scans for better predictions.") predict_btn.click( fn=predict_mri, inputs=image_input, outputs=[output_text, output_chart] ) clear_btn.click( fn=lambda: (None, "", None), inputs=[], outputs=[image_input, output_text, output_chart] ) return interface # ========================= # ๐Ÿš€ LAUNCH # ========================= if __name__ == "__main__": app = create_interface() app.launch(share=True)