from datetime import datetime import os import logging import torch import torch.nn as nn import torch.optim as optim from src import config import time from torch.utils.tensorboard import SummaryWriter def calculate_accuracy(y_pred, y_true): preds = torch.argmax(y_pred, dim=1) correct = (preds == y_true).sum().item() return correct / len(y_true) def setup_logging(log_dir): os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(log_dir, f"training_{timestamp}.log") logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) return log_file def train_one_epoch(model, dataloader, criterion, optimizer, device): model.train() running_loss, running_acc = 0.0, 0.0 batch_count = len(dataloader) logging.info(f"Training on {batch_count} batches") for batch_idx, (images, labels) in enumerate(dataloader): if batch_idx % 10 == 0: logging.info(f" Batch {batch_idx}/{batch_count}") images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) acc = calculate_accuracy(outputs, labels) loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=config.GRAD_CLIP_VALUE) optimizer.step() running_loss += loss.item() running_acc += acc return running_loss / len(dataloader), running_acc / len(dataloader) def train_model(model, train_loader, val_loader, epochs=config.EPOCHS, lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, device=config.DEVICE): log_file = setup_logging(config.LOG_DIR) logging.info(f"Training logs will be saved to: {log_file}") logging.info(f"Training configuration:") logging.info(f" Epochs: {epochs}") logging.info(f" Learning rate: {lr}") logging.info(f" Weight decay: {weight_decay}") logging.info(f" Device: {device}") logging.info(f" Batch size: {config.BATCH_SIZE}") logging.info(f" Image size: {config.IMAGE_SIZE}") model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.LR_SCHEDULER_FACTOR, patience=config.LR_SCHEDULER_PATIENCE, verbose=True ) criterion = nn.CrossEntropyLoss() best_val_acc = 0.0 run_name = time.strftime("run_%Y%m%d-%H%M") log_dir = f"{config.LOG_DIR}/{run_name}" writer = SummaryWriter(log_dir=log_dir) logging.info(f"Training on: {device.upper()}\n") for epoch in range(epochs): epoch_start_time = time.time() logging.info(f"Epoch {epoch+1}/{epochs} started") train_loss, train_acc = train_one_epoch( model, train_loader, criterion, optimizer, device) logging.info("Validating...") val_loss, val_acc = validate(model, val_loader, criterion, device) epoch_time = time.time() - epoch_start_time scheduler.step(val_acc) logging.info( f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s") logging.info( f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%") logging.info( f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%") writer.add_scalar("Loss/train", train_loss, epoch) writer.add_scalar("Loss/val", val_loss, epoch) writer.add_scalar("Accuracy/train", train_acc, epoch) writer.add_scalar("Accuracy/val", val_acc, epoch) if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), config.MODEL_SAVE_PATH) logging.info("Model saved!") writer.close() logging.info("Training complete. Best Val Acc: {:.2f}%".format( best_val_acc * 100)) return best_val_acc def validate(model, dataloader, criterion, device): model.eval() val_loss, val_acc = 0.0, 0.0 with torch.no_grad(): for images, labels in dataloader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) acc = calculate_accuracy(outputs, labels) val_loss += loss.item() val_acc += acc return val_loss / len(dataloader), val_acc / len(dataloader)