Spaces:
Build error
Build error
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.cm as cm | |
| from PIL import Image | |
| import gradio as gr | |
| import tempfile, os | |
| from grad_cam import grad_cam # <-- uses the helper we wrote earlier | |
| # --------- MODEL SETUP --------- | |
| model = models.resnet18(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, 4) # 4-class output | |
| model.load_state_dict(torch.load("best_model.pth", map_location="cpu")) | |
| model.eval() | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| classes = ["glioma", "meningioma", "notumor", "pituitary"] | |
| # --------- PDF REPORT GENERATOR --------- | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as RLImage, Table, TableStyle | |
| from reportlab.lib.styles import getSampleStyleSheet | |
| from reportlab.lib import colors | |
| def generate_report(predictions, grad_cam_image): | |
| tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| doc = SimpleDocTemplate(tmp_file.name, pagesize=letter) | |
| styles = getSampleStyleSheet() | |
| story = [] | |
| # Title | |
| story.append(Paragraph("<b>🧠 Brain Tumor Classification Report</b>", styles["Title"])) | |
| story.append(Spacer(1, 20)) | |
| # Table of predictions | |
| data = [["Class", "Confidence"]] | |
| for cls, prob in predictions.items(): | |
| data.append([cls.capitalize(), f"{prob:.2%}"]) | |
| table = Table(data, colWidths=[200, 150]) | |
| table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor("#e0e0e0")), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.black), | |
| ('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
| ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), | |
| ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold') | |
| ])) | |
| story.append(table) | |
| story.append(Spacer(1, 20)) | |
| # Grad-CAM image | |
| if grad_cam_image: | |
| grad_cam_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name | |
| grad_cam_image.save(grad_cam_path) | |
| story.append(Paragraph("<b>Grad-CAM Visualization</b>", styles["Heading2"])) | |
| story.append(RLImage(grad_cam_path, width=300, height=300)) | |
| # Footer | |
| story.append(Spacer(1, 30)) | |
| story.append(Paragraph("Generated by Brain Tumor Classifier | © 2025", styles["Normal"])) | |
| doc.build(story) | |
| return tmp_file.name | |
| # --------- INFERENCE FUNCTION --------- | |
| def predict_with_gradcam(image, heatmap_alpha=0.5): | |
| input_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probs = torch.softmax(outputs, dim=1).numpy()[0] | |
| predicted_class_idx = int(np.argmax(probs)) | |
| predictions = {cls: float(prob) for cls, prob in zip(classes, probs)} | |
| # Grad-CAM | |
| grad_cam_image = None | |
| try: | |
| cam = grad_cam(model, input_tensor, predicted_class_idx) | |
| if cam is not None: | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| heatmap = cm.jet(cam_resized)[:, :, :3] | |
| img_array = np.array(image.resize((224, 224))) / 255.0 | |
| alpha = float(max(0.0, min(1.0, heatmap_alpha))) | |
| superimposed = heatmap * alpha + img_array * (1.0 - alpha) | |
| superimposed = np.clip(superimposed, 0, 1) | |
| grad_cam_image = Image.fromarray((superimposed * 255).astype(np.uint8)) | |
| except Exception as e: | |
| print(f"Grad-CAM error: {e}") | |
| # PDF report | |
| report_path = generate_report(predictions, grad_cam_image) | |
| return predictions, grad_cam_image, report_path | |
| # --------- SAMPLE IMAGES TAB --------- | |
| SAMPLE_IMAGES_FOLDER = "examples" | |
| def load_sample_images(): | |
| if not os.path.exists(SAMPLE_IMAGES_FOLDER): | |
| return [] | |
| return sorted([ | |
| os.path.join(SAMPLE_IMAGES_FOLDER, f) | |
| for f in os.listdir(SAMPLE_IMAGES_FOLDER) | |
| if f.lower().endswith((".png", ".jpg", ".jpeg")) | |
| ]) | |
| sample_paths = load_sample_images() | |
| with gr.Blocks() as sample_tab: | |
| gr.Markdown("## 🖼️ Try with Sample MRI Images\nClick an image below to run a prediction:") | |
| sample_gallery = gr.Gallery( | |
| value=sample_paths, | |
| label="Sample MRI Images", | |
| columns=[4], | |
| height="auto" | |
| ) | |
| output_probs = gr.Label(label="Predicted Probabilities") | |
| output_cam = gr.Image(type="pil", label="Grad-CAM Visualization", height=256, width=256) | |
| output_file = gr.File(label="Download Full Report") | |
| def predict_from_sample(evt: gr.SelectData, index: int = None): | |
| img_path = None | |
| # Gradio returns event.value as a dict with image path | |
| if isinstance(evt.value, dict) and "image" in evt.value: | |
| img_path = evt.value["image"].get("path") | |
| elif isinstance(evt.value, str): | |
| img_path = evt.value | |
| if img_path and os.path.exists(img_path): | |
| img = Image.open(img_path).convert("RGB") | |
| return predict_with_gradcam(img, heatmap_alpha=0.5) | |
| return {}, None, None | |
| sample_gallery.select( | |
| fn=predict_from_sample, | |
| inputs=sample_gallery, | |
| outputs=[output_probs, output_cam, output_file] | |
| ) | |
| # --------- MAIN APP --------- | |
| main_tab = gr.Interface( | |
| fn=predict_with_gradcam, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload MRI Image", image_mode="RGB", height=356, width=356), | |
| ], | |
| outputs=[ | |
| gr.Label(label="Predicted Probabilities"), | |
| gr.Image(type="pil", label="Grad-CAM Visualization", height=256, width=256, show_download_button=True), | |
| gr.File(label="Download Full Report") | |
| ], | |
| title="🧠 Brain Tumor Classification (ResNet18)", | |
| description="Upload a brain MRI image to get predictions, a Grad-CAM heatmap, and download a detailed report." | |
| ) | |
| # --------- ABOUT TAB --------- | |
| with gr.Blocks() as about_tab: | |
| gr.Markdown(""" | |
| # ℹ️ About This Project | |
| A comprehensive deep learning solution for automated brain tumor classification using PyTorch and ResNet18. This full-stack application classifies brain MRI scans into four categories: Glioma, Meningioma, No Tumor, and Pituitary tumors, with explainable AI features for medical professionals. | |
| It can classifies brain MRI scans into four categories: Glioma, Meningioma, No Tumor, and Pituitary tumors, with explainable AI features for medical professionals. | |
| ## 1. Overall Model Performance and Comparison | |
| The ResNet8 model trained using transfer learning (fine-tuned) outperformed a custom Convolutional Neural Network (CNN) and a ResNet8 model trained from scratch. The goal of achieving at least 90% precision and an F1-score mean of at least 0.88 was exceeded. | |
| | Model | Accuracy (%) | Precision | F1-score | | |
| | :-------------------- | :----------- | :-------- | :------- | | |
| | Custom CNN | 98.7 | 0.986 | 0.986 | | |
| | ResNet8 Scratch | 97.9 | 0.978 | 0.978 | | |
| | **ResNet8 Fine-Tuned** | **99.39** | **0.993** | **0.993** | | |
| *Table based on source data.* | |
| The top result of **99.39% accuracy** is highly competitive, approaching the performance achieved by computationally heavier architectures, such as EfficientNet-B4 (99.76%) and hybrid CNN–XGBoost models (99.77%). | |
| ## 2. Detailed Performance Metrics by Class | |
| The model demonstrated a high degree of reliability and safety across all four classification categories: Glioma, Meningioma, Pituitary tumor, and Notumor. | |
| | Class | Precision | Recall (Sensitivity) | F1-score | | |
| | :--------- | :-------- | :----------------- | :------- | | |
| | Glioma | 0.9966 | 0.9900 | 0.9933 | | |
| | Meningioma | 0.9870 | 0.9902 | 0.9886 | | |
| | **Notumor**| **0.9951**| **1.0000** | **0.9975** | | |
| | Pituitary | 0.9967 | 0.9933 | 0.9950 | | |
| *Table based on source data.* | |
| - **High Reliability (Precision):** Precision scores, starting at 98.70%, ensure the number of false positives is extremely low, reducing erroneous diagnoses. | |
| - **Patient Safety (Recall):** Recall scores close to 100% indicate that the model misses virtually no tumors, which is critical for patient safety. | |
| - **Notumor Class:** The model achieved **100% recall** for the *Notumor* class, meaning no normal image was mistaken for a tumor, minimizing unnecessary stress and supplemental examinations. | |
| ### 👨💻 Author | |
| Created by Him | |
| 🌐 [Portfolio](https://www.ai-by-him.me/) • 💼 [LinkedIn](https://www.linkedin.com/in/ibrahim-oubelkas-9b8a6a362/) | |
| """) | |
| # --------- COMBINE --------- | |
| demo = gr.TabbedInterface( | |
| [main_tab, sample_tab, about_tab], | |
| ["Prediction", "Sample Images", "About"] | |
| ) | |
| demo.launch() | |