Update train.py
Browse files
train.py
CHANGED
|
@@ -282,7 +282,7 @@ def train(
|
|
| 282 |
grad_accum=8,
|
| 283 |
epochs=1,
|
| 284 |
lr=1e-5,
|
| 285 |
-
warmup_steps=
|
| 286 |
):
|
| 287 |
accelerator = Accelerator(
|
| 288 |
mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
|
|
@@ -491,6 +491,7 @@ if __name__ == "__main__":
|
|
| 491 |
|
| 492 |
model = Transformer(args)
|
| 493 |
|
|
|
|
| 494 |
RESUME_FROM = "checkpoints/step_200000.pt"
|
| 495 |
|
| 496 |
if os.path.exists(RESUME_FROM):
|
|
@@ -505,6 +506,7 @@ if __name__ == "__main__":
|
|
| 505 |
# Old format: checkpoint is directly the model state_dict
|
| 506 |
model.load_state_dict(checkpoint)
|
| 507 |
print(f"[Resume] Loaded model (old format)")
|
|
|
|
| 508 |
|
| 509 |
train(
|
| 510 |
model,
|
|
|
|
| 282 |
grad_accum=8,
|
| 283 |
epochs=1,
|
| 284 |
lr=1e-5,
|
| 285 |
+
warmup_steps=500,
|
| 286 |
):
|
| 287 |
accelerator = Accelerator(
|
| 288 |
mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
|
|
|
|
| 491 |
|
| 492 |
model = Transformer(args)
|
| 493 |
|
| 494 |
+
'''
|
| 495 |
RESUME_FROM = "checkpoints/step_200000.pt"
|
| 496 |
|
| 497 |
if os.path.exists(RESUME_FROM):
|
|
|
|
| 506 |
# Old format: checkpoint is directly the model state_dict
|
| 507 |
model.load_state_dict(checkpoint)
|
| 508 |
print(f"[Resume] Loaded model (old format)")
|
| 509 |
+
'''
|
| 510 |
|
| 511 |
train(
|
| 512 |
model,
|