File size: 3,948 Bytes
ebdb5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from tqdm import tqdm
import logging
from pt_variety_identifier.src.bert.model import LanguageIdentfier
from pt_variety_identifier.src.bert.tester import Tester
import math
import os


class Trainer:
    def __init__(self, train_dataset, params, validation_dataset_dict=None) -> None:
        self.train_dataset = train_dataset
        
        self.model = LanguageIdentfier(params['model_name'])

        self.epochs = params['epochs']
        self.lr = 1e-5
        self.loss_fn = torch.nn.BCELoss()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        self.early_stoping = params['early_stoping']

        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, patience=self.early_stoping//2, verbose=True)

        self.device = params['device']
        self.CURRENT_PATH = params['CURRENT_PATH']
        self.CURRENT_TIME = params['CURRENT_TIME']
        self.training_domain = params['training_domain'] if 'training_domain' in params else 'all'

        self.validator = None

        print(f"Using {self.device} device")

        if validation_dataset_dict:
            self.validator = Tester(
                test_dataset_dict=validation_dataset_dict,
                model=self.model,
                train_domain=self.training_domain,
            )

    def _epoch_iter(self):
        self.model.train()
        self.model.to(self.device)
        self.optimizer.zero_grad()

        with torch.enable_grad():
            total_loss = 0

            for batch in tqdm(self.train_dataset):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device, dtype=torch.float)

                outputs = self.model(
                    input_ids, attention_mask=attention_mask).squeeze(dim=1)
                loss = self.loss_fn(outputs, labels)

                loss.backward()

                self.optimizer.step()
                self.optimizer.zero_grad()

                total_loss += loss.item()

            self.scheduler.step(total_loss)

            return total_loss / len(self.train_dataset)

    def train(self):
        logging.info(f"Training model in {self.device}...")

        best_results = {
            'f1': -math.inf,
            'accuracy': -math.inf,
            'precision': -math.inf,
            'recall': -math.inf,
            'loss': math.inf
        }

        for epoch in tqdm(range(self.epochs)):
            training_loss = self._epoch_iter()

            if self.validator:
                results = self.validator.validate()
                
                logging.info(f"Results for {self.training_domain} domain: {results} Epoch: {epoch}")

                if results['loss'] < best_results['loss'] and results['f1'] > best_results['f1']:
                    logging.info(
                        f"Saving best model... Domain:{self.training_domain} F1:{results['f1']} and Test Loss:{results['loss']}")
                    
                    best_results['loss'] = results['loss']
                    best_results['accuracy'] = results['accuracy']
                    best_results['f1'] = results['f1']
                    best_results['recall'] = results['recall']
                    best_results['precision'] = results['precision']
                    
                    torch.save(self.model.state_dict(), os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", f'{self.training_domain}.pt'))
                else:
                    logging.info(f"Not saving model... F1:{results['f1']} and Test Loss:{results['loss']}")

            logging.info(f"Epoch {epoch} Training Loss: {training_loss}")

            if training_loss < 0.1:
                logging.info(f"Training Loss is too low, stoping training...")
                break

        return best_results