yeomtong commited on
Commit
533cf8d
·
verified ·
1 Parent(s): db3c303

Delete training.py

Browse files
Files changed (1) hide show
  1. training.py +0 -182
training.py DELETED
@@ -1,182 +0,0 @@
1
- from SRL_MODEL import data_prep, SRL_BERT_model
2
- import torch
3
- from transformers import AutoTokenizer, get_linear_schedule_with_warmup
4
- from sklearn.metrics import f1_score
5
- import pickle
6
-
7
- def save_pkl(tgt_list, svg_path):
8
- with open(svg_path, "wb") as f:
9
- pickle.dump(tgt_list, f)
10
-
11
- def load_pkl(path) :
12
- with open(path, "rb") as f:
13
- data = pickle.load(f)
14
- return data
15
-
16
-
17
- def train_one_epoch(
18
- model,
19
- dataloader,
20
- optimizer,
21
- device="cuda",
22
- scheduler=None,
23
- grad_accum_steps=1,
24
- amp=True,
25
- max_grad_norm=1.0,
26
- ):
27
- model.train()
28
- total_loss, n_steps = 0.0, 0
29
-
30
- use_amp = amp and torch.cuda.is_available()
31
- scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
32
-
33
- optimizer.zero_grad(set_to_none=True)
34
-
35
- for step, batch in enumerate(dataloader, 1):
36
- batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
37
-
38
- with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.float16):
39
- _, loss = model(**batch) # model must return (logits, loss)
40
-
41
- total_loss += float(loss.detach().item())
42
- n_steps += 1
43
-
44
- loss = loss / grad_accum_steps # for accumulation
45
-
46
- if use_amp:
47
- scaler.scale(loss).backward()
48
- else:
49
- loss.backward()
50
-
51
- if step % grad_accum_steps == 0:
52
- if use_amp:
53
- scaler.unscale_(optimizer)
54
- nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
55
-
56
- if use_amp:
57
- scaler.step(optimizer)
58
- scaler.update()
59
- else:
60
- optimizer.step()
61
-
62
- optimizer.zero_grad(set_to_none=True)
63
-
64
- if scheduler is not None:
65
- scheduler.step()
66
-
67
- return total_loss / max(1, n_steps)
68
-
69
- #This is Validation
70
- @torch.no_grad()
71
- def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
72
-
73
- model.eval()
74
- total_loss, n_batches = 0.0, 0
75
- all_preds, all_golds = [], []
76
-
77
- for batch in dataloader:
78
- gold = batch["labels"] # keep on CPU for masking
79
- mask = (gold != -100)
80
-
81
- batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
82
- logits, loss = model(**batch) # loss computed once here
83
- total_loss += float(loss.item()); n_batches += 1
84
-
85
- preds = logits.argmax(-1).cpu()
86
- all_preds.extend(preds[mask].tolist())
87
- all_golds.extend(gold[mask].tolist())
88
-
89
- f1 = f1_score(all_golds, all_preds, average=average)
90
- return total_loss / max(1, n_batches), f1
91
-
92
-
93
- if __name__ =='__main__':
94
- bert_name = "bert-base-cased"
95
- tokenizer = AutoTokenizer.from_pretrained(bert_name)
96
-
97
- device = "cuda" if torch.cuda.is_available() else "cpu"
98
- # tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
99
-
100
- #data_class_train/dev/test from data_prep
101
- train_dev_test_data = data_class_train + data_class_dev + data_class_test
102
- train_bf_loader, dev_bf_loader,test_bf_loader, label2id, id2label = data_prep.data_processing_for_loader(train_dev_test_data, data_class_train, data_class_dev, data_class_test, tokenizer)
103
-
104
- pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
105
- collate = lambda b: data_prep.srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
106
-
107
- train_loader = data_prep.DataLoader(train_bf_loader, batch_size=16, shuffle=True, collate_fn=collate)
108
- dev_loader = data_prep.DataLoader(dev_bf_loader, batch_size=16, shuffle=False, collate_fn=collate)
109
- test_loader = data_prep.DataLoader(test_bf_loader, batch_size=16, shuffle=False, collate_fn=collate)
110
-
111
- # bert_name = "bert-base-cased"
112
- # tokenizer = AutoTokenizer.from_pretrained(bert_name)
113
-
114
- # device = "cuda" if torch.cuda.is_available() else "cpu"
115
-
116
- model = SRL_BERT_model.PredicateAwareSRL(
117
- bert_name=bert_name,
118
- num_labels=len(label2id),
119
- use_indicator=True,
120
- use_distance =True,
121
- indicator_dim= 10,
122
- lstm_hidden=768,
123
- mlp_hidden=300,
124
- pos_dim= 50,
125
- max_distance = 128,
126
- dropout=0.1
127
- ).to(device)
128
-
129
- # Optimizer (you may want to use AdamW with weight decay and a scheduler)
130
- num_epochs = 12
131
- grad_accum_steps = 1
132
- optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
133
-
134
- # # Train a couple of epochs (on toy data this is just to check shapes run)
135
- # for epoch in range(3):
136
- # tr_loss = train_one_epoch(model, train_loader, optimizer, device=device)
137
- # f1 = evaluate_token_f1(model, dev_loader, id2label=id2label, device=device)
138
- # print(f"Epoch {epoch+1} | loss={tr_loss:.4f} | token-F1={f1:.4f}")
139
-
140
- total_steps = len(train_loader) * num_epochs // max(1, grad_accum_steps)
141
- warmup_steps = int(0.1 * total_steps)
142
-
143
- scheduler = get_linear_schedule_with_warmup(
144
- optimizer,
145
- num_warmup_steps=warmup_steps,
146
- num_training_steps=total_steps
147
- )
148
-
149
- history = {"epoch": [], "train_loss": [], "dev_loss": [], "dev_f1": []}
150
-
151
- best_dev, best_path = -1.0, "best_srl.ckpt"
152
- for epoch in range(num_epochs):
153
- tr_loss = train_one_epoch(
154
- model, train_loader, optimizer, device=device,
155
- scheduler=scheduler, grad_accum_steps=grad_accum_steps, amp=True, max_grad_norm=1.0
156
- )
157
- dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
158
-
159
-
160
- history["epoch"].append(epoch + 1)
161
- history["train_loss"].append(tr_loss)
162
- history["dev_loss"].append(dev_loss)
163
- history["dev_f1"].append(dev_f1)
164
-
165
- print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}")
166
-
167
- if dev_f1 > best_dev:
168
- best_dev = dev_f1
169
- torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path)
170
- print(" ↳ new best dev; saved.")
171
-
172
- save_pkl(history, #save_path_for_loss)
173
-
174
- # best_dev, best_path = -1.0, "best_srl.ckpt"
175
- # for epoch in range(num_epochs):
176
- # tr_loss = train_one_epoch(model, train_loader, optimizer, device=device)
177
- # dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
178
- # print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}")
179
- # if dev_f1 > best_dev:
180
- # best_dev = dev_f1
181
- # torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path)
182
- # print(" ↳ new best dev; saved.")