Upload trainer.py
Browse files- trainer.py +12 -10
trainer.py
CHANGED
|
@@ -110,22 +110,24 @@ def train_one_epoch(
|
|
| 110 |
def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
|
| 111 |
model.eval()
|
| 112 |
total_loss, n_batches = 0.0, 0
|
| 113 |
-
|
| 114 |
|
| 115 |
for batch in dataloader:
|
| 116 |
-
gold = batch["labels"]
|
| 117 |
-
mask = (gold != -100)
|
| 118 |
-
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
|
| 119 |
|
|
|
|
| 120 |
logits, loss = model(**batch)
|
| 121 |
total_loss += float(loss.item()); n_batches += 1
|
| 122 |
|
| 123 |
preds = logits.argmax(-1).cpu()
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
f1 = f1_score(all_golds, all_preds, average=average)
|
| 128 |
-
return total_loss / max(1, n_batches), f1
|
| 129 |
|
| 130 |
|
| 131 |
# ==============================================================
|
|
@@ -251,7 +253,7 @@ if __name__ == "__main__":
|
|
| 251 |
# ------------------------------
|
| 252 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 253 |
total_steps = len(train_loader) * num_epochs // max(1, grad_accum_steps)
|
| 254 |
-
warmup_steps = int(
|
| 255 |
|
| 256 |
scheduler = get_linear_schedule_with_warmup(
|
| 257 |
optimizer,
|
|
@@ -268,7 +270,7 @@ if __name__ == "__main__":
|
|
| 268 |
for epoch in range(num_epochs):
|
| 269 |
tr_loss = train_one_epoch(
|
| 270 |
model, train_loader, optimizer, device=device,
|
| 271 |
-
scheduler=scheduler, grad_accum_steps=
|
| 272 |
amp=True, max_grad_norm=1.0,
|
| 273 |
)
|
| 274 |
dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
|
|
|
|
| 110 |
def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
|
| 111 |
model.eval()
|
| 112 |
total_loss, n_batches = 0.0, 0
|
| 113 |
+
correct, total = 0, 0
|
| 114 |
|
| 115 |
for batch in dataloader:
|
| 116 |
+
gold = batch["labels"] # CPU
|
| 117 |
+
mask = (gold != -100) # valid word positions
|
|
|
|
| 118 |
|
| 119 |
+
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
|
| 120 |
logits, loss = model(**batch)
|
| 121 |
total_loss += float(loss.item()); n_batches += 1
|
| 122 |
|
| 123 |
preds = logits.argmax(-1).cpu()
|
| 124 |
+
# micro-F1 == accuracy for single-label classification
|
| 125 |
+
correct += int((preds[mask] == gold[mask]).sum())
|
| 126 |
+
total += int(mask.sum())
|
| 127 |
+
|
| 128 |
+
micro_f1 = (correct / total) if total > 0 else 0.0
|
| 129 |
+
return total_loss / max(1, n_batches), micro_f1
|
| 130 |
|
|
|
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
# ==============================================================
|
|
|
|
| 253 |
# ------------------------------
|
| 254 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 255 |
total_steps = len(train_loader) * num_epochs // max(1, grad_accum_steps)
|
| 256 |
+
warmup_steps = int(0.1 * total_steps)
|
| 257 |
|
| 258 |
scheduler = get_linear_schedule_with_warmup(
|
| 259 |
optimizer,
|
|
|
|
| 270 |
for epoch in range(num_epochs):
|
| 271 |
tr_loss = train_one_epoch(
|
| 272 |
model, train_loader, optimizer, device=device,
|
| 273 |
+
scheduler=scheduler, grad_accum_steps=grad_accum_steps,
|
| 274 |
amp=True, max_grad_norm=1.0,
|
| 275 |
)
|
| 276 |
dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
|