yeomtong commited on
Commit
3eb6561
·
verified ·
1 Parent(s): 0c0db8b

Upload trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +12 -10
trainer.py CHANGED
@@ -110,22 +110,24 @@ def train_one_epoch(
110
  def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
111
  model.eval()
112
  total_loss, n_batches = 0.0, 0
113
- all_preds, all_golds = [], []
114
 
115
  for batch in dataloader:
116
- gold = batch["labels"]
117
- mask = (gold != -100)
118
- batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
119
 
 
120
  logits, loss = model(**batch)
121
  total_loss += float(loss.item()); n_batches += 1
122
 
123
  preds = logits.argmax(-1).cpu()
124
- all_preds.extend(preds[mask].tolist())
125
- all_golds.extend(gold[mask].tolist())
 
 
 
 
126
 
127
- f1 = f1_score(all_golds, all_preds, average=average)
128
- return total_loss / max(1, n_batches), f1
129
 
130
 
131
  # ==============================================================
@@ -251,7 +253,7 @@ if __name__ == "__main__":
251
  # ------------------------------
252
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
253
  total_steps = len(train_loader) * num_epochs // max(1, grad_accum_steps)
254
- warmup_steps = int(warmup_ratio * total_steps)
255
 
256
  scheduler = get_linear_schedule_with_warmup(
257
  optimizer,
@@ -268,7 +270,7 @@ if __name__ == "__main__":
268
  for epoch in range(num_epochs):
269
  tr_loss = train_one_epoch(
270
  model, train_loader, optimizer, device=device,
271
- scheduler=scheduler, grad_accum_steps=grad_accum,
272
  amp=True, max_grad_norm=1.0,
273
  )
274
  dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
 
110
  def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
111
  model.eval()
112
  total_loss, n_batches = 0.0, 0
113
+ correct, total = 0, 0
114
 
115
  for batch in dataloader:
116
+ gold = batch["labels"] # CPU
117
+ mask = (gold != -100) # valid word positions
 
118
 
119
+ batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
120
  logits, loss = model(**batch)
121
  total_loss += float(loss.item()); n_batches += 1
122
 
123
  preds = logits.argmax(-1).cpu()
124
+ # micro-F1 == accuracy for single-label classification
125
+ correct += int((preds[mask] == gold[mask]).sum())
126
+ total += int(mask.sum())
127
+
128
+ micro_f1 = (correct / total) if total > 0 else 0.0
129
+ return total_loss / max(1, n_batches), micro_f1
130
 
 
 
131
 
132
 
133
  # ==============================================================
 
253
  # ------------------------------
254
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
255
  total_steps = len(train_loader) * num_epochs // max(1, grad_accum_steps)
256
+ warmup_steps = int(0.1 * total_steps)
257
 
258
  scheduler = get_linear_schedule_with_warmup(
259
  optimizer,
 
270
  for epoch in range(num_epochs):
271
  tr_loss = train_one_epoch(
272
  model, train_loader, optimizer, device=device,
273
+ scheduler=scheduler, grad_accum_steps=grad_accum_steps,
274
  amp=True, max_grad_norm=1.0,
275
  )
276
  dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)