import gradio as gr import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms, models from datasets import load_dataset import numpy as np import os from PIL import Image as PILImage from sklearn.metrics import classification_report, confusion_matrix import pandas as pd # Configuration CUSTOM_MODEL_NAME = "GoGenix_Brain_MRI_Model" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {DEVICE}") # Dataset information DATASET_NAME = "PranomVignesh/MRI-Images-of-Brain-Tumor" CLASS_NAMES = ["glioma", "meningioma", "no-tumor", "pituitary"] NUM_CLASSES = len(CLASS_NAMES) # Enhanced CNN Architecture for 4-Class Classification class BrainTumorCNN(nn.Module): def __init__(self, num_classes=4): super(BrainTumorCNN, self).__init__() # Feature extraction with more capacity for 4 classes self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(64) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(128) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(256) self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(512) # Global Average Pooling instead of FC layers self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Fully connected layers self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, num_classes) # Regularization self.dropout = nn.Dropout(0.5) self.relu = nn.ReLU() def forward(self, x): # Block 1 x = self.relu(self.bn1(self.conv1(x))) x = nn.MaxPool2d(2)(x) x = self.dropout(x) # Block 2 x = self.relu(self.bn2(self.conv2(x))) x = nn.MaxPool2d(2)(x) x = self.dropout(x) # Block 3 x = self.relu(self.bn3(self.conv3(x))) x = nn.MaxPool2d(2)(x) x = self.dropout(x) # Block 4 x = self.relu(self.bn4(self.conv4(x))) x = nn.MaxPool2d(2)(x) x = self.dropout(x) # Global Average Pooling x = self.gap(x) x = x.view(x.size(0), -1) # Fully connected x = self.relu(self.fc1(x)) x = self.dropout(x) x = self.relu(self.fc2(x)) x = self.dropout(x) x = self.fc3(x) return x # Advanced Data Augmentation def get_transforms(): train_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.GaussianBlur(3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return train_transform, test_transform # Dataset class for 4-class classification class BrainTumorDataset(Dataset): def __init__(self, dataset, transform=None): self.dataset = dataset self.transform = transform # Build label mapping self.label_to_idx = {name: idx for idx, name in enumerate(CLASS_NAMES)} print(f"Label mapping: {self.label_to_idx}") def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] # Handle image image = item['image'] if not isinstance(image, PILImage.Image): image = PILImage.fromarray(image) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Handle label - map to correct class index label = item.get('label', 0) # Handle different label formats if isinstance(label, str): # Label is string like "glioma", "meningioma", etc. label_idx = self.label_to_idx.get(label.lower(), 0) elif isinstance(label, int): # Label is already an index label_idx = label else: label_idx = 0 # Default to first class # Ensure label is within valid range label_idx = max(0, min(label_idx, NUM_CLASSES - 1)) if self.transform: image = self.transform(image) return image, torch.tensor(label_idx, dtype=torch.long) def analyze_dataset(dataset): """Analyze dataset structure and class distribution""" class_counts = {name: 0 for name in CLASS_NAMES} for i in range(min(1000, len(dataset))): item = dataset[i] label = item.get('label', 0) if isinstance(label, str): if label.lower() in class_counts: class_counts[label.lower()] += 1 elif isinstance(label, int) and label < len(CLASS_NAMES): class_counts[CLASS_NAMES[label]] += 1 return class_counts def train_and_save_model(): """Train CNN model for 4-class brain tumor classification""" try: # Load the specified dataset print(f"Loading dataset: {DATASET_NAME}") dataset = load_dataset(DATASET_NAME) splits = list(dataset.keys()) print(f"Splits available: {splits}") # Use train/valid splits train_data = dataset['train'] valid_data = dataset['valid'] test_data = dataset['test'] print(f"Training samples: {len(train_data)}") print(f"Validation samples: {len(valid_data)}") print(f"Test samples: {len(test_data)}") # Analyze class distribution train_dist = analyze_dataset(train_data) valid_dist = analyze_dataset(valid_data) print("Training distribution:", train_dist) print("Validation distribution:", valid_dist) # Get transforms train_transform, test_transform = get_transforms() # Create datasets train_dataset = BrainTumorDataset(train_data, train_transform) valid_dataset = BrainTumorDataset(valid_data, test_transform) test_dataset = BrainTumorDataset(test_data, test_transform) # Create data loaders train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2) valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2) # Initialize model model = BrainTumorCNN(num_classes=NUM_CLASSES) model.to(DEVICE) # Loss function with class weighting for imbalance criterion = nn.CrossEntropyLoss() # Advanced optimizer optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4) # Cosine annealing scheduler scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) # Training parameters num_epochs = 100 best_accuracy = 0.0 patience = 10 patience_counter = 0 result_message = f"šŸš€ Training CNN Model for 4-Class Brain Tumor Classification\n\n" result_message += f"Dataset: {DATASET_NAME}\n" result_message += f"Classes: {CLASS_NAMES}\n" result_message += f"Training samples: {len(train_dataset)}\n" result_message += f"Validation samples: {len(valid_dataset)}\n" result_message += f"Test samples: {len(test_dataset)}\n" result_message += f"Epochs: {num_epochs}\n" result_message += f"Device: {DEVICE}\n\n" result_message += f"Class Distribution - Train: {train_dist}\n" result_message += f"Class Distribution - Valid: {valid_dist}\n\n" # Training loop for epoch in range(num_epochs): # Training phase model.train() running_loss = 0.0 train_correct = 0 train_total = 0 for images, labels in train_loader: images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # Training accuracy _, predicted = torch.max(outputs.data, 1) train_total += labels.size(0) train_correct += (predicted == labels).sum().item() # Validation phase model.eval() valid_correct = 0 valid_total = 0 with torch.no_grad(): for images, labels in valid_loader: images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = model(images) _, predicted = torch.max(outputs.data, 1) valid_total += labels.size(0) valid_correct += (predicted == labels).sum().item() train_accuracy = 100 * train_correct / train_total valid_accuracy = 100 * valid_correct / valid_total avg_loss = running_loss / len(train_loader) # Update scheduler scheduler.step() current_lr = scheduler.get_last_lr()[0] # Save best model if valid_accuracy > best_accuracy: best_accuracy = valid_accuracy patience_counter = 0 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'accuracy': valid_accuracy, 'loss': avg_loss, }, f'{CUSTOM_MODEL_NAME}_best.pth') else: patience_counter += 1 result_message += f'Epoch [{epoch+1}/{num_epochs}], LR: {current_lr:.6f}, Loss: {avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Valid Acc: {valid_accuracy:.2f}%\n' # Early stopping if patience_counter >= patience: result_message += f"\nā¹ļø Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)\n" break # Target accuracy achieved if valid_accuracy >= 98.0: result_message += f"\nšŸŽÆ Target accuracy achieved! Stopping training at epoch {epoch+1}\n" break # Load best model for final evaluation best_checkpoint = torch.load(f'{CUSTOM_MODEL_NAME}_best.pth') model.load_state_dict(best_checkpoint['model_state_dict']) model.eval() # Final test evaluation test_correct = 0 test_total = 0 all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = model(images) _, predicted = torch.max(outputs, 1) test_total += labels.size(0) test_correct += (predicted == labels).sum().item() all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) test_accuracy = 100 * test_correct / test_total result_message += f"\nšŸ FINAL TEST RESULTS:\n" result_message += f"Best Validation Accuracy: {best_checkpoint['accuracy']:.2f}%\n" result_message += f"Final Test Accuracy: {test_accuracy:.2f}%\n" # Class-wise accuracy class_correct = [0] * NUM_CLASSES class_total = [0] * NUM_CLASSES for pred, true in zip(all_preds, all_labels): if pred == true: class_correct[true] += 1 class_total[true] += 1 result_message += f"\nšŸ“Š CLASS-WISE ACCURACY:\n" for i, class_name in enumerate(CLASS_NAMES): if class_total[i] > 0: acc = 100 * class_correct[i] / class_total[i] result_message += f"{class_name}: {acc:.2f}% ({class_correct[i]}/{class_total[i]})\n" # Save final model torch.save(model.state_dict(), f'{CUSTOM_MODEL_NAME}_final.pth') # Create detailed model card model_card = f""" # GoGenix Brain MRI Model - 4-Class Classification ## Model Information - **Architecture**: Custom CNN with Global Average Pooling - **Task**: Multi-Class Brain Tumor Classification - **Classes**: {CLASS_NAMES} - **Test Accuracy**: {test_accuracy:.2f}% - **Dataset**: {DATASET_NAME} ## Usage ```python from torchvision import transforms # Load model model = BrainTumorCNN(num_classes=4) model.load_state_dict(torch.load('GoGenix_Brain_MRI_Model_final.pth')) # Preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ``` """ with open(f'{CUSTOM_MODEL_NAME}_model_card.md', 'w') as f: f.write(model_card) result_message += f"\nāœ… Model saved as '{CUSTOM_MODEL_NAME}_final.pth'\n" result_message += f"šŸ“ Model card saved as '{CUSTOM_MODEL_NAME}_model_card.md'\n" # Download instructions result_message += f"\nšŸ“„ DOWNLOAD INSTRUCTIONS:\n" result_message += f"1. Files are saved in your working directory\n" result_message += f"2. Download '{CUSTOM_MODEL_NAME}_final.pth' for the trained model\n" result_message += f"3. Download '{CUSTOM_MODEL_NAME}_model_card.md' for documentation\n" return result_message except Exception as e: import traceback return f"āŒ Training Error: {str(e)}\n\n{traceback.format_exc()}" def classify_mri(image): """Classify MRI image using trained CNN""" try: # Load model model_path = f'{CUSTOM_MODEL_NAME}_final.pth' if not os.path.exists(model_path): return {name: 0.0 for name in CLASS_NAMES} model = BrainTumorCNN(num_classes=NUM_CLASSES) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.to(DEVICE) model.eval() # Preprocess image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if not isinstance(image, PILImage.Image): image = PILImage.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image_tensor = transform(image).unsqueeze(0).to(DEVICE) # Predict with torch.no_grad(): output = model(image_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) results = {} for i, class_name in enumerate(CLASS_NAMES): results[class_name] = round(probabilities[i].item(), 4) # Get diagnosis max_class = max(results, key=results.get) max_prob = results[max_class] diagnosis_info = f"Diagnosis: {max_class} (Confidence: {max_prob*100:.1f}%)" return results, diagnosis_info except Exception as e: return {name: 0.0 for name in CLASS_NAMES}, f"Error: {str(e)}" # Gradio Interface with gr.Blocks(title="GoGenix Brain MRI Classifier") as demo: gr.Markdown("# 🧠 GoGenix Brain MRI CNN Classifier - 4 Classes") gr.Markdown(f"**Dataset**: {DATASET_NAME} | **Classes**: {', '.join(CLASS_NAMES)}") with gr.Tab("šŸš€ Train CNN Model"): gr.Markdown("### Train 4-Class CNN Model") gr.Markdown(f"**Target**: 98%+ Accuracy | **Classes**: {', '.join(CLASS_NAMES)}") train_btn = gr.Button("Start 4-Class Training", variant="primary", size="lg") output_text = gr.Textbox( label="Training Progress", lines=25, placeholder="Training output will appear here..." ) train_btn.click( fn=train_and_save_model, outputs=output_text ) with gr.Tab("šŸ” Classify MRI"): gr.Markdown("### Brain Tumor Type Detection") gr.Markdown(f"Upload MRI scan for 4-class classification") image_input = gr.Image( type="pil", label="MRI Brain Scan", height=300 ) classify_btn = gr.Button("Analyze Scan", variant="secondary") with gr.Row(): result_label = gr.Label(label="Class Probabilities", num_top_classes=4) diagnosis_text = gr.Textbox( label="Diagnostic Result", interactive=False ) def process_classification(image): results, diagnosis = classify_mri(image) return results, diagnosis classify_btn.click( fn=process_classification, inputs=image_input, outputs=[result_label, diagnosis_text] ) with gr.Tab("šŸ“Š Model Architecture"): gr.Markdown("### CNN Architecture Details") gr.Markdown(f""" **Architecture**: Custom CNN with 4 Convolutional Blocks + GAP **Classes**: {NUM_CLASSES} - Glioma Tumors - Meningioma Tumors - No Tumor (Healthy) - Pituitary Tumors **Enhanced Features**: - Global Average Pooling for better generalization - Advanced data augmentation - Cosine annealing learning rate - Early stopping - Class distribution analysis """) if __name__ == "__main__": demo.launch()