Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from torchvision.transforms import InterpolationMode | |
| from torchvision.models import efficientnet_b3 | |
| # Model setup | |
| class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'] | |
| model = efficientnet_b3(weights=None) | |
| model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names)) | |
| model.load_state_dict(torch.load( | |
| "Eff_net_b3_01_brain_tumor.pth", | |
| map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| )) | |
| model.eval() | |
| # Image transform | |
| img_transform = transforms.Compose([ | |
| transforms.Resize(320, interpolation=InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(300), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Prediction function | |
| def predict(image): | |
| transformed_image = img_transform(image).unsqueeze(0) | |
| with torch.inference_mode(): | |
| preds = model(transformed_image) | |
| probs = torch.softmax(preds, dim=1) | |
| label_idx = torch.argmax(probs, dim=1).item() | |
| class_label = class_names[label_idx] | |
| confidence = probs[0, label_idx].item() | |
| return class_label, confidence | |
| # Gradio Blocks UI | |
| with gr.Blocks(title="π§ Brain Tumor MRI Classifier") as demo: | |
| gr.Markdown("## π§ Brain Tumor Classifier (EfficientNet-B3)") | |
| gr.Markdown(""" | |
| Upload an MRI scan of the brain, and this model will classify it as one of: | |
| - **Glioma Tumor** | |
| - **Meningioma Tumor** | |
| - **Pituitary Tumor** | |
| - **No Tumor** | |
| Uses EfficientNet-B3 trained on labeled brain MRI dataset. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload MRI Image") | |
| predict_button = gr.Button("π Predict") | |
| clear_button = gr.Button("π§Ή Clear") | |
| with gr.Column(): | |
| output_label = gr.Label(label="Predicted Class") | |
| confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score") | |
| predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider]) | |
| clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider]) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "<center>π€ Developed by [Sagar Bisht](https://www.linkedin.com/in/sagarbisht123)</center>", | |
| elem_id="footer" | |
| ) | |
| demo.launch(share=True) | |