import os import torch import torch.nn as nn import torch.optim as optim import pandas as pd import numpy as np from tqdm import tqdm from dataloader import get_dataloaders from model import build_model from utils import get_device, accuracy def compute_class_weights(csv_path): df = pd.read_csv(csv_path) class_counts = df["label_id"].value_counts().sort_index() total_samples = class_counts.sum() class_counts = torch.tensor(class_counts.values, dtype=torch.float32) # Soft inverse-frequency weighting weights = total_samples / class_counts # Log-scale to reduce extremes weights = torch.log1p(weights) # Normalize weights = weights / weights.mean() # 🔒 Cap extreme weights (critical) weights = torch.clamp(weights, max=3.0) return weights # Train and validation functions for one epoch each def train_one_epoch(model, loader, criterion, optimizer, device): model.train() total_loss, total_acc = 0.0, 0.0 for images, labels in tqdm(loader, desc="Training", leave=False): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() total_acc += accuracy(outputs, labels) return total_loss / len(loader), total_acc / len(loader) def validate_one_epoch(model, loader, criterion, device): model.eval() total_loss, total_acc = 0.0, 0.0 with torch.no_grad(): for images, labels in tqdm(loader, desc="Validation", leave=False): images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() total_acc += accuracy(outputs, labels) return total_loss / len(loader), total_acc / len(loader) def main(): #Hyperparameters and paths BATCH_SIZE = 32 EPOCHS = 20 LR = 1e-4 PATIENCE = 4 CSV_PATH = "data_processed/metadata_final.csv" IMG_DIR = "data_processed/images" CHECKPOINT_DIR = "checkpoints" CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/best_model.pth" os.makedirs(CHECKPOINT_DIR, exist_ok=True) #Setup device = get_device() print("Using device:", device) df = pd.read_csv(CSV_PATH) num_classes = df["label_id"].nunique() train_loader, val_loader = get_dataloaders( csv_path=CSV_PATH, images_dir=IMG_DIR, batch_size=BATCH_SIZE ) model = build_model(num_classes, device) class_weights = compute_class_weights(CSV_PATH).to(device) criterion = nn.CrossEntropyLoss( weight=class_weights, label_smoothing=0.02 ) optimizer = torch.optim.AdamW( model.parameters(), lr=LR, weight_decay=1e-4 ) # Learning rate scheduler so that lr reduces if val loss plateaus scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", patience=2, factor=0.5 ) best_val_loss = float("inf") epochs_without_improvement = 0 # Training loop with early stopping to prevent overfitting for epoch in range(EPOCHS): print(f"\nEpoch [{epoch + 1}/{EPOCHS}]") train_loss, train_acc = train_one_epoch( model, train_loader, criterion, optimizer, device ) val_loss, val_acc = validate_one_epoch( model, val_loader, criterion, device ) scheduler.step(val_loss) print( f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | " f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}" ) if val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 torch.save(model.state_dict(), CHECKPOINT_PATH) print("Best model saved") else: epochs_without_improvement += 1 if epochs_without_improvement >= PATIENCE: print("Early stopping triggered") break print("\nTraining is complete.") if __name__ == "__main__": main()