Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms, models | |
| from datasets import load_dataset | |
| import numpy as np | |
| import os | |
| from PIL import Image as PILImage | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| import pandas as pd | |
| # Configuration | |
| CUSTOM_MODEL_NAME = "GoGenix_Brain_MRI_Model" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {DEVICE}") | |
| # Dataset information | |
| DATASET_NAME = "PranomVignesh/MRI-Images-of-Brain-Tumor" | |
| CLASS_NAMES = ["glioma", "meningioma", "no-tumor", "pituitary"] | |
| NUM_CLASSES = len(CLASS_NAMES) | |
| # Enhanced CNN Architecture for 4-Class Classification | |
| class BrainTumorCNN(nn.Module): | |
| def __init__(self, num_classes=4): | |
| super(BrainTumorCNN, self).__init__() | |
| # Feature extraction with more capacity for 4 classes | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(128) | |
| self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm2d(256) | |
| self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1) | |
| self.bn4 = nn.BatchNorm2d(512) | |
| # Global Average Pooling instead of FC layers | |
| self.gap = nn.AdaptiveAvgPool2d((1, 1)) | |
| # Fully connected layers | |
| self.fc1 = nn.Linear(512, 256) | |
| self.fc2 = nn.Linear(256, 128) | |
| self.fc3 = nn.Linear(128, num_classes) | |
| # Regularization | |
| self.dropout = nn.Dropout(0.5) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| # Block 1 | |
| x = self.relu(self.bn1(self.conv1(x))) | |
| x = nn.MaxPool2d(2)(x) | |
| x = self.dropout(x) | |
| # Block 2 | |
| x = self.relu(self.bn2(self.conv2(x))) | |
| x = nn.MaxPool2d(2)(x) | |
| x = self.dropout(x) | |
| # Block 3 | |
| x = self.relu(self.bn3(self.conv3(x))) | |
| x = nn.MaxPool2d(2)(x) | |
| x = self.dropout(x) | |
| # Block 4 | |
| x = self.relu(self.bn4(self.conv4(x))) | |
| x = nn.MaxPool2d(2)(x) | |
| x = self.dropout(x) | |
| # Global Average Pooling | |
| x = self.gap(x) | |
| x = x.view(x.size(0), -1) | |
| # Fully connected | |
| x = self.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| x = self.relu(self.fc2(x)) | |
| x = self.dropout(x) | |
| x = self.fc3(x) | |
| return x | |
| # Advanced Data Augmentation | |
| def get_transforms(): | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(15), | |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.GaussianBlur(3), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| test_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return train_transform, test_transform | |
| # Dataset class for 4-class classification | |
| class BrainTumorDataset(Dataset): | |
| def __init__(self, dataset, transform=None): | |
| self.dataset = dataset | |
| self.transform = transform | |
| # Build label mapping | |
| self.label_to_idx = {name: idx for idx, name in enumerate(CLASS_NAMES)} | |
| print(f"Label mapping: {self.label_to_idx}") | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| # Handle image | |
| image = item['image'] | |
| if not isinstance(image, PILImage.Image): | |
| image = PILImage.fromarray(image) | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Handle label - map to correct class index | |
| label = item.get('label', 0) | |
| # Handle different label formats | |
| if isinstance(label, str): | |
| # Label is string like "glioma", "meningioma", etc. | |
| label_idx = self.label_to_idx.get(label.lower(), 0) | |
| elif isinstance(label, int): | |
| # Label is already an index | |
| label_idx = label | |
| else: | |
| label_idx = 0 # Default to first class | |
| # Ensure label is within valid range | |
| label_idx = max(0, min(label_idx, NUM_CLASSES - 1)) | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, torch.tensor(label_idx, dtype=torch.long) | |
| def analyze_dataset(dataset): | |
| """Analyze dataset structure and class distribution""" | |
| class_counts = {name: 0 for name in CLASS_NAMES} | |
| for i in range(min(1000, len(dataset))): | |
| item = dataset[i] | |
| label = item.get('label', 0) | |
| if isinstance(label, str): | |
| if label.lower() in class_counts: | |
| class_counts[label.lower()] += 1 | |
| elif isinstance(label, int) and label < len(CLASS_NAMES): | |
| class_counts[CLASS_NAMES[label]] += 1 | |
| return class_counts | |
| def train_and_save_model(): | |
| """Train CNN model for 4-class brain tumor classification""" | |
| try: | |
| # Load the specified dataset | |
| print(f"Loading dataset: {DATASET_NAME}") | |
| dataset = load_dataset(DATASET_NAME) | |
| splits = list(dataset.keys()) | |
| print(f"Splits available: {splits}") | |
| # Use train/valid splits | |
| train_data = dataset['train'] | |
| valid_data = dataset['valid'] | |
| test_data = dataset['test'] | |
| print(f"Training samples: {len(train_data)}") | |
| print(f"Validation samples: {len(valid_data)}") | |
| print(f"Test samples: {len(test_data)}") | |
| # Analyze class distribution | |
| train_dist = analyze_dataset(train_data) | |
| valid_dist = analyze_dataset(valid_data) | |
| print("Training distribution:", train_dist) | |
| print("Validation distribution:", valid_dist) | |
| # Get transforms | |
| train_transform, test_transform = get_transforms() | |
| # Create datasets | |
| train_dataset = BrainTumorDataset(train_data, train_transform) | |
| valid_dataset = BrainTumorDataset(valid_data, test_transform) | |
| test_dataset = BrainTumorDataset(test_data, test_transform) | |
| # Create data loaders | |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2) | |
| valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2) | |
| test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2) | |
| # Initialize model | |
| model = BrainTumorCNN(num_classes=NUM_CLASSES) | |
| model.to(DEVICE) | |
| # Loss function with class weighting for imbalance | |
| criterion = nn.CrossEntropyLoss() | |
| # Advanced optimizer | |
| optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4) | |
| # Cosine annealing scheduler | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) | |
| # Training parameters | |
| num_epochs = 100 | |
| best_accuracy = 0.0 | |
| patience = 10 | |
| patience_counter = 0 | |
| result_message = f"🚀 Training CNN Model for 4-Class Brain Tumor Classification\n\n" | |
| result_message += f"Dataset: {DATASET_NAME}\n" | |
| result_message += f"Classes: {CLASS_NAMES}\n" | |
| result_message += f"Training samples: {len(train_dataset)}\n" | |
| result_message += f"Validation samples: {len(valid_dataset)}\n" | |
| result_message += f"Test samples: {len(test_dataset)}\n" | |
| result_message += f"Epochs: {num_epochs}\n" | |
| result_message += f"Device: {DEVICE}\n\n" | |
| result_message += f"Class Distribution - Train: {train_dist}\n" | |
| result_message += f"Class Distribution - Valid: {valid_dist}\n\n" | |
| # Training loop | |
| for epoch in range(num_epochs): | |
| # Training phase | |
| model.train() | |
| running_loss = 0.0 | |
| train_correct = 0 | |
| train_total = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| # Training accuracy | |
| _, predicted = torch.max(outputs.data, 1) | |
| train_total += labels.size(0) | |
| train_correct += (predicted == labels).sum().item() | |
| # Validation phase | |
| model.eval() | |
| valid_correct = 0 | |
| valid_total = 0 | |
| with torch.no_grad(): | |
| for images, labels in valid_loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| outputs = model(images) | |
| _, predicted = torch.max(outputs.data, 1) | |
| valid_total += labels.size(0) | |
| valid_correct += (predicted == labels).sum().item() | |
| train_accuracy = 100 * train_correct / train_total | |
| valid_accuracy = 100 * valid_correct / valid_total | |
| avg_loss = running_loss / len(train_loader) | |
| # Update scheduler | |
| scheduler.step() | |
| current_lr = scheduler.get_last_lr()[0] | |
| # Save best model | |
| if valid_accuracy > best_accuracy: | |
| best_accuracy = valid_accuracy | |
| patience_counter = 0 | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'accuracy': valid_accuracy, | |
| 'loss': avg_loss, | |
| }, f'{CUSTOM_MODEL_NAME}_best.pth') | |
| else: | |
| patience_counter += 1 | |
| result_message += f'Epoch [{epoch+1}/{num_epochs}], LR: {current_lr:.6f}, Loss: {avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Valid Acc: {valid_accuracy:.2f}%\n' | |
| # Early stopping | |
| if patience_counter >= patience: | |
| result_message += f"\n⏹️ Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)\n" | |
| break | |
| # Target accuracy achieved | |
| if valid_accuracy >= 98.0: | |
| result_message += f"\n🎯 Target accuracy achieved! Stopping training at epoch {epoch+1}\n" | |
| break | |
| # Load best model for final evaluation | |
| best_checkpoint = torch.load(f'{CUSTOM_MODEL_NAME}_best.pth') | |
| model.load_state_dict(best_checkpoint['model_state_dict']) | |
| model.eval() | |
| # Final test evaluation | |
| test_correct = 0 | |
| test_total = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for images, labels in test_loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| outputs = model(images) | |
| _, predicted = torch.max(outputs, 1) | |
| test_total += labels.size(0) | |
| test_correct += (predicted == labels).sum().item() | |
| all_preds.extend(predicted.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| test_accuracy = 100 * test_correct / test_total | |
| result_message += f"\n🏁 FINAL TEST RESULTS:\n" | |
| result_message += f"Best Validation Accuracy: {best_checkpoint['accuracy']:.2f}%\n" | |
| result_message += f"Final Test Accuracy: {test_accuracy:.2f}%\n" | |
| # Class-wise accuracy | |
| class_correct = [0] * NUM_CLASSES | |
| class_total = [0] * NUM_CLASSES | |
| for pred, true in zip(all_preds, all_labels): | |
| if pred == true: | |
| class_correct[true] += 1 | |
| class_total[true] += 1 | |
| result_message += f"\n📊 CLASS-WISE ACCURACY:\n" | |
| for i, class_name in enumerate(CLASS_NAMES): | |
| if class_total[i] > 0: | |
| acc = 100 * class_correct[i] / class_total[i] | |
| result_message += f"{class_name}: {acc:.2f}% ({class_correct[i]}/{class_total[i]})\n" | |
| # Save final model | |
| torch.save(model.state_dict(), f'{CUSTOM_MODEL_NAME}_final.pth') | |
| # Create detailed model card | |
| model_card = f""" | |
| # GoGenix Brain MRI Model - 4-Class Classification | |
| ## Model Information | |
| - **Architecture**: Custom CNN with Global Average Pooling | |
| - **Task**: Multi-Class Brain Tumor Classification | |
| - **Classes**: {CLASS_NAMES} | |
| - **Test Accuracy**: {test_accuracy:.2f}% | |
| - **Dataset**: {DATASET_NAME} | |
| ## Usage | |
| ```python | |
| from torchvision import transforms | |
| # Load model | |
| model = BrainTumorCNN(num_classes=4) | |
| model.load_state_dict(torch.load('GoGenix_Brain_MRI_Model_final.pth')) | |
| # Preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| ``` | |
| """ | |
| with open(f'{CUSTOM_MODEL_NAME}_model_card.md', 'w') as f: | |
| f.write(model_card) | |
| result_message += f"\n✅ Model saved as '{CUSTOM_MODEL_NAME}_final.pth'\n" | |
| result_message += f"📁 Model card saved as '{CUSTOM_MODEL_NAME}_model_card.md'\n" | |
| # Download instructions | |
| result_message += f"\n📥 DOWNLOAD INSTRUCTIONS:\n" | |
| result_message += f"1. Files are saved in your working directory\n" | |
| result_message += f"2. Download '{CUSTOM_MODEL_NAME}_final.pth' for the trained model\n" | |
| result_message += f"3. Download '{CUSTOM_MODEL_NAME}_model_card.md' for documentation\n" | |
| return result_message | |
| except Exception as e: | |
| import traceback | |
| return f"❌ Training Error: {str(e)}\n\n{traceback.format_exc()}" | |
| def classify_mri(image): | |
| """Classify MRI image using trained CNN""" | |
| try: | |
| # Load model | |
| model_path = f'{CUSTOM_MODEL_NAME}_final.pth' | |
| if not os.path.exists(model_path): | |
| return {name: 0.0 for name in CLASS_NAMES} | |
| model = BrainTumorCNN(num_classes=NUM_CLASSES) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Preprocess image | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| if not isinstance(image, PILImage.Image): | |
| image = PILImage.fromarray(image) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| # Predict | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| results = {} | |
| for i, class_name in enumerate(CLASS_NAMES): | |
| results[class_name] = round(probabilities[i].item(), 4) | |
| # Get diagnosis | |
| max_class = max(results, key=results.get) | |
| max_prob = results[max_class] | |
| diagnosis_info = f"Diagnosis: {max_class} (Confidence: {max_prob*100:.1f}%)" | |
| return results, diagnosis_info | |
| except Exception as e: | |
| return {name: 0.0 for name in CLASS_NAMES}, f"Error: {str(e)}" | |
| # Gradio Interface | |
| with gr.Blocks(title="GoGenix Brain MRI Classifier") as demo: | |
| gr.Markdown("# 🧠 GoGenix Brain MRI CNN Classifier - 4 Classes") | |
| gr.Markdown(f"**Dataset**: {DATASET_NAME} | **Classes**: {', '.join(CLASS_NAMES)}") | |
| with gr.Tab("🚀 Train CNN Model"): | |
| gr.Markdown("### Train 4-Class CNN Model") | |
| gr.Markdown(f"**Target**: 98%+ Accuracy | **Classes**: {', '.join(CLASS_NAMES)}") | |
| train_btn = gr.Button("Start 4-Class Training", variant="primary", size="lg") | |
| output_text = gr.Textbox( | |
| label="Training Progress", | |
| lines=25, | |
| placeholder="Training output will appear here..." | |
| ) | |
| train_btn.click( | |
| fn=train_and_save_model, | |
| outputs=output_text | |
| ) | |
| with gr.Tab("🔍 Classify MRI"): | |
| gr.Markdown("### Brain Tumor Type Detection") | |
| gr.Markdown(f"Upload MRI scan for 4-class classification") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="MRI Brain Scan", | |
| height=300 | |
| ) | |
| classify_btn = gr.Button("Analyze Scan", variant="secondary") | |
| with gr.Row(): | |
| result_label = gr.Label(label="Class Probabilities", num_top_classes=4) | |
| diagnosis_text = gr.Textbox( | |
| label="Diagnostic Result", | |
| interactive=False | |
| ) | |
| def process_classification(image): | |
| results, diagnosis = classify_mri(image) | |
| return results, diagnosis | |
| classify_btn.click( | |
| fn=process_classification, | |
| inputs=image_input, | |
| outputs=[result_label, diagnosis_text] | |
| ) | |
| with gr.Tab("📊 Model Architecture"): | |
| gr.Markdown("### CNN Architecture Details") | |
| gr.Markdown(f""" | |
| **Architecture**: Custom CNN with 4 Convolutional Blocks + GAP | |
| **Classes**: {NUM_CLASSES} | |
| - Glioma Tumors | |
| - Meningioma Tumors | |
| - No Tumor (Healthy) | |
| - Pituitary Tumors | |
| **Enhanced Features**: | |
| - Global Average Pooling for better generalization | |
| - Advanced data augmentation | |
| - Cosine annealing learning rate | |
| - Early stopping | |
| - Class distribution analysis | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |