|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.distributed.checkpoint as dist_cp |
|
|
from peft import get_peft_model_state_dict |
|
|
from safetensors.torch import load_file, save_file |
|
|
from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner, |
|
|
DefaultSavePlanner) |
|
|
from torch.distributed.checkpoint.optimizer import \ |
|
|
load_sharded_optimizer_state_dict |
|
|
from torch.distributed.fsdp import (FullOptimStateDictConfig, |
|
|
FullStateDictConfig) |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
from torch.distributed.fsdp import StateDictType |
|
|
|
|
|
from fastvideo.utils.logging_ import main_print |
|
|
|
|
|
|
|
|
def save_checkpoint_optimizer(model, |
|
|
optimizer, |
|
|
rank, |
|
|
output_dir, |
|
|
step, |
|
|
discriminator=False): |
|
|
with FSDP.state_dict_type( |
|
|
model, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
cpu_state = model.state_dict() |
|
|
optim_state = FSDP.optim_state_dict( |
|
|
model, |
|
|
optimizer, |
|
|
) |
|
|
|
|
|
|
|
|
save_dir = os.path.join(output_dir, f"checkpoint-{step}") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
if rank <= 0 and not discriminator: |
|
|
weight_path = os.path.join(save_dir, |
|
|
"diffusion_pytorch_model.safetensors") |
|
|
save_file(cpu_state, weight_path) |
|
|
config_dict = dict(model.config) |
|
|
config_dict.pop('dtype') |
|
|
config_path = os.path.join(save_dir, "config.json") |
|
|
|
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config_dict, f, indent=4) |
|
|
optimizer_path = os.path.join(save_dir, "optimizer.pt") |
|
|
torch.save(optim_state, optimizer_path) |
|
|
else: |
|
|
weight_path = os.path.join(save_dir, |
|
|
"discriminator_pytorch_model.safetensors") |
|
|
save_file(cpu_state, weight_path) |
|
|
optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt") |
|
|
torch.save(optim_state, optimizer_path) |
|
|
main_print(f"--> checkpoint saved at step {step}") |
|
|
|
|
|
|
|
|
def save_checkpoint(transformer, rank, output_dir, step, epoch): |
|
|
main_print(f"--> saving checkpoint at step {step}") |
|
|
with FSDP.state_dict_type( |
|
|
transformer, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
cpu_state = transformer.state_dict() |
|
|
|
|
|
if rank <= 0: |
|
|
save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
weight_path = os.path.join(save_dir, |
|
|
"diffusion_pytorch_model.safetensors") |
|
|
save_file(cpu_state, weight_path) |
|
|
config_dict = dict(transformer.config) |
|
|
if "dtype" in config_dict: |
|
|
del config_dict["dtype"] |
|
|
config_path = os.path.join(save_dir, "config.json") |
|
|
|
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config_dict, f, indent=4) |
|
|
main_print(f"--> checkpoint saved at step {step}") |
|
|
|
|
|
|
|
|
def save_checkpoint_generator_discriminator( |
|
|
model, |
|
|
optimizer, |
|
|
discriminator, |
|
|
discriminator_optimizer, |
|
|
rank, |
|
|
output_dir, |
|
|
step, |
|
|
): |
|
|
with FSDP.state_dict_type( |
|
|
model, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
cpu_state = model.state_dict() |
|
|
|
|
|
|
|
|
save_dir = os.path.join(output_dir, f"checkpoint-{step}") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
hf_weight_dir = os.path.join(save_dir, "hf_weights") |
|
|
os.makedirs(hf_weight_dir, exist_ok=True) |
|
|
|
|
|
if rank <= 0: |
|
|
config_dict = dict(model.config) |
|
|
config_path = os.path.join(hf_weight_dir, "config.json") |
|
|
|
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config_dict, f, indent=4) |
|
|
weight_path = os.path.join(hf_weight_dir, |
|
|
"diffusion_pytorch_model.safetensors") |
|
|
save_file(cpu_state, weight_path) |
|
|
|
|
|
main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}") |
|
|
model_weight_dir = os.path.join(save_dir, "model_weights_state") |
|
|
os.makedirs(model_weight_dir, exist_ok=True) |
|
|
model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state") |
|
|
os.makedirs(model_optimizer_dir, exist_ok=True) |
|
|
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): |
|
|
optim_state = FSDP.optim_state_dict(model, optimizer) |
|
|
model_state = model.state_dict() |
|
|
weight_state_dict = {"model": model_state} |
|
|
dist_cp.save_state_dict( |
|
|
state_dict=weight_state_dict, |
|
|
storage_writer=dist_cp.FileSystemWriter(model_weight_dir), |
|
|
planner=DefaultSavePlanner(), |
|
|
) |
|
|
optimizer_state_dict = {"optimizer": optim_state} |
|
|
dist_cp.save_state_dict( |
|
|
state_dict=optimizer_state_dict, |
|
|
storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir), |
|
|
planner=DefaultSavePlanner(), |
|
|
) |
|
|
|
|
|
discriminator_fsdp_state_dir = os.path.join(save_dir, |
|
|
"discriminator_fsdp_state") |
|
|
os.makedirs(discriminator_fsdp_state_dir, exist_ok=True) |
|
|
with FSDP.state_dict_type( |
|
|
discriminator, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
optim_state = FSDP.optim_state_dict(discriminator, |
|
|
discriminator_optimizer) |
|
|
model_state = discriminator.state_dict() |
|
|
state_dict = {"optimizer": optim_state, "model": model_state} |
|
|
if rank <= 0: |
|
|
discriminator_fsdp_state_fil = os.path.join( |
|
|
discriminator_fsdp_state_dir, "discriminator_state.pt") |
|
|
torch.save(state_dict, discriminator_fsdp_state_fil) |
|
|
|
|
|
main_print("--> saved FSDP state checkpoint") |
|
|
|
|
|
|
|
|
def load_sharded_model(model, optimizer, model_dir, optimizer_dir): |
|
|
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): |
|
|
weight_state_dict = {"model": model.state_dict()} |
|
|
|
|
|
optim_state = load_sharded_optimizer_state_dict( |
|
|
model_state_dict=weight_state_dict["model"], |
|
|
optimizer_key="optimizer", |
|
|
storage_reader=dist_cp.FileSystemReader(optimizer_dir), |
|
|
) |
|
|
optim_state = optim_state["optimizer"] |
|
|
flattened_osd = FSDP.optim_state_dict_to_load( |
|
|
model=model, optim=optimizer, optim_state_dict=optim_state) |
|
|
optimizer.load_state_dict(flattened_osd) |
|
|
dist_cp.load_state_dict( |
|
|
state_dict=weight_state_dict, |
|
|
storage_reader=dist_cp.FileSystemReader(model_dir), |
|
|
planner=DefaultLoadPlanner(), |
|
|
) |
|
|
model_state = weight_state_dict["model"] |
|
|
model.load_state_dict(model_state) |
|
|
main_print(f"--> loaded model and optimizer from path {model_dir}") |
|
|
return model, optimizer |
|
|
|
|
|
|
|
|
def load_full_state_model(model, optimizer, checkpoint_file, rank): |
|
|
with FSDP.state_dict_type( |
|
|
model, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
discriminator_state = torch.load(checkpoint_file) |
|
|
model_state = discriminator_state["model"] |
|
|
if rank <= 0: |
|
|
optim_state = discriminator_state["optimizer"] |
|
|
else: |
|
|
optim_state = None |
|
|
model.load_state_dict(model_state) |
|
|
discriminator_optim_state = FSDP.optim_state_dict_to_load( |
|
|
model=model, optim=optimizer, optim_state_dict=optim_state) |
|
|
optimizer.load_state_dict(discriminator_optim_state) |
|
|
main_print( |
|
|
f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}" |
|
|
) |
|
|
return model, optimizer |
|
|
|
|
|
|
|
|
def resume_training_generator_discriminator(model, optimizer, discriminator, |
|
|
discriminator_optimizer, |
|
|
checkpoint_dir, rank): |
|
|
step = int(checkpoint_dir.split("-")[-1]) |
|
|
model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state") |
|
|
model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state") |
|
|
model, optimizer = load_sharded_model(model, optimizer, model_weight_dir, |
|
|
model_optimizer_dir) |
|
|
discriminator_ckpt_file = os.path.join(checkpoint_dir, |
|
|
"discriminator_fsdp_state", |
|
|
"discriminator_state.pt") |
|
|
discriminator, discriminator_optimizer = load_full_state_model( |
|
|
discriminator, discriminator_optimizer, discriminator_ckpt_file, rank) |
|
|
return model, optimizer, discriminator, discriminator_optimizer, step |
|
|
|
|
|
|
|
|
def resume_training(model, optimizer, checkpoint_dir, discriminator=False): |
|
|
weight_path = os.path.join(checkpoint_dir, |
|
|
"diffusion_pytorch_model.safetensors") |
|
|
if discriminator: |
|
|
weight_path = os.path.join(checkpoint_dir, |
|
|
"discriminator_pytorch_model.safetensors") |
|
|
model_weights = load_file(weight_path) |
|
|
|
|
|
with FSDP.state_dict_type( |
|
|
model, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
current_state = model.state_dict() |
|
|
current_state.update(model_weights) |
|
|
model.load_state_dict(current_state, strict=False) |
|
|
if discriminator: |
|
|
optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt") |
|
|
else: |
|
|
optim_path = os.path.join(checkpoint_dir, "optimizer.pt") |
|
|
optimizer_state_dict = torch.load(optim_path, weights_only=False) |
|
|
optim_state = FSDP.optim_state_dict_to_load( |
|
|
model=model, optim=optimizer, optim_state_dict=optimizer_state_dict) |
|
|
optimizer.load_state_dict(optim_state) |
|
|
step = int(checkpoint_dir.split("-")[-1]) |
|
|
return model, optimizer, step |
|
|
|
|
|
|
|
|
def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, |
|
|
pipeline, epoch): |
|
|
with FSDP.state_dict_type( |
|
|
transformer, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
|
): |
|
|
full_state_dict = transformer.state_dict() |
|
|
lora_optim_state = FSDP.optim_state_dict( |
|
|
transformer, |
|
|
optimizer, |
|
|
) |
|
|
|
|
|
if rank <= 0: |
|
|
save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}-{epoch}") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
optim_path = os.path.join(save_dir, "lora_optimizer.pt") |
|
|
torch.save(lora_optim_state, optim_path) |
|
|
|
|
|
main_print(f"--> saving LoRA checkpoint at step {step}") |
|
|
transformer_lora_layers = get_peft_model_state_dict( |
|
|
model=transformer, state_dict=full_state_dict) |
|
|
pipeline.save_lora_weights( |
|
|
save_directory=save_dir, |
|
|
transformer_lora_layers=transformer_lora_layers, |
|
|
is_main_process=True, |
|
|
) |
|
|
|
|
|
lora_config = { |
|
|
"step": step, |
|
|
"lora_params": { |
|
|
"lora_rank": transformer.config.lora_rank, |
|
|
"lora_alpha": transformer.config.lora_alpha, |
|
|
"target_modules": transformer.config.lora_target_modules, |
|
|
}, |
|
|
} |
|
|
config_path = os.path.join(save_dir, "lora_config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(lora_config, f, indent=4) |
|
|
main_print(f"--> LoRA checkpoint saved at step {step}") |
|
|
|
|
|
|
|
|
def resume_lora_optimizer(transformer, checkpoint_dir, optimizer): |
|
|
config_path = os.path.join(checkpoint_dir, "lora_config.json") |
|
|
with open(config_path, "r") as f: |
|
|
config_dict = json.load(f) |
|
|
optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt") |
|
|
optimizer_state_dict = torch.load(optim_path, weights_only=False) |
|
|
optim_state = FSDP.optim_state_dict_to_load( |
|
|
model=transformer, |
|
|
optim=optimizer, |
|
|
optim_state_dict=optimizer_state_dict) |
|
|
optimizer.load_state_dict(optim_state) |
|
|
step = config_dict["step"] |
|
|
main_print(f"--> Successfully resuming LoRA optimizer from step {step}") |
|
|
return transformer, optimizer, step |
|
|
|