aaljabari commited on
Commit
9762f2c
·
verified ·
1 Parent(s): 531c1d5

Create BaseTrainer.py

Browse files
Files changed (1) hide show
  1. Nested/trainers/BaseTrainer.py +125 -0
Nested/trainers/BaseTrainer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import natsort
5
+ import glob
6
+ from huggingface_hub import hf_hub_download, snapshot_download
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BaseTrainer:
12
+ def __init__(
13
+ self,
14
+ model=None,
15
+ max_epochs=50,
16
+ optimizer=None,
17
+ scheduler=None,
18
+ loss=None,
19
+ train_dataloader=None,
20
+ val_dataloader=None,
21
+ test_dataloader=None,
22
+ log_interval=10,
23
+ summary_writer=None,
24
+ output_path=None,
25
+ clip=5,
26
+ patience=5
27
+ ):
28
+ self.model = model
29
+ self.max_epochs = max_epochs
30
+ self.train_dataloader = train_dataloader
31
+ self.val_dataloader = val_dataloader
32
+ self.test_dataloader = test_dataloader
33
+ self.optimizer = optimizer
34
+ self.scheduler = scheduler
35
+ self.loss = loss
36
+ self.log_interval = log_interval
37
+ self.summary_writer = summary_writer
38
+ self.output_path = output_path
39
+ self.current_timestep = 0
40
+ self.current_epoch = 0
41
+ self.clip = clip
42
+ self.patience = patience
43
+
44
+ def tag(self, dataloader, is_train=True):
45
+ """
46
+ Given a dataloader containing segments, predict the tags
47
+ :param dataloader: torch.utils.data.DataLoader
48
+ :param is_train: boolean - True for training model, False for evaluation
49
+ :return: Iterator
50
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
51
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
52
+ tokens - List[Nested.data.dataset.Token] - list of tokens
53
+ valid_len (B x 1) - int - valiud length of each sequence
54
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
55
+ """
56
+ for subwords, gold_tags, tokens, valid_len in dataloader:
57
+ self.model.train(is_train)
58
+
59
+ if torch.cuda.is_available():
60
+ subwords = subwords.cuda()
61
+ gold_tags = gold_tags.cuda()
62
+
63
+ if is_train:
64
+ self.optimizer.zero_grad()
65
+ logits = self.model(subwords)
66
+ else:
67
+ with torch.no_grad():
68
+ logits = self.model(subwords)
69
+
70
+ yield subwords, gold_tags, tokens, valid_len, logits
71
+
72
+ def segments_to_file(self, segments, filename):
73
+ """
74
+ Write segments to file
75
+ :param segments: [List[Nested.data.dataset.Token]] - list of list of tokens
76
+ :param filename: str - output filename
77
+ :return: None
78
+ """
79
+ with open(filename, "w") as fh:
80
+ results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
81
+ fh.write("Token\tGold Tag\tPredicted Tag\n")
82
+ fh.write(results)
83
+ logging.info("Predictions written to %s", filename)
84
+
85
+ def save(self):
86
+ """
87
+ Save model checkpoint
88
+ :return:
89
+ """
90
+ filename = os.path.join(
91
+ self.output_path,
92
+ "checkpoints",
93
+ "checkpoint_{}.pt".format(self.current_epoch),
94
+ )
95
+
96
+ checkpoint = {
97
+ "model": self.model.state_dict(),
98
+ "optimizer": self.optimizer.state_dict(),
99
+ "epoch": self.current_epoch
100
+ }
101
+
102
+ logger.info("Saving checkpoint to %s", filename)
103
+ torch.save(checkpoint, filename)
104
+
105
+ def load(self, checkpoint_path):
106
+ """
107
+ Load model checkpoint
108
+ :param checkpoint_path: str - path/to/checkpoints
109
+ :return: None
110
+ """
111
+ # checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
112
+ checkpoint_path = natsort.natsorted(checkpoint_path)
113
+ # checkpoint_path = checkpoint_path[-1]
114
+
115
+ logger.info("Loading checkpoint %s", checkpoint_path)
116
+
117
+ device = None if torch.cuda.is_available() else torch.device('cpu')
118
+ # checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
119
+ repo_path = snapshot_download(repo_id="SinaLab/Nested")
120
+
121
+ model_file = os.path.join(repo_path, "checkpoints", "checkpoint_2.pt")
122
+
123
+ checkpoint = torch.load(model_file, map_location=device, weights_only=False)
124
+
125
+ self.model.load_state_dict(checkpoint["model"])