changes to training script
Browse files- scripts/train.py +10 -3
scripts/train.py
CHANGED
|
@@ -102,7 +102,6 @@ def main():
|
|
| 102 |
):
|
| 103 |
param.requires_grad = True
|
| 104 |
trainable_params.append(name)
|
| 105 |
-
print(f"Unfreezing parameter: {name}")
|
| 106 |
|
| 107 |
|
| 108 |
print(f"Total trainable parameters: {len(trainable_params)}")
|
|
@@ -192,13 +191,21 @@ def main():
|
|
| 192 |
model.zero_grad()
|
| 193 |
|
| 194 |
# Check for existing checkpoint
|
|
|
|
|
|
|
| 195 |
checkpoint_dir = None
|
| 196 |
if os.path.isdir(training_args.output_dir):
|
| 197 |
-
checkpoints = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
if checkpoints:
|
| 199 |
-
|
|
|
|
| 200 |
print(f"Resuming from checkpoint: {checkpoint_dir}")
|
| 201 |
|
|
|
|
| 202 |
# Train
|
| 203 |
print("Starting training...")
|
| 204 |
trainer.train(resume_from_checkpoint=checkpoint_dir)
|
|
|
|
| 102 |
):
|
| 103 |
param.requires_grad = True
|
| 104 |
trainable_params.append(name)
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
print(f"Total trainable parameters: {len(trainable_params)}")
|
|
|
|
| 191 |
model.zero_grad()
|
| 192 |
|
| 193 |
# Check for existing checkpoint
|
| 194 |
+
import re
|
| 195 |
+
|
| 196 |
checkpoint_dir = None
|
| 197 |
if os.path.isdir(training_args.output_dir):
|
| 198 |
+
checkpoints = [
|
| 199 |
+
os.path.join(training_args.output_dir, d)
|
| 200 |
+
for d in os.listdir(training_args.output_dir)
|
| 201 |
+
if re.match(r"checkpoint-\d+", d)
|
| 202 |
+
]
|
| 203 |
if checkpoints:
|
| 204 |
+
# Extract step numbers and find the highest
|
| 205 |
+
checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1]))
|
| 206 |
print(f"Resuming from checkpoint: {checkpoint_dir}")
|
| 207 |
|
| 208 |
+
|
| 209 |
# Train
|
| 210 |
print("Starting training...")
|
| 211 |
trainer.train(resume_from_checkpoint=checkpoint_dir)
|