Flamekizer11's picture
Upload 27 files
64d0ccc verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from tqdm import tqdm
from dataloader import get_dataloaders
from model import build_model
from utils import get_device, accuracy
def compute_class_weights(csv_path):
df = pd.read_csv(csv_path)
class_counts = df["label_id"].value_counts().sort_index()
total_samples = class_counts.sum()
class_counts = torch.tensor(class_counts.values, dtype=torch.float32)
# Soft inverse-frequency weighting
weights = total_samples / class_counts
# Log-scale to reduce extremes
weights = torch.log1p(weights)
# Normalize
weights = weights / weights.mean()
# 🔒 Cap extreme weights (critical)
weights = torch.clamp(weights, max=3.0)
return weights
# Train and validation functions for one epoch each
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss, total_acc = 0.0, 0.0
for images, labels in tqdm(loader, desc="Training", leave=False):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_acc += accuracy(outputs, labels)
return total_loss / len(loader), total_acc / len(loader)
def validate_one_epoch(model, loader, criterion, device):
model.eval()
total_loss, total_acc = 0.0, 0.0
with torch.no_grad():
for images, labels in tqdm(loader, desc="Validation", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
total_acc += accuracy(outputs, labels)
return total_loss / len(loader), total_acc / len(loader)
def main():
#Hyperparameters and paths
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-4
PATIENCE = 4
CSV_PATH = "data_processed/metadata_final.csv"
IMG_DIR = "data_processed/images"
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/best_model.pth"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
#Setup
device = get_device()
print("Using device:", device)
df = pd.read_csv(CSV_PATH)
num_classes = df["label_id"].nunique()
train_loader, val_loader = get_dataloaders(
csv_path=CSV_PATH,
images_dir=IMG_DIR,
batch_size=BATCH_SIZE
)
model = build_model(num_classes, device)
class_weights = compute_class_weights(CSV_PATH).to(device)
criterion = nn.CrossEntropyLoss(
weight=class_weights,
label_smoothing=0.02
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=LR,
weight_decay=1e-4
)
# Learning rate scheduler so that lr reduces if val loss plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=2, factor=0.5
)
best_val_loss = float("inf")
epochs_without_improvement = 0
# Training loop with early stopping to prevent overfitting
for epoch in range(EPOCHS):
print(f"\nEpoch [{epoch + 1}/{EPOCHS}]")
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device
)
val_loss, val_acc = validate_one_epoch(
model, val_loader, criterion, device
)
scheduler.step(val_loss)
print(
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
torch.save(model.state_dict(), CHECKPOINT_PATH)
print("Best model saved")
else:
epochs_without_improvement += 1
if epochs_without_improvement >= PATIENCE:
print("Early stopping triggered")
break
print("\nTraining is complete.")
if __name__ == "__main__":
main()