Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import transforms, models | |
| import medmnist | |
| from medmnist import INFO | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| def main(): | |
| # 1. Setup and Hardware Configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Training on: {device}") | |
| # Set this to point to your secondary NVMe drive to prevent OS drive I/O bottlenecks | |
| dataset_root = r"C:\Users\USER\Downloads\MedMNIST_Data" | |
| os.makedirs(dataset_root, exist_ok=True) | |
| data_flag = 'pneumoniamnist' | |
| info = INFO[data_flag] | |
| DataClass = getattr(medmnist, info['python_class']) | |
| # 2. The Golden Preprocessing & Dynamic Augmentation | |
| # We normalize to [-1, 1] using mean=0.5, std=0.5 so it matches your team's generator math | |
| train_transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=3), # ResNet expects 3 RGB channels | |
| transforms.RandomHorizontalFlip(), # Dynamic spatial augmentation | |
| transforms.RandomRotation(10), # Dynamic spatial augmentation | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| val_transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=3), | |
| transforms.ToTensor(), # NO spatial augmentation for validation | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # 3. Load Datasets | |
| print("Fetching 224x224 dataset...") | |
| train_dataset = DataClass(split='train', transform=train_transform, download=True, size=224, root=dataset_root) | |
| val_dataset = DataClass(split='val', transform=val_transform, download=True, size=224, root=dataset_root) | |
| # 4. DataLoaders | |
| # Using batch size 32. num_workers=0 is the safest default for Windows to prevent multiprocessing crashes. | |
| train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, num_workers=0) | |
| # 5. Initialize ResNet50 | |
| print("Loading ResNet50...") | |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| # Modify the final layer for Binary Classification (Pneumonia vs Normal) | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, 2) | |
| model = model.to(device) | |
| # 6. Loss and Optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) # 1e-4 is a very stable learning rate for fine-tuning | |
| num_epochs = 10 | |
| history_loss = [] | |
| history_acc = [] | |
| # 7. The Training Loop | |
| for epoch in range(num_epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| # tqdm creates a nice progress bar in the terminal | |
| loop = tqdm(train_loader, leave=True) | |
| for images, labels in loop: | |
| images, labels = images.to(device), labels.to(device).squeeze().long() | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]") | |
| loop.set_postfix(loss=loss.item(), acc=100.*correct/total) | |
| # Calculate the average loss and accuracy for this epoch | |
| epoch_loss = running_loss / len(train_loader) | |
| epoch_acc = 100. * correct / total | |
| history_loss.append(epoch_loss) | |
| history_acc.append(epoch_acc) | |
| # 8. Save the Frozen Weights for your team | |
| save_path = os.path.join(dataset_root, 'baseline_resnet50.pth') | |
| torch.save(model.state_dict(), save_path) | |
| print(f"\nTraining Complete! Baseline weights saved to: {save_path}") | |
| # Create the Learning Curve Graph | |
| fig, ax1 = plt.subplots(figsize=(10, 6)) | |
| # Plot Loss (Red Line) | |
| color = 'tab:red' | |
| ax1.set_xlabel('Epochs', fontweight='bold') | |
| ax1.set_ylabel('Training Loss', color=color, fontweight='bold') | |
| ax1.plot(range(1, num_epochs+1), history_loss, color=color, marker='o', label='Loss') | |
| ax1.tick_params(axis='y', labelcolor=color) | |
| # Plot Accuracy (Blue Line) on the same graph | |
| ax2 = ax1.twinx() | |
| color = 'tab:blue' | |
| ax2.set_ylabel('Training Accuracy (%)', color=color, fontweight='bold') | |
| ax2.plot(range(1, num_epochs+1), history_acc, color=color, marker='s', label='Accuracy') | |
| ax2.tick_params(axis='y', labelcolor=color) | |
| plt.title('ResNet50 Training Curve', fontsize=14, fontweight='bold') | |
| fig.tight_layout() | |
| # Save the image | |
| graph_path = os.path.join(dataset_root, 'learning_curve.png') | |
| plt.savefig(graph_path, dpi=300) | |
| print(f"Learning Curve saved to: {graph_path}") | |
| if __name__ == '__main__': | |
| main() |