guiBackend / CNN /CNN.py
BrianLov's picture
Upload folder using huggingface_hub
068b6e0 verified
Raw
History Blame Contribute Delete
5.3 kB
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
def main():
# 1. Setup and Hardware Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
# Set this to point to your secondary NVMe drive to prevent OS drive I/O bottlenecks
dataset_root = r"C:\Users\USER\Downloads\MedMNIST_Data"
os.makedirs(dataset_root, exist_ok=True)
data_flag = 'pneumoniamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])
# 2. The Golden Preprocessing & Dynamic Augmentation
# We normalize to [-1, 1] using mean=0.5, std=0.5 so it matches your team's generator math
train_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3), # ResNet expects 3 RGB channels
transforms.RandomHorizontalFlip(), # Dynamic spatial augmentation
transforms.RandomRotation(10), # Dynamic spatial augmentation
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
val_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(), # NO spatial augmentation for validation
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 3. Load Datasets
print("Fetching 224x224 dataset...")
train_dataset = DataClass(split='train', transform=train_transform, download=True, size=224, root=dataset_root)
val_dataset = DataClass(split='val', transform=val_transform, download=True, size=224, root=dataset_root)
# 4. DataLoaders
# Using batch size 32. num_workers=0 is the safest default for Windows to prevent multiprocessing crashes.
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, num_workers=0)
# 5. Initialize ResNet50
print("Loading ResNet50...")
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# Modify the final layer for Binary Classification (Pneumonia vs Normal)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model = model.to(device)
# 6. Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 1e-4 is a very stable learning rate for fine-tuning
num_epochs = 10
history_loss = []
history_acc = []
# 7. The Training Loop
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
# tqdm creates a nice progress bar in the terminal
loop = tqdm(train_loader, leave=True)
for images, labels in loop:
images, labels = images.to(device), labels.to(device).squeeze().long()
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
loop.set_postfix(loss=loss.item(), acc=100.*correct/total)
# Calculate the average loss and accuracy for this epoch
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100. * correct / total
history_loss.append(epoch_loss)
history_acc.append(epoch_acc)
# 8. Save the Frozen Weights for your team
save_path = os.path.join(dataset_root, 'baseline_resnet50.pth')
torch.save(model.state_dict(), save_path)
print(f"\nTraining Complete! Baseline weights saved to: {save_path}")
# Create the Learning Curve Graph
fig, ax1 = plt.subplots(figsize=(10, 6))
# Plot Loss (Red Line)
color = 'tab:red'
ax1.set_xlabel('Epochs', fontweight='bold')
ax1.set_ylabel('Training Loss', color=color, fontweight='bold')
ax1.plot(range(1, num_epochs+1), history_loss, color=color, marker='o', label='Loss')
ax1.tick_params(axis='y', labelcolor=color)
# Plot Accuracy (Blue Line) on the same graph
ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('Training Accuracy (%)', color=color, fontweight='bold')
ax2.plot(range(1, num_epochs+1), history_acc, color=color, marker='s', label='Accuracy')
ax2.tick_params(axis='y', labelcolor=color)
plt.title('ResNet50 Training Curve', fontsize=14, fontweight='bold')
fig.tight_layout()
# Save the image
graph_path = os.path.join(dataset_root, 'learning_curve.png')
plt.savefig(graph_path, dpi=300)
print(f"Learning Curve saved to: {graph_path}")
if __name__ == '__main__':
main()