|
|
import os |
|
|
from typing import List, Optional, Union, Tuple |
|
|
import torch |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed |
|
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid |
|
|
from accelerate.logging import get_logger |
|
|
import tempfile |
|
|
import argparse |
|
|
import yaml |
|
|
import shutil |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.") |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to the YAML config file." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
with open(args.config, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
args = argparse.Namespace(**config) |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def atomic_save(save_path, accelerator): |
|
|
parent = os.path.dirname(save_path) |
|
|
tmp_dir = tempfile.mkdtemp(dir=parent) |
|
|
backup_dir = save_path + "_backup" |
|
|
|
|
|
try: |
|
|
|
|
|
accelerator.save_state(tmp_dir) |
|
|
|
|
|
|
|
|
if os.path.exists(save_path): |
|
|
os.rename(save_path, backup_dir) |
|
|
|
|
|
|
|
|
os.rename(tmp_dir, save_path) |
|
|
|
|
|
|
|
|
if os.path.exists(backup_dir): |
|
|
shutil.rmtree(backup_dir) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if os.path.exists(tmp_dir): |
|
|
shutil.rmtree(tmp_dir) |
|
|
|
|
|
|
|
|
if os.path.exists(backup_dir): |
|
|
if os.path.exists(save_path): |
|
|
shutil.rmtree(save_path) |
|
|
os.rename(backup_dir, save_path) |
|
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): |
|
|
|
|
|
if use_deepspeed: |
|
|
from accelerate.utils import DummyOptim |
|
|
|
|
|
|
|
|
return DummyOptim( |
|
|
params_to_optimize, |
|
|
lr=args.learning_rate, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
eps=args.adam_epsilon, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
) |
|
|
|
|
|
|
|
|
supported_optimizers = ["adam", "adamw", "prodigy"] |
|
|
if args.optimizer not in supported_optimizers: |
|
|
logger.warning( |
|
|
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" |
|
|
) |
|
|
args.optimizer = "adamw" |
|
|
|
|
|
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): |
|
|
logger.warning( |
|
|
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " |
|
|
f"set to {args.optimizer.lower()}" |
|
|
) |
|
|
|
|
|
if args.use_8bit_adam: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
|
|
) |
|
|
|
|
|
if args.optimizer.lower() == "adamw": |
|
|
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW |
|
|
|
|
|
optimizer = optimizer_class( |
|
|
params_to_optimize, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
eps=args.adam_epsilon, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
) |
|
|
elif args.optimizer.lower() == "adam": |
|
|
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam |
|
|
|
|
|
|
|
|
optimizer = optimizer_class( |
|
|
params_to_optimize, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
eps=args.adam_epsilon, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
) |
|
|
elif args.optimizer.lower() == "prodigy": |
|
|
try: |
|
|
import prodigyopt |
|
|
except ImportError: |
|
|
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") |
|
|
|
|
|
optimizer_class = prodigyopt.Prodigy |
|
|
|
|
|
if args.learning_rate <= 0.1: |
|
|
logger.warning( |
|
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" |
|
|
) |
|
|
|
|
|
optimizer = optimizer_class( |
|
|
params_to_optimize, |
|
|
lr=args.learning_rate, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
beta3=args.prodigy_beta3, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
eps=args.adam_epsilon, |
|
|
decouple=args.prodigy_decouple, |
|
|
use_bias_correction=args.prodigy_use_bias_correction, |
|
|
safeguard_warmup=args.prodigy_safeguard_warmup, |
|
|
) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def prepare_rotary_positional_embeddings( |
|
|
height: int, |
|
|
width: int, |
|
|
num_frames: int, |
|
|
vae_scale_factor_spatial: int = 8, |
|
|
patch_size: int = 2, |
|
|
attention_head_dim: int = 64, |
|
|
device: Optional[torch.device] = None, |
|
|
base_height: int = 480, |
|
|
base_width: int = 720, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
grid_height = height // (vae_scale_factor_spatial * patch_size) |
|
|
grid_width = width // (vae_scale_factor_spatial * patch_size) |
|
|
base_size_width = base_width // (vae_scale_factor_spatial * patch_size) |
|
|
base_size_height = base_height // (vae_scale_factor_spatial * patch_size) |
|
|
|
|
|
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) |
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( |
|
|
embed_dim=attention_head_dim, |
|
|
crops_coords=grid_crops_coords, |
|
|
grid_size=(grid_height, grid_width), |
|
|
temporal_size=num_frames, |
|
|
) |
|
|
|
|
|
freqs_cos = freqs_cos.to(device=device) |
|
|
freqs_sin = freqs_sin.to(device=device) |
|
|
return freqs_cos, freqs_sin |
|
|
|
|
|
|
|
|
def _get_t5_prompt_embeds( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: Union[str, List[str]], |
|
|
num_videos_per_prompt: int = 1, |
|
|
max_sequence_length: int = 226, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
text_input_ids=None, |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) |
|
|
|
|
|
if tokenizer is not None: |
|
|
text_inputs = tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids |
|
|
else: |
|
|
if text_input_ids is None: |
|
|
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") |
|
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0] |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
|
|
|
def encode_prompt( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: Union[str, List[str]], |
|
|
num_videos_per_prompt: int = 1, |
|
|
max_sequence_length: int = 226, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
text_input_ids=None, |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
prompt_embeds = _get_t5_prompt_embeds( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
prompt=prompt, |
|
|
num_videos_per_prompt=num_videos_per_prompt, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
text_input_ids=text_input_ids, |
|
|
) |
|
|
return prompt_embeds |
|
|
|
|
|
|
|
|
def compute_prompt_embeddings( |
|
|
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False |
|
|
): |
|
|
if requires_grad: |
|
|
prompt_embeds = encode_prompt( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
prompt, |
|
|
num_videos_per_prompt=1, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
prompt_embeds = encode_prompt( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
prompt, |
|
|
num_videos_per_prompt=1, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
return prompt_embeds |
|
|
|
|
|
def save_frames_as_pngs(video_array,output_dir, |
|
|
downsample_spatial=1, |
|
|
downsample_temporal=1): |
|
|
""" |
|
|
Save each frame of a (T, H, W, C) numpy array as a PNG with no compression. |
|
|
""" |
|
|
assert video_array.ndim == 4 and video_array.shape[-1] == 3, \ |
|
|
"Expected (T, H, W, C=3) array" |
|
|
assert video_array.dtype == np.uint8, "Expected uint8 array" |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
frames = video_array[::downsample_temporal] |
|
|
|
|
|
|
|
|
T, H, W, _ = frames.shape |
|
|
new_size = (W // downsample_spatial, H // downsample_spatial) |
|
|
|
|
|
|
|
|
png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0] |
|
|
|
|
|
for idx, frame in enumerate(frames): |
|
|
|
|
|
bgr = frame[..., ::-1] |
|
|
if downsample_spatial > 1: |
|
|
bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx)) |
|
|
success = cv2.imwrite(filename, bgr, png_params) |
|
|
if not success: |
|
|
raise RuntimeError("Failed to write frame ") |
|
|
|