import wandb from tqdm import tqdm from src.evaluate import evaluate import torch def train_model(model, optimizer, configs, loaders): # Login wandb wandb.login() # Init Wandb for tracking training phase wandb.init( project=configs["project"], name=configs["name"], config=configs ) # Log gradient of parameter wandb.watch(model, log="all") # Save model checkpoint by best F1 best_val_f1 = 0.0 # Training Loop for epoch in range(1, configs["epochs"] + 1): model.train() total_loss = 0.0 # Create progress bar train_bar = tqdm(loaders['train'], desc=f"Train Epoch {epoch}/{configs['epochs']}") for batch_idx, (x, y, _) in enumerate(train_bar, start=1): mask = (y != -1) loss = model(x, y, mask) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() train_bar.set_postfix(batch_loss=loss.item(), avg_loss=total_loss / batch_idx) # Evaluate model after each epoch avg_train_loss = total_loss / len(loaders['train']) train_precision, train_recall, train_f1, train_acc, _, _ = evaluate(model, loaders['train'], count_loss=False) val_precision, val_recall, val_f1, val_acc, avg_val_loss, _= evaluate(model, loaders['val'], count_loss=True) # Log metric for train and val set print(f"Epoch {epoch}: train_loss={avg_train_loss:.4f}, train_f1={train_f1:.4f}, val_loss={avg_val_loss:.4f}, val_f1={val_f1:.4f}") wandb.log({ "epoch": epoch, # Group: Training metrics "Train/Loss": avg_train_loss, "Train/Precision": train_precision, "Train/Recall": train_recall, "Train/F1": train_f1, "Train/Accuracy": train_acc, # Group: Validation metrics "Val/Loss": avg_val_loss, "Val/Precision": val_precision, "Val/Recall": val_recall, "Val/F1": val_f1, "Val/Accuracy": val_acc }) # Save best model based on val_f1 if val_f1 > best_val_f1: best_val_f1 = val_f1 ckpt_path = f"./models/best_epoch_{epoch}.pt" torch.save(model.state_dict(), ckpt_path) wandb.save(ckpt_path) print(f"Saved imporved model to {ckpt_path}") print() # Load best model before test print(f"Loading best model from {ckpt_path} for final evaluation...") model.load_state_dict(torch.load(ckpt_path)) print("Done \n") # Log metric for test set print("Evaluation on test set ...") test_precision, test_recall, test_f1, test_acc, avg_test_loss, report = evaluate(model, loaders['test'], count_loss=True, report=True) wandb.log({ "Test/Loss": avg_test_loss, "Test/Precision": test_precision, "Test/Recall": test_recall, "Test/F1": test_f1, "Test/Accuracy": test_acc, }) print(f"Test_loss={avg_test_loss:.4f}, Test_f1={test_f1:.4f}") print(report) # Finish W&B run wandb.finish()