Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from utils.data_loader import get_data_loaders | |
| from models.resnet_model import MonkeyResNet | |
| import os | |
| import matplotlib.pyplot as plt | |
| from sklearn.utils.class_weight import compute_class_weight | |
| import numpy as np | |
| # This class helps stop training early if validation loss stops improving | |
| class EarlyStopping: | |
| def __init__(self, patience=5): | |
| self.patience = patience | |
| self.counter = 0 | |
| self.best_loss = float('inf') | |
| self.early_stop = False | |
| def __call__(self, val_loss): | |
| if val_loss < self.best_loss: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| # Hyperparameters | |
| data_dir = "data" | |
| epochs = 50 | |
| batch_size = 32 | |
| lr = 0.001 | |
| patience = 5 | |
| # Load training and validation data | |
| train_loader, val_loader, class_names = get_data_loaders(data_dir, batch_size) | |
| # Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Calculate class weights | |
| train_labels = [] | |
| for _, labels in train_loader: | |
| train_labels.extend(labels.numpy()) | |
| train_labels = np.array(train_labels) | |
| class_weights = compute_class_weight( | |
| class_weight='balanced', | |
| classes=np.unique(train_labels), | |
| y=train_labels | |
| ) | |
| class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) | |
| # Set up model, loss function, optimizer, scheduler | |
| model = MonkeyResNet(num_classes=len(class_names)).to(device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2) | |
| early_stopper = EarlyStopping(patience=patience) | |
| # Store values for plotting | |
| train_losses, val_losses = [], [] | |
| train_accuracies, val_accuracies = [], [] | |
| # Start training loop | |
| for epoch in range(epochs): | |
| model.train() | |
| train_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| train_accuracy = 100 * correct / total | |
| train_losses.append(train_loss) | |
| train_accuracies.append(train_accuracy) | |
| # Validation step | |
| model.eval() | |
| val_loss = 0 | |
| correct_val = 0 | |
| total_val = 0 | |
| with torch.no_grad(): | |
| for images, labels in val_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| val_loss += loss.item() | |
| _, predicted = torch.max(outputs.data, 1) | |
| total_val += labels.size(0) | |
| correct_val += (predicted == labels).sum().item() | |
| val_accuracy = 100 * correct_val / total_val | |
| val_losses.append(val_loss) | |
| val_accuracies.append(val_accuracy) | |
| scheduler.step(val_loss) | |
| early_stopper(val_loss) | |
| print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Train Acc: {train_accuracy:.2f}%") | |
| if early_stopper.early_stop: | |
| print(f"Early stopping triggered at epoch {epoch+1}") | |
| break | |
| # Save the trained model | |
| os.makedirs("models", exist_ok=True) | |
| torch.save(model.state_dict(), "models/monkey_resnet.pth") | |
| print("Training done. Model saved.") | |
| # Save training and validation plots | |
| os.makedirs("plots", exist_ok=True) | |
| # Loss plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(train_losses, label="Train Loss") | |
| plt.plot(val_losses, label="Val Loss") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Loss") | |
| plt.title("Training and Validation Loss") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.savefig("plots/loss_plot.png") | |
| plt.close() | |
| # Accuracy plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(train_accuracies, label="Train Accuracy") | |
| plt.plot(val_accuracies, label="Val Accuracy") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Accuracy (%)") | |
| plt.title("Training and Validation Accuracy") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.savefig("plots/accuracy_plot.png") | |
| plt.close() |