Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm import tqdm | |
| from dataloader import get_dataloaders | |
| from model import build_model | |
| from utils import get_device, accuracy | |
| def compute_class_weights(csv_path): | |
| df = pd.read_csv(csv_path) | |
| class_counts = df["label_id"].value_counts().sort_index() | |
| total_samples = class_counts.sum() | |
| class_counts = torch.tensor(class_counts.values, dtype=torch.float32) | |
| # Soft inverse-frequency weighting | |
| weights = total_samples / class_counts | |
| # Log-scale to reduce extremes | |
| weights = torch.log1p(weights) | |
| # Normalize | |
| weights = weights / weights.mean() | |
| # 🔒 Cap extreme weights (critical) | |
| weights = torch.clamp(weights, max=3.0) | |
| return weights | |
| # Train and validation functions for one epoch each | |
| def train_one_epoch(model, loader, criterion, optimizer, device): | |
| model.train() | |
| total_loss, total_acc = 0.0, 0.0 | |
| for images, labels in tqdm(loader, desc="Training", leave=False): | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| total_acc += accuracy(outputs, labels) | |
| return total_loss / len(loader), total_acc / len(loader) | |
| def validate_one_epoch(model, loader, criterion, device): | |
| model.eval() | |
| total_loss, total_acc = 0.0, 0.0 | |
| with torch.no_grad(): | |
| for images, labels in tqdm(loader, desc="Validation", leave=False): | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| total_acc += accuracy(outputs, labels) | |
| return total_loss / len(loader), total_acc / len(loader) | |
| def main(): | |
| #Hyperparameters and paths | |
| BATCH_SIZE = 32 | |
| EPOCHS = 20 | |
| LR = 1e-4 | |
| PATIENCE = 4 | |
| CSV_PATH = "data_processed/metadata_final.csv" | |
| IMG_DIR = "data_processed/images" | |
| CHECKPOINT_DIR = "checkpoints" | |
| CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/best_model.pth" | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| #Setup | |
| device = get_device() | |
| print("Using device:", device) | |
| df = pd.read_csv(CSV_PATH) | |
| num_classes = df["label_id"].nunique() | |
| train_loader, val_loader = get_dataloaders( | |
| csv_path=CSV_PATH, | |
| images_dir=IMG_DIR, | |
| batch_size=BATCH_SIZE | |
| ) | |
| model = build_model(num_classes, device) | |
| class_weights = compute_class_weights(CSV_PATH).to(device) | |
| criterion = nn.CrossEntropyLoss( | |
| weight=class_weights, | |
| label_smoothing=0.02 | |
| ) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=LR, | |
| weight_decay=1e-4 | |
| ) | |
| # Learning rate scheduler so that lr reduces if val loss plateaus | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", patience=2, factor=0.5 | |
| ) | |
| best_val_loss = float("inf") | |
| epochs_without_improvement = 0 | |
| # Training loop with early stopping to prevent overfitting | |
| for epoch in range(EPOCHS): | |
| print(f"\nEpoch [{epoch + 1}/{EPOCHS}]") | |
| train_loss, train_acc = train_one_epoch( | |
| model, train_loader, criterion, optimizer, device | |
| ) | |
| val_loss, val_acc = validate_one_epoch( | |
| model, val_loader, criterion, device | |
| ) | |
| scheduler.step(val_loss) | |
| print( | |
| f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | " | |
| f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}" | |
| ) | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| epochs_without_improvement = 0 | |
| torch.save(model.state_dict(), CHECKPOINT_PATH) | |
| print("Best model saved") | |
| else: | |
| epochs_without_improvement += 1 | |
| if epochs_without_improvement >= PATIENCE: | |
| print("Early stopping triggered") | |
| break | |
| print("\nTraining is complete.") | |
| if __name__ == "__main__": | |
| main() | |