| import os |
| from huggingface_hub import login |
|
|
| login(token=os.getenv("HF_TOKEN")) |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.models as models |
| import numpy as np |
| from PIL import Image |
| import json |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
|
|
| MODEL_REPO = "InfoBayAI/resnet18-mri-anatomy-classifier" |
|
|
| weights_path = hf_hub_download(repo_id=MODEL_REPO, filename="pytorch_model.bin") |
| config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json") |
| classes_path = hf_hub_download(repo_id=MODEL_REPO, filename="classes.json") |
|
|
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| with open(classes_path, "r") as f: |
| classes = json.load(f) |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
| model = models.resnet18(pretrained=False) |
|
|
| model.conv1 = nn.Conv2d( |
| config["num_channels"], |
| 64, |
| kernel_size=7, |
| stride=2, |
| padding=3, |
| bias=False |
| ) |
|
|
| model.fc = nn.Linear(model.fc.in_features, config["num_classes"]) |
|
|
| model.load_state_dict(torch.load(weights_path, map_location=device)) |
| model.to(device) |
| model.eval() |
|
|
| print("โ
MRI Model Loaded Successfully") |
|
|
| |
| |
| |
| def preprocess_image(image): |
| img = image.convert("L") |
| img = img.resize((config["input_size"], config["input_size"])) |
| img = np.array(img) |
|
|
| |
| img = (img / 255.0 - 0.5) / 0.5 |
|
|
| img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float() |
| return img.to(device) |
|
|
| |
| |
| |
| def predict_mri(image): |
| if image is None: |
| return "โ ๏ธ Please upload an MRI image.", None |
|
|
| img = preprocess_image(image) |
|
|
| with torch.no_grad(): |
| output = model(img) |
| probs = F.softmax(output, dim=1) |
|
|
| probs_np = probs.cpu().numpy()[0] |
|
|
| |
| sorted_indices = np.argsort(probs_np)[::-1] |
|
|
| top_label = classes[sorted_indices[0]] |
| top_conf = probs_np[sorted_indices[0]] * 100 |
|
|
| |
| formatted_probs = "\n".join([ |
| f"{classes[i]} โ {probs_np[i]*100:.2f}%" |
| for i in sorted_indices |
| ]) |
|
|
| result_text = f""" |
| ๐ง MRI Classification Result |
| |
| ๐ Prediction: {top_label} |
| ๐ Confidence: {top_conf:.2f}% |
| |
| ๐ Probabilities: |
| {formatted_probs} |
| |
| โ ๏ธ AI-assisted output โ not a medical diagnosis |
| """ |
|
|
| |
| probs_dict = { |
| classes[i]: float(probs_np[i]) |
| for i in range(len(classes)) |
| } |
|
|
| return result_text, probs_dict |
|
|
| |
| |
| |
| def create_interface(): |
| with gr.Blocks( |
| theme=gr.themes.Soft(), |
| css=""" |
| .gradio-container { |
| max-width: 1600px !important; |
| margin: auto; |
| } |
| |
| .gr-image { |
| min-height: 500px !important; |
| } |
| |
| textarea { |
| font-size: 16px !important; |
| } |
| |
| h1, h2, h3 { |
| text-align: center; |
| } |
| """ |
| ) as interface: |
|
|
| gr.Markdown("# ๐ง MRI Anatomy Classifier") |
| gr.Markdown("### Upload an MRI scan to identify the body region") |
|
|
| with gr.Row(equal_height=True): |
|
|
| |
| with gr.Column(scale=1.2): |
| image_input = gr.Image( |
| type="pil", |
| label="๐ค Upload MRI Image", |
| height=500, |
| sources=["upload"] |
| ) |
|
|
| with gr.Row(): |
| predict_btn = gr.Button("๐ Analyze MRI", variant="primary") |
| clear_btn = gr.Button("๐งน Clear") |
|
|
| |
| with gr.Column(scale=1.8): |
| output_text = gr.Textbox( |
| label="๐ Result", |
| lines=18 |
| ) |
|
|
| output_chart = gr.Label( |
| label="๐ Confidence Breakdown" |
| ) |
|
|
| gr.Markdown("---") |
| gr.Markdown("๐ก Tip: Use clear MRI scans for better predictions.") |
|
|
| predict_btn.click( |
| fn=predict_mri, |
| inputs=image_input, |
| outputs=[output_text, output_chart] |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: (None, "", None), |
| inputs=[], |
| outputs=[image_input, output_text, output_chart] |
| ) |
|
|
| return interface |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| app = create_interface() |
| app.launch(share=True) |