Upload adversarial_training_clip_with_object_token.py
Browse files
train/adversarial_training_clip_with_object_token.py
CHANGED
|
@@ -108,6 +108,8 @@ def main(args):
|
|
| 108 |
assert str(args.start_step) in args.optimizer_state
|
| 109 |
assert args.pretrained in ['', 'none']
|
| 110 |
args.pretrained = args.optimizer_state.replace('_opt', '')
|
|
|
|
|
|
|
| 111 |
model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
|
| 112 |
|
| 113 |
# Remove the Normalize transform by creating a new Compose object
|
|
@@ -128,6 +130,9 @@ def main(args):
|
|
| 128 |
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
| 129 |
model_slots = DINOSAURpp(cfg_dict)
|
| 130 |
proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
# get data
|
|
@@ -505,13 +510,13 @@ def train_one_epoch(
|
|
| 505 |
wandb.log(log_data)
|
| 506 |
|
| 507 |
# save 10 models over the course of training
|
| 508 |
-
if args.save_checkpoints and (step_total % (args.steps //
|
| 509 |
# save model and optimizer state_dict
|
| 510 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
| 511 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
|
| 512 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
| 513 |
# every 200 steps, save a fallback model, which gets overwritten
|
| 514 |
-
if step_total %
|
| 515 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
| 516 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
|
| 517 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
|
@@ -523,7 +528,7 @@ def train_one_epoch(
|
|
| 523 |
if step_total >= args.steps:
|
| 524 |
break
|
| 525 |
|
| 526 |
-
torch.cuda.empty_cache()
|
| 527 |
return step_total
|
| 528 |
|
| 529 |
|
|
|
|
| 108 |
assert str(args.start_step) in args.optimizer_state
|
| 109 |
assert args.pretrained in ['', 'none']
|
| 110 |
args.pretrained = args.optimizer_state.replace('_opt', '')
|
| 111 |
+
args.pretrained_proj_head = args.optimizer_state.replace('_opt', '_proj_head')
|
| 112 |
+
|
| 113 |
model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
|
| 114 |
|
| 115 |
# Remove the Normalize transform by creating a new Compose object
|
|
|
|
| 130 |
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
| 131 |
model_slots = DINOSAURpp(cfg_dict)
|
| 132 |
proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
| 133 |
+
if args.optimizer_state != '':
|
| 134 |
+
proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
|
| 135 |
+
|
| 136 |
|
| 137 |
|
| 138 |
# get data
|
|
|
|
| 510 |
wandb.log(log_data)
|
| 511 |
|
| 512 |
# save 10 models over the course of training
|
| 513 |
+
if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
|
| 514 |
# save model and optimizer state_dict
|
| 515 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
| 516 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
|
| 517 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
| 518 |
# every 200 steps, save a fallback model, which gets overwritten
|
| 519 |
+
if step_total % 2000 == 0:
|
| 520 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
| 521 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
|
| 522 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
|
|
|
| 528 |
if step_total >= args.steps:
|
| 529 |
break
|
| 530 |
|
| 531 |
+
# torch.cuda.empty_cache()
|
| 532 |
return step_total
|
| 533 |
|
| 534 |
|