import logging import torch from tqdm import tqdm from transformers import get_cosine_schedule_with_warmup from src.config import CHECKPOINT_DIR logger = logging.getLogger(__name__) class EarlyStopping: def __init__(self, patience=7, min_delta=0.001): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = None self.early_stop = False def __call__(self, val_acc): if self.best_score is None: self.best_score = val_acc elif val_acc < self.best_score + self.min_delta: self.counter += 1 logger.info( f"EarlyStopping counter: {self.counter}/{self.patience}" ) if self.counter >= self.patience: self.early_stop = True else: self.best_score = val_acc self.counter = 0 def train_single_input_model( model, train_loader, eval_loader, optimizer, criterion, device, epochs, checkpoint_model_name, patience=7 ): logger.info("Starting single-input training...") num_training_steps = epochs * len(train_loader) num_warmup_steps = int(0.1 * num_training_steps) scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) early_stopping = EarlyStopping(patience=patience) best_acc = 0.0 all_preds = [] all_labels = [] for epoch in range(epochs): logger.info(f"Epoch {epoch + 1}/{epochs}") model.train() running_loss = 0 correct = 0 total = 0 for images, labels in tqdm( train_loader, desc=f"Epoch {epoch+1} Training" ): images = images.to(device) labels = labels.to(device) optimizer.zero_grad(set_to_none=True) logits = model(images) loss = criterion(logits, labels) loss.backward() optimizer.step() scheduler.step() running_loss += loss.item() preds = torch.argmax(logits, dim=1) correct += (preds == labels).sum().item() total += labels.size(0) train_loss = running_loss / len(train_loader) train_acc = 100 * correct / total model.eval() val_running_loss = 0 val_correct = 0 val_total = 0 all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in tqdm( eval_loader, desc=f"Epoch {epoch+1} Validation" ): images = images.to(device) labels = labels.to(device) logits = model(images) loss = criterion(logits, labels) val_running_loss += loss.item() preds = torch.argmax(logits, dim=1) val_correct += (preds == labels).sum().item() val_total += labels.size(0) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) val_loss = val_running_loss / len(eval_loader) val_acc = 100 * val_correct / val_total logger.info( f"Train Loss: {train_loss:.4f} | " f"Train Acc: {train_acc:.2f}% || " f"Val Loss: {val_loss:.4f} | " f"Val Acc: {val_acc:.2f}%" ) if val_acc > best_acc: best_acc = val_acc checkpoint_path = CHECKPOINT_DIR / f"{checkpoint_model_name}.pt" torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "val_acc": val_acc }, checkpoint_path ) logger.info(f"Best checkpoint saved at: {checkpoint_path}") early_stopping(val_acc) if early_stopping.early_stop: logger.info("Early stopping triggered.") break return all_preds, all_labels def train_dual_input_model( model, train_loader, eval_loader, optimizer, criterion, device, epochs, checkpoint_model_name, patience=7 ): logger.info("Starting dual-input training...") num_training_steps = epochs * len(train_loader) num_warmup_steps = int(0.1 * num_training_steps) scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) early_stopping = EarlyStopping(patience=patience) best_acc = 0.0 all_preds = [] all_labels = [] for epoch in range(epochs): logger.info(f"Epoch {epoch + 1}/{epochs}") model.train() running_loss = 0 correct = 0 total = 0 for batch in tqdm( train_loader, desc=f"Epoch {epoch+1} Training" ): images_eff = batch["pixel_values_eff"].to(device) images_cnx = batch["pixel_values_cnx"].to(device) labels = batch["labels"].to(device) optimizer.zero_grad(set_to_none=True) logits = model(images_eff, images_cnx) loss = criterion(logits, labels) loss.backward() optimizer.step() scheduler.step() running_loss += loss.item() preds = torch.argmax(logits, dim=1) correct += (preds == labels).sum().item() total += labels.size(0) train_loss = running_loss / len(train_loader) train_acc = 100 * correct / total model.eval() val_running_loss = 0 val_correct = 0 val_total = 0 all_preds = [] all_labels = [] with torch.no_grad(): for batch in tqdm( eval_loader, desc=f"Epoch {epoch+1} Validation" ): images_eff = batch["pixel_values_eff"].to(device) images_cnx = batch["pixel_values_cnx"].to(device) labels = batch["labels"].to(device) logits = model(images_eff, images_cnx) loss = criterion(logits, labels) val_running_loss += loss.item() preds = torch.argmax(logits, dim=1) val_correct += (preds == labels).sum().item() val_total += labels.size(0) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) val_loss = val_running_loss / len(eval_loader) val_acc = 100 * val_correct / val_total logger.info( f"Train Loss: {train_loss:.4f} | " f"Train Acc: {train_acc:.2f}% || " f"Val Loss: {val_loss:.4f} | " f"Val Acc: {val_acc:.2f}%" ) if val_acc > best_acc: best_acc = val_acc checkpoint_path = CHECKPOINT_DIR / f"{checkpoint_model_name}.pt" torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "val_acc": val_acc }, checkpoint_path ) logger.info(f"Best checkpoint saved at: {checkpoint_path}") early_stopping(val_acc) if early_stopping.early_stop: logger.info("Early stopping triggered.") break return all_preds, all_labels if __name__ == "__main__": print("Trainer utilities ready.")