import random import numpy as np import torch import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights from tqdm import tqdm from src.dataset import create_dataloaders def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) #training loop def train_one_epoch(model, dataloader, criterion, optimizer, device): model.train() total_loss = 0.0 correct = 0 total = 0 for images, labels, _ in tqdm(dataloader): images = images.to(device) labels = labels.to(device).long() optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() preds = torch.argmax(outputs, dim=1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = total_loss / len(dataloader) accuracy = 100.0 * correct / total return avg_loss, accuracy #validation loop def validate(model, dataloader, criterion, device): model.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels, _ in dataloader: images = images.to(device) labels = labels.to(device).long() outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() preds = torch.argmax(outputs, dim=1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = total_loss / len(dataloader) accuracy = 100.0 * correct / total return avg_loss, accuracy def plot_curves(train_losses, val_losses, train_accuracies, val_accuracies): epochs_done = len(train_losses) plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(range(1, epochs_done + 1), train_losses, marker="o", label="Train Loss") plt.plot(range(1, epochs_done + 1), val_losses, marker="o", label="Val Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Loss Curve") plt.legend() plt.subplot(1, 2, 2) plt.plot(range(1, epochs_done + 1), train_accuracies, marker="o", label="Train Accuracy") plt.plot(range(1, epochs_done + 1), val_accuracies, marker="o", label="Val Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy (%)") plt.title("Accuracy Curve") plt.legend() plt.show() set_seed(42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) train_loader, val_loader, test_loader, class_names = create_dataloaders() num_classes = len(class_names) print("Number of classes:", num_classes) print("Classes:", class_names) weights = EfficientNet_B0_Weights.DEFAULT model = efficientnet_b0(weights=weights) #freezing the feature extractor and modifying the classifier for param in model.features.parameters(): param.requires_grad = False in_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(in_features, num_classes) ) model = model.to(device) criterion = nn.CrossEntropyLoss() train_losses = [] val_losses = [] train_accuracies = [] val_accuracies = [] # Phase 1: train classifier only optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4 ) epochs_phase1 = 10 for epoch in range(epochs_phase1): train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_acc = validate(model, val_loader, criterion, device) train_losses.append(train_loss) val_losses.append(val_loss) train_accuracies.append(train_acc) val_accuracies.append(val_acc) print( f"[Phase 1] Epoch {epoch + 1}/{epochs_phase1} | " f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%" ) # Unfreeze all feature layers for full fine-tuning for param in model.features.parameters(): param.requires_grad = True optimizer = optim.AdamW( model.parameters(), lr=1e-5, weight_decay=1e-4 ) epochs_phase2 = 20 for epoch in range(epochs_phase2): train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_acc = validate(model, val_loader, criterion, device) train_losses.append(train_loss) val_losses.append(val_loss) train_accuracies.append(train_acc) val_accuracies.append(val_acc) print( f"[Phase 2] Epoch {epoch + 1}/{epochs_phase2} | " f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%" ) plot_curves(train_losses, val_losses, train_accuracies, val_accuracies)