#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. import json import os import torch import torch.distributed.checkpoint as dist_cp from peft import get_peft_model_state_dict from safetensors.torch import load_file, save_file from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner, DefaultSavePlanner) from torch.distributed.checkpoint.optimizer import \ load_sharded_optimizer_state_dict from torch.distributed.fsdp import (FullOptimStateDictConfig, FullStateDictConfig) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from fastvideo.utils.logging_ import main_print def save_checkpoint_optimizer(model, optimizer, rank, output_dir, step, discriminator=False): with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): cpu_state = model.state_dict() optim_state = FSDP.optim_state_dict( model, optimizer, ) # todo move to get_state_dict save_dir = os.path.join(output_dir, f"checkpoint-{step}") os.makedirs(save_dir, exist_ok=True) # save using safetensors if rank <= 0 and not discriminator: weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors") save_file(cpu_state, weight_path) config_dict = dict(model.config) config_dict.pop('dtype') config_path = os.path.join(save_dir, "config.json") # save dict as json with open(config_path, "w") as f: json.dump(config_dict, f, indent=4) optimizer_path = os.path.join(save_dir, "optimizer.pt") torch.save(optim_state, optimizer_path) else: weight_path = os.path.join(save_dir, "discriminator_pytorch_model.safetensors") save_file(cpu_state, weight_path) optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt") torch.save(optim_state, optimizer_path) main_print(f"--> checkpoint saved at step {step}") def save_checkpoint(transformer, rank, output_dir, step, epoch): main_print(f"--> saving checkpoint at step {step}") with FSDP.state_dict_type( transformer, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): cpu_state = transformer.state_dict() # todo move to get_state_dict if rank <= 0: save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}") os.makedirs(save_dir, exist_ok=True) # save using safetensors weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors") save_file(cpu_state, weight_path) config_dict = dict(transformer.config) if "dtype" in config_dict: del config_dict["dtype"] # TODO config_path = os.path.join(save_dir, "config.json") # save dict as json with open(config_path, "w") as f: json.dump(config_dict, f, indent=4) main_print(f"--> checkpoint saved at step {step}") def save_checkpoint_generator_discriminator( model, optimizer, discriminator, discriminator_optimizer, rank, output_dir, step, ): with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): cpu_state = model.state_dict() # todo move to get_state_dict save_dir = os.path.join(output_dir, f"checkpoint-{step}") os.makedirs(save_dir, exist_ok=True) hf_weight_dir = os.path.join(save_dir, "hf_weights") os.makedirs(hf_weight_dir, exist_ok=True) # save using safetensors if rank <= 0: config_dict = dict(model.config) config_path = os.path.join(hf_weight_dir, "config.json") # save dict as json with open(config_path, "w") as f: json.dump(config_dict, f, indent=4) weight_path = os.path.join(hf_weight_dir, "diffusion_pytorch_model.safetensors") save_file(cpu_state, weight_path) main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}") model_weight_dir = os.path.join(save_dir, "model_weights_state") os.makedirs(model_weight_dir, exist_ok=True) model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state") os.makedirs(model_optimizer_dir, exist_ok=True) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): optim_state = FSDP.optim_state_dict(model, optimizer) model_state = model.state_dict() weight_state_dict = {"model": model_state} dist_cp.save_state_dict( state_dict=weight_state_dict, storage_writer=dist_cp.FileSystemWriter(model_weight_dir), planner=DefaultSavePlanner(), ) optimizer_state_dict = {"optimizer": optim_state} dist_cp.save_state_dict( state_dict=optimizer_state_dict, storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir), planner=DefaultSavePlanner(), ) discriminator_fsdp_state_dir = os.path.join(save_dir, "discriminator_fsdp_state") os.makedirs(discriminator_fsdp_state_dir, exist_ok=True) with FSDP.state_dict_type( discriminator, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): optim_state = FSDP.optim_state_dict(discriminator, discriminator_optimizer) model_state = discriminator.state_dict() state_dict = {"optimizer": optim_state, "model": model_state} if rank <= 0: discriminator_fsdp_state_fil = os.path.join( discriminator_fsdp_state_dir, "discriminator_state.pt") torch.save(state_dict, discriminator_fsdp_state_fil) main_print("--> saved FSDP state checkpoint") def load_sharded_model(model, optimizer, model_dir, optimizer_dir): with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): weight_state_dict = {"model": model.state_dict()} optim_state = load_sharded_optimizer_state_dict( model_state_dict=weight_state_dict["model"], optimizer_key="optimizer", storage_reader=dist_cp.FileSystemReader(optimizer_dir), ) optim_state = optim_state["optimizer"] flattened_osd = FSDP.optim_state_dict_to_load( model=model, optim=optimizer, optim_state_dict=optim_state) optimizer.load_state_dict(flattened_osd) dist_cp.load_state_dict( state_dict=weight_state_dict, storage_reader=dist_cp.FileSystemReader(model_dir), planner=DefaultLoadPlanner(), ) model_state = weight_state_dict["model"] model.load_state_dict(model_state) main_print(f"--> loaded model and optimizer from path {model_dir}") return model, optimizer def load_full_state_model(model, optimizer, checkpoint_file, rank): with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): discriminator_state = torch.load(checkpoint_file) model_state = discriminator_state["model"] if rank <= 0: optim_state = discriminator_state["optimizer"] else: optim_state = None model.load_state_dict(model_state) discriminator_optim_state = FSDP.optim_state_dict_to_load( model=model, optim=optimizer, optim_state_dict=optim_state) optimizer.load_state_dict(discriminator_optim_state) main_print( f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}" ) return model, optimizer def resume_training_generator_discriminator(model, optimizer, discriminator, discriminator_optimizer, checkpoint_dir, rank): step = int(checkpoint_dir.split("-")[-1]) model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state") model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state") model, optimizer = load_sharded_model(model, optimizer, model_weight_dir, model_optimizer_dir) discriminator_ckpt_file = os.path.join(checkpoint_dir, "discriminator_fsdp_state", "discriminator_state.pt") discriminator, discriminator_optimizer = load_full_state_model( discriminator, discriminator_optimizer, discriminator_ckpt_file, rank) return model, optimizer, discriminator, discriminator_optimizer, step def resume_training(model, optimizer, checkpoint_dir, discriminator=False): weight_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model.safetensors") if discriminator: weight_path = os.path.join(checkpoint_dir, "discriminator_pytorch_model.safetensors") model_weights = load_file(weight_path) with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): current_state = model.state_dict() current_state.update(model_weights) model.load_state_dict(current_state, strict=False) if discriminator: optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt") else: optim_path = os.path.join(checkpoint_dir, "optimizer.pt") optimizer_state_dict = torch.load(optim_path, weights_only=False) optim_state = FSDP.optim_state_dict_to_load( model=model, optim=optimizer, optim_state_dict=optimizer_state_dict) optimizer.load_state_dict(optim_state) step = int(checkpoint_dir.split("-")[-1]) return model, optimizer, step def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipeline, epoch): with FSDP.state_dict_type( transformer, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): full_state_dict = transformer.state_dict() lora_optim_state = FSDP.optim_state_dict( transformer, optimizer, ) if rank <= 0: save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}-{epoch}") os.makedirs(save_dir, exist_ok=True) # save optimizer optim_path = os.path.join(save_dir, "lora_optimizer.pt") torch.save(lora_optim_state, optim_path) # save lora weight main_print(f"--> saving LoRA checkpoint at step {step}") transformer_lora_layers = get_peft_model_state_dict( model=transformer, state_dict=full_state_dict) pipeline.save_lora_weights( save_directory=save_dir, transformer_lora_layers=transformer_lora_layers, is_main_process=True, ) # save config lora_config = { "step": step, "lora_params": { "lora_rank": transformer.config.lora_rank, "lora_alpha": transformer.config.lora_alpha, "target_modules": transformer.config.lora_target_modules, }, } config_path = os.path.join(save_dir, "lora_config.json") with open(config_path, "w") as f: json.dump(lora_config, f, indent=4) main_print(f"--> LoRA checkpoint saved at step {step}") def resume_lora_optimizer(transformer, checkpoint_dir, optimizer): config_path = os.path.join(checkpoint_dir, "lora_config.json") with open(config_path, "r") as f: config_dict = json.load(f) optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt") optimizer_state_dict = torch.load(optim_path, weights_only=False) optim_state = FSDP.optim_state_dict_to_load( model=transformer, optim=optimizer, optim_state_dict=optimizer_state_dict) optimizer.load_state_dict(optim_state) step = config_dict["step"] main_print(f"--> Successfully resuming LoRA optimizer from step {step}") return transformer, optimizer, step