Save checkpoints also for the adapter since I do not have it after training

#1
by voakit05 - opened
diffusers_repo/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py CHANGED
@@ -321,6 +321,7 @@ def main():
321
  adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
322
 
323
  ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
 
324
 
325
  weight_dtype = torch.float32
326
  if accelerator.mixed_precision == "fp16":
 
321
  adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
322
 
323
  ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
324
+ accelerator.register_for_checkpointing(ip_adapter)
325
 
326
  weight_dtype = torch.float32
327
  if accelerator.mixed_precision == "fp16":