Update trainer.py
Browse files- trainer.py +2 -2
trainer.py
CHANGED
|
@@ -251,7 +251,7 @@ if __name__ == "__main__":
|
|
| 251 |
# 🔧 Optimizer + Scheduler
|
| 252 |
# ------------------------------
|
| 253 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 254 |
-
total_steps = len(train_loader) * num_epochs // max(1,
|
| 255 |
warmup_steps = int(0.1 * total_steps)
|
| 256 |
|
| 257 |
scheduler = get_linear_schedule_with_warmup(
|
|
@@ -269,7 +269,7 @@ if __name__ == "__main__":
|
|
| 269 |
for epoch in range(num_epochs):
|
| 270 |
tr_loss = train_one_epoch(
|
| 271 |
model, train_loader, optimizer, device=device,
|
| 272 |
-
scheduler=scheduler, grad_accum_steps=
|
| 273 |
amp=True, max_grad_norm=1.0,
|
| 274 |
)
|
| 275 |
dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
|
|
|
|
| 251 |
# 🔧 Optimizer + Scheduler
|
| 252 |
# ------------------------------
|
| 253 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 254 |
+
total_steps = len(train_loader) * num_epochs // max(1, grad_accum)
|
| 255 |
warmup_steps = int(0.1 * total_steps)
|
| 256 |
|
| 257 |
scheduler = get_linear_schedule_with_warmup(
|
|
|
|
| 269 |
for epoch in range(num_epochs):
|
| 270 |
tr_loss = train_one_epoch(
|
| 271 |
model, train_loader, optimizer, device=device,
|
| 272 |
+
scheduler=scheduler, grad_accum_steps=grad_accum,
|
| 273 |
amp=True, max_grad_norm=1.0,
|
| 274 |
)
|
| 275 |
dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
|