File size: 3,205 Bytes
fff452e
 
95062a5
fff452e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import wandb
from tqdm import tqdm
from src.evaluate import evaluate
import torch

def train_model(model, optimizer, configs, loaders):

    # Login wandb
    wandb.login()

    # Init Wandb for tracking training phase
    wandb.init(
        project=configs["project"],
        name=configs["name"],
        config=configs
    )

    # Log gradient of parameter
    wandb.watch(model, log="all")

    # Save model checkpoint by best F1
    best_val_f1 = 0.0

    # Training Loop
    for epoch in range(1, configs["epochs"] + 1):
        model.train()
        total_loss = 0.0

        # Create progress bar
        train_bar = tqdm(loaders['train'], desc=f"Train Epoch {epoch}/{configs['epochs']}")

        for batch_idx, (x, y, _) in enumerate(train_bar, start=1):
            mask = (y != -1)
            loss = model(x, y, mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            train_bar.set_postfix(batch_loss=loss.item(), avg_loss=total_loss / batch_idx)
        
        # Evaluate model after each epoch
        avg_train_loss = total_loss / len(loaders['train'])
        train_precision, train_recall, train_f1, train_acc, _, _ = evaluate(model, loaders['train'], count_loss=False)
        val_precision, val_recall, val_f1, val_acc, avg_val_loss, _= evaluate(model, loaders['val'], count_loss=True)
        
        # Log metric for train and val set
        print(f"Epoch {epoch}: train_loss={avg_train_loss:.4f}, train_f1={train_f1:.4f}, val_loss={avg_val_loss:.4f}, val_f1={val_f1:.4f}")
        wandb.log({

            "epoch": epoch,

            # Group: Training metrics
            "Train/Loss": avg_train_loss,
            "Train/Precision": train_precision,
            "Train/Recall": train_recall,
            "Train/F1": train_f1,
            "Train/Accuracy": train_acc,
            
            # Group: Validation metrics
            "Val/Loss": avg_val_loss,
            "Val/Precision": val_precision,
            "Val/Recall": val_recall,
            "Val/F1": val_f1,
            "Val/Accuracy": val_acc
        })

        # Save best model based on val_f1
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            ckpt_path = f"./models/best_epoch_{epoch}.pt"
            torch.save(model.state_dict(), ckpt_path)
            wandb.save(ckpt_path)
            print(f"Saved imporved model to {ckpt_path}")
        
        print()
    
    # Load best model before test
    print(f"Loading best model from {ckpt_path} for final evaluation...")
    model.load_state_dict(torch.load(ckpt_path))
    print("Done \n")

        
    # Log metric for test set
    print("Evaluation on test set ...")
    test_precision, test_recall, test_f1, test_acc, avg_test_loss, report = evaluate(model, loaders['test'], count_loss=True, report=True)
    wandb.log({
        "Test/Loss": avg_test_loss,
        "Test/Precision": test_precision,
        "Test/Recall": test_recall,
        "Test/F1": test_f1,
        "Test/Accuracy": test_acc,
    })
    print(f"Test_loss={avg_test_loss:.4f}, Test_f1={test_f1:.4f}")
    print(report)

    # Finish W&B run
    wandb.finish()