Upload adversarial_training_clip_with_object_token.py
Browse files
train/adversarial_training_clip_with_object_token.py
CHANGED
|
@@ -265,7 +265,7 @@ def main(args):
|
|
| 265 |
|
| 266 |
# save final model
|
| 267 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
|
| 268 |
-
torch.save(unwrap_model(proj_head).
|
| 269 |
|
| 270 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
|
| 271 |
|
|
@@ -505,13 +505,15 @@ def train_one_epoch(
|
|
| 505 |
wandb.log(log_data)
|
| 506 |
|
| 507 |
# save 10 models over the course of training
|
| 508 |
-
if args.save_checkpoints and (step_total % (args.steps //
|
| 509 |
# save model and optimizer state_dict
|
| 510 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
|
|
|
| 511 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
| 512 |
# every 200 steps, save a fallback model, which gets overwritten
|
| 513 |
-
if step_total %
|
| 514 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
|
|
|
| 515 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
| 516 |
# remove old fallback models
|
| 517 |
for file in os.listdir(f'{args.output_dir}/checkpoints'):
|
|
|
|
| 265 |
|
| 266 |
# save final model
|
| 267 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
|
| 268 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/final_proj_head.pt')
|
| 269 |
|
| 270 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
|
| 271 |
|
|
|
|
| 505 |
wandb.log(log_data)
|
| 506 |
|
| 507 |
# save 10 models over the course of training
|
| 508 |
+
if args.save_checkpoints and (step_total % (args.steps // 1) == 0):
|
| 509 |
# save model and optimizer state_dict
|
| 510 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
| 511 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
|
| 512 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
| 513 |
# every 200 steps, save a fallback model, which gets overwritten
|
| 514 |
+
if step_total % 2 == 0:
|
| 515 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
| 516 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
|
| 517 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
| 518 |
# remove old fallback models
|
| 519 |
for file in os.listdir(f'{args.output_dir}/checkpoints'):
|