Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- fastvideo/config_sd/__pycache__/base.cpython-310.pyc +0 -0
- fastvideo/config_sd/base.py +113 -0
- fastvideo/config_sd/dgx.py +60 -0
- fastvideo/data_preprocess/.DS_Store +0 -0
- fastvideo/data_preprocess/preprocess_flux_embedding.py +170 -0
- fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py +172 -0
- fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py +224 -0
- fastvideo/data_preprocess/preprocess_qwenimage_embedding.py +220 -0
- fastvideo/data_preprocess/preprocess_rl_embeddings.py +175 -0
- fastvideo/data_preprocess/preprocess_text_embeddings.py +175 -0
- fastvideo/data_preprocess/preprocess_vae_latents.py +137 -0
- fastvideo/data_preprocess/preprocess_validation_text_embeddings.py +80 -0
- fastvideo/dataset/.DS_Store +0 -0
- fastvideo/dataset/__init__.py +104 -0
- fastvideo/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- fastvideo/dataset/__pycache__/__init__.cpython-312.pyc +0 -0
- fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc +0 -0
- fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc +0 -0
- fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc +0 -0
- fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc +0 -0
- fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc +0 -0
- fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc +0 -0
- fastvideo/dataset/__pycache__/transform.cpython-310.pyc +0 -0
- fastvideo/dataset/__pycache__/transform.cpython-312.pyc +0 -0
- fastvideo/dataset/latent_datasets.py +132 -0
- fastvideo/dataset/latent_flux_rfpt_datasets.py +122 -0
- fastvideo/dataset/latent_flux_rfpt_datasets_all.py +134 -0
- fastvideo/dataset/latent_flux_rl_datasets.py +110 -0
- fastvideo/dataset/latent_qwenimage_rl_datasets.py +90 -0
- fastvideo/dataset/latent_rl_datasets.py +99 -0
- fastvideo/dataset/t2v_datasets.py +351 -0
- fastvideo/dataset/transform.py +647 -0
- fastvideo/distill/__init__.py +0 -0
- fastvideo/distill/__pycache__/__init__.cpython-312.pyc +0 -0
- fastvideo/distill/__pycache__/solver.cpython-312.pyc +0 -0
- fastvideo/distill/discriminator.py +84 -0
- fastvideo/distill/solver.py +310 -0
- fastvideo/models/.DS_Store +0 -0
- fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc +0 -0
- fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc +0 -0
- fastvideo/models/flash_attn_no_pad.py +37 -0
- fastvideo/reward_model/clip_score.py +98 -0
- fastvideo/reward_model/hps_score.py +79 -0
- fastvideo/reward_model/image_reward.py +40 -0
- fastvideo/reward_model/pick_score.py +107 -0
- fastvideo/reward_model/unified_reward.py +333 -0
- fastvideo/reward_model/utils.py +126 -0
- fastvideo/utils/.DS_Store +0 -0
- fastvideo/utils/checkpoint.py +314 -0
- fastvideo/utils/communications.py +335 -0
fastvideo/config_sd/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
fastvideo/config_sd/base.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ml_collections
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_config():
|
| 5 |
+
config = ml_collections.ConfigDict()
|
| 6 |
+
|
| 7 |
+
###### General ######
|
| 8 |
+
# run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime.
|
| 9 |
+
config.run_name = ""
|
| 10 |
+
# random seed for reproducibility.
|
| 11 |
+
config.seed = 42
|
| 12 |
+
# top-level logging directory for checkpoint saving.
|
| 13 |
+
config.logdir = "logs"
|
| 14 |
+
# number of epochs to train for. each epoch is one round of sampling from the model followed by training on those
|
| 15 |
+
# samples.
|
| 16 |
+
config.num_epochs = 300
|
| 17 |
+
# number of epochs between saving model checkpoints.
|
| 18 |
+
config.save_freq = 20
|
| 19 |
+
# number of checkpoints to keep before overwriting old ones.
|
| 20 |
+
config.num_checkpoint_limit = 5
|
| 21 |
+
# mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly.
|
| 22 |
+
config.mixed_precision = "bf16"
|
| 23 |
+
# allow tf32 on Ampere GPUs, which can speed up training.
|
| 24 |
+
config.allow_tf32 = True
|
| 25 |
+
# resume training from a checkpoint. either an exact checkpoint directory (e.g. checkpoint_50), or a directory
|
| 26 |
+
# containing checkpoints, in which case the latest one will be used. `config.use_lora` must be set to the same value
|
| 27 |
+
# as the run that generated the saved checkpoint.
|
| 28 |
+
config.resume_from = ""
|
| 29 |
+
# whether or not to use LoRA. LoRA reduces memory usage significantly by injecting small weight matrices into the
|
| 30 |
+
# attention layers of the UNet. with LoRA, fp16, and a batch size of 1, finetuning Stable Diffusion should take
|
| 31 |
+
# about 10GB of GPU memory. beware that if LoRA is disabled, training will take a lot of memory and saved checkpoint
|
| 32 |
+
# files will also be large.
|
| 33 |
+
config.use_lora = False
|
| 34 |
+
|
| 35 |
+
###### Pretrained Model ######
|
| 36 |
+
config.pretrained = pretrained = ml_collections.ConfigDict()
|
| 37 |
+
# base model to load. either a path to a local directory, or a model name from the HuggingFace model hub.
|
| 38 |
+
pretrained.model = "./data/StableDiffusion"
|
| 39 |
+
# revision of the model to load.
|
| 40 |
+
pretrained.revision = "main"
|
| 41 |
+
|
| 42 |
+
###### Sampling ######
|
| 43 |
+
config.sample = sample = ml_collections.ConfigDict()
|
| 44 |
+
# number of sampler inference steps.
|
| 45 |
+
sample.num_steps = 50
|
| 46 |
+
# eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0
|
| 47 |
+
# being fully deterministic and 1.0 being equivalent to the DDPM sampler.
|
| 48 |
+
sample.eta = 1.0
|
| 49 |
+
# classifier-free guidance weight. 1.0 is no guidance.
|
| 50 |
+
sample.guidance_scale = 5.0
|
| 51 |
+
# batch size (per GPU!) to use for sampling.
|
| 52 |
+
sample.batch_size = 1
|
| 53 |
+
# number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch *
|
| 54 |
+
# batch_size * num_gpus`.
|
| 55 |
+
sample.num_batches_per_epoch = 2
|
| 56 |
+
|
| 57 |
+
###### Training ######
|
| 58 |
+
config.train = train = ml_collections.ConfigDict()
|
| 59 |
+
# batch size (per GPU!) to use for training.
|
| 60 |
+
train.batch_size = 1
|
| 61 |
+
# whether to use the 8bit Adam optimizer from bitsandbytes.
|
| 62 |
+
train.use_8bit_adam = False
|
| 63 |
+
# learning rate.
|
| 64 |
+
train.learning_rate = 1e-5
|
| 65 |
+
# Adam beta1.
|
| 66 |
+
train.adam_beta1 = 0.9
|
| 67 |
+
# Adam beta2.
|
| 68 |
+
train.adam_beta2 = 0.999
|
| 69 |
+
# Adam weight decay.
|
| 70 |
+
train.adam_weight_decay = 1e-4
|
| 71 |
+
# Adam epsilon.
|
| 72 |
+
train.adam_epsilon = 1e-8
|
| 73 |
+
# number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus *
|
| 74 |
+
# gradient_accumulation_steps`.
|
| 75 |
+
train.gradient_accumulation_steps = 1
|
| 76 |
+
# maximum gradient norm for gradient clipping.
|
| 77 |
+
train.max_grad_norm = 1.0
|
| 78 |
+
# number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one
|
| 79 |
+
# outer epoch's round of sampling.
|
| 80 |
+
train.num_inner_epochs = 1
|
| 81 |
+
# whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during
|
| 82 |
+
# sampling will be used during training.
|
| 83 |
+
train.cfg = True
|
| 84 |
+
# clip advantages to the range [-adv_clip_max, adv_clip_max].
|
| 85 |
+
train.adv_clip_max = 5
|
| 86 |
+
# the PPO clip range.
|
| 87 |
+
train.clip_range = 1e-4
|
| 88 |
+
# the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
|
| 89 |
+
# timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates.
|
| 90 |
+
train.timestep_fraction = 1.0
|
| 91 |
+
|
| 92 |
+
###### Prompt Function ######
|
| 93 |
+
# prompt function to use. see `prompts.py` for available prompt functions.
|
| 94 |
+
config.prompt_fn = "imagenet_animals"
|
| 95 |
+
# kwargs to pass to the prompt function.
|
| 96 |
+
config.prompt_fn_kwargs = {}
|
| 97 |
+
|
| 98 |
+
###### Reward Function ######
|
| 99 |
+
# reward function to use. see `rewards.py` for available reward functions.
|
| 100 |
+
config.reward_fn = "hpsv2"
|
| 101 |
+
|
| 102 |
+
###### Per-Prompt Stat Tracking ######
|
| 103 |
+
# when enabled, the model will track the mean and std of reward on a per-prompt basis and use that to compute
|
| 104 |
+
# advantages. set `config.per_prompt_stat_tracking` to None to disable per-prompt stat tracking, in which case
|
| 105 |
+
# advantages will be calculated using the mean and std of the entire batch.
|
| 106 |
+
#config.per_prompt_stat_tracking = ml_collections.ConfigDict()
|
| 107 |
+
# number of reward values to store in the buffer for each prompt. the buffer persists across epochs.
|
| 108 |
+
#config.per_prompt_stat_tracking.buffer_size = 16
|
| 109 |
+
# the minimum number of reward values to store in the buffer before using the per-prompt mean and std. if the buffer
|
| 110 |
+
# contains fewer than `min_count` values, the mean and std of the entire batch will be used instead.
|
| 111 |
+
#config.per_prompt_stat_tracking.min_count = 16
|
| 112 |
+
|
| 113 |
+
return config
|
fastvideo/config_sd/dgx.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ml_collections
|
| 2 |
+
import imp
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compressibility():
|
| 9 |
+
config = base.get_config()
|
| 10 |
+
|
| 11 |
+
config.pretrained.model = "CompVis/stable-diffusion-v1-4"
|
| 12 |
+
|
| 13 |
+
config.num_epochs = 300
|
| 14 |
+
config.save_freq = 50
|
| 15 |
+
config.num_checkpoint_limit = 100000000
|
| 16 |
+
|
| 17 |
+
# the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
|
| 18 |
+
config.sample.batch_size = 8
|
| 19 |
+
config.sample.num_batches_per_epoch = 4
|
| 20 |
+
|
| 21 |
+
# this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
|
| 22 |
+
config.train.batch_size = 1
|
| 23 |
+
config.train.gradient_accumulation_steps = 4
|
| 24 |
+
|
| 25 |
+
# prompting
|
| 26 |
+
config.prompt_fn = "imagenet_animals"
|
| 27 |
+
config.prompt_fn_kwargs = {}
|
| 28 |
+
|
| 29 |
+
# rewards
|
| 30 |
+
config.reward_fn = "jpeg_compressibility"
|
| 31 |
+
|
| 32 |
+
config.per_prompt_stat_tracking = {
|
| 33 |
+
"buffer_size": 16,
|
| 34 |
+
"min_count": 16,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return config
|
| 38 |
+
|
| 39 |
+
def hps():
|
| 40 |
+
config = compressibility()
|
| 41 |
+
config.num_epochs = 300
|
| 42 |
+
config.reward_fn = "aesthetic_score"
|
| 43 |
+
|
| 44 |
+
# this reward is a bit harder to optimize, so I used 2 gradient updates per epoch.
|
| 45 |
+
config.train.gradient_accumulation_steps = 8
|
| 46 |
+
|
| 47 |
+
# the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
|
| 48 |
+
config.sample.batch_size = 4
|
| 49 |
+
|
| 50 |
+
# this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
|
| 51 |
+
config.train.batch_size = 4
|
| 52 |
+
|
| 53 |
+
config.prompt_fn = "aes"
|
| 54 |
+
config.chosen_number = 16
|
| 55 |
+
config.num_generations = 16
|
| 56 |
+
return config
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_config(name):
|
| 60 |
+
return globals()[name]()
|
fastvideo/data_preprocess/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
fastvideo/data_preprocess/preprocess_flux_embedding.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import torch
|
| 15 |
+
from accelerate.logging import get_logger
|
| 16 |
+
import cv2
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
from torch.utils.data import Dataset
|
| 24 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
import re
|
| 28 |
+
from diffusers import FluxPipeline
|
| 29 |
+
|
| 30 |
+
def contains_chinese(text):
|
| 31 |
+
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
| 32 |
+
|
| 33 |
+
class T5dataset(Dataset):
|
| 34 |
+
def __init__(
|
| 35 |
+
self, txt_path, vae_debug,
|
| 36 |
+
):
|
| 37 |
+
self.txt_path = txt_path
|
| 38 |
+
self.vae_debug = vae_debug
|
| 39 |
+
with open(self.txt_path, "r", encoding="utf-8") as f:
|
| 40 |
+
self.train_dataset = [
|
| 41 |
+
line for line in f.read().splitlines() if not contains_chinese(line)
|
| 42 |
+
][:50000]
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
#import pdb;pdb.set_trace()
|
| 46 |
+
caption = self.train_dataset[idx]
|
| 47 |
+
filename = str(idx)
|
| 48 |
+
#length = self.train_dataset[idx]["length"]
|
| 49 |
+
if self.vae_debug:
|
| 50 |
+
latents = torch.load(
|
| 51 |
+
os.path.join(
|
| 52 |
+
args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
|
| 53 |
+
),
|
| 54 |
+
map_location="cpu",
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
latents = []
|
| 58 |
+
|
| 59 |
+
return dict(caption=caption, latents=latents, filename=filename)
|
| 60 |
+
|
| 61 |
+
def __len__(self):
|
| 62 |
+
return len(self.train_dataset)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main(args):
|
| 66 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 67 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 68 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 69 |
+
|
| 70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
torch.cuda.set_device(local_rank)
|
| 72 |
+
if not dist.is_initialized():
|
| 73 |
+
dist.init_process_group(
|
| 74 |
+
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 78 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
|
| 79 |
+
os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
|
| 80 |
+
os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
|
| 81 |
+
|
| 82 |
+
latents_txt_path = args.prompt_dir
|
| 83 |
+
train_dataset = T5dataset(latents_txt_path, args.vae_debug)
|
| 84 |
+
sampler = DistributedSampler(
|
| 85 |
+
train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
|
| 86 |
+
)
|
| 87 |
+
train_dataloader = DataLoader(
|
| 88 |
+
train_dataset,
|
| 89 |
+
sampler=sampler,
|
| 90 |
+
batch_size=args.train_batch_size,
|
| 91 |
+
num_workers=args.dataloader_num_workers,
|
| 92 |
+
)
|
| 93 |
+
flux_path = args.model_path
|
| 94 |
+
pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device)
|
| 95 |
+
|
| 96 |
+
json_data = []
|
| 97 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 98 |
+
try:
|
| 99 |
+
with torch.inference_mode():
|
| 100 |
+
if args.vae_debug:
|
| 101 |
+
latents = data["latents"]
|
| 102 |
+
for idx, video_name in enumerate(data["filename"]):
|
| 103 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
|
| 104 |
+
prompt=data["caption"], prompt_2=data["caption"]
|
| 105 |
+
)
|
| 106 |
+
prompt_embed_path = os.path.join(
|
| 107 |
+
args.output_dir, "prompt_embed", video_name + ".pt"
|
| 108 |
+
)
|
| 109 |
+
pooled_prompt_embeds_path = os.path.join(
|
| 110 |
+
args.output_dir, "pooled_prompt_embeds", video_name + ".pt"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
text_ids_path = os.path.join(
|
| 114 |
+
args.output_dir, "text_ids", video_name + ".pt"
|
| 115 |
+
)
|
| 116 |
+
# save latent
|
| 117 |
+
torch.save(prompt_embeds[idx], prompt_embed_path)
|
| 118 |
+
torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
|
| 119 |
+
torch.save(text_ids[idx], text_ids_path)
|
| 120 |
+
item = {}
|
| 121 |
+
item["prompt_embed_path"] = video_name + ".pt"
|
| 122 |
+
item["text_ids"] = video_name + ".pt"
|
| 123 |
+
item["pooled_prompt_embeds_path"] = video_name + ".pt"
|
| 124 |
+
item["caption"] = data["caption"][idx]
|
| 125 |
+
json_data.append(item)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Rank {local_rank} Error: {repr(e)}")
|
| 128 |
+
dist.barrier()
|
| 129 |
+
raise
|
| 130 |
+
dist.barrier()
|
| 131 |
+
local_data = json_data
|
| 132 |
+
gathered_data = [None] * world_size
|
| 133 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 134 |
+
if local_rank == 0:
|
| 135 |
+
# os.remove(latents_json_path)
|
| 136 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 137 |
+
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
|
| 138 |
+
json.dump(all_json_data, f, indent=4)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
parser = argparse.ArgumentParser()
|
| 143 |
+
# dataset & dataloader
|
| 144 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 145 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 146 |
+
# text encoder & vae & diffusion model
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--dataloader_num_workers",
|
| 149 |
+
type=int,
|
| 150 |
+
default=1,
|
| 151 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--train_batch_size",
|
| 155 |
+
type=int,
|
| 156 |
+
default=1,
|
| 157 |
+
help="Batch size (per device) for the training dataloader.",
|
| 158 |
+
)
|
| 159 |
+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
|
| 160 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--output_dir",
|
| 163 |
+
type=str,
|
| 164 |
+
default=None,
|
| 165 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument("--vae_debug", action="store_true")
|
| 168 |
+
parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import torch
|
| 15 |
+
from accelerate.logging import get_logger
|
| 16 |
+
import cv2
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
from torch.utils.data import Dataset
|
| 24 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
import re
|
| 28 |
+
from diffusers import FluxPipeline
|
| 29 |
+
|
| 30 |
+
def contains_chinese(text):
|
| 31 |
+
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
| 32 |
+
|
| 33 |
+
class T5dataset(Dataset):
|
| 34 |
+
def __init__(
|
| 35 |
+
self, txt_path, vae_debug,
|
| 36 |
+
):
|
| 37 |
+
self.txt_path = txt_path
|
| 38 |
+
self.vae_debug = vae_debug
|
| 39 |
+
print(f"[DEBUG] Loading captions from: {self.txt_path}")
|
| 40 |
+
with open(self.txt_path, "r", encoding="utf-8") as f:
|
| 41 |
+
self.train_dataset = [
|
| 42 |
+
line.strip() for line in f.read().splitlines() if line.strip() and not contains_chinese(line)
|
| 43 |
+
][:50000]
|
| 44 |
+
print(f"[DEBUG] Loaded {len(self.train_dataset)} captions after filtering")
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
#import pdb;pdb.set_trace()
|
| 48 |
+
caption = self.train_dataset[idx]
|
| 49 |
+
filename = str(idx)
|
| 50 |
+
#length = self.train_dataset[idx]["length"]
|
| 51 |
+
if self.vae_debug:
|
| 52 |
+
latents = torch.load(
|
| 53 |
+
os.path.join(
|
| 54 |
+
args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
|
| 55 |
+
),
|
| 56 |
+
map_location="cpu",
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
latents = []
|
| 60 |
+
|
| 61 |
+
return dict(caption=caption, latents=latents, filename=filename)
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.train_dataset)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main(args):
|
| 68 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 69 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 70 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 71 |
+
|
| 72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 73 |
+
torch.cuda.set_device(local_rank)
|
| 74 |
+
if not dist.is_initialized():
|
| 75 |
+
dist.init_process_group(
|
| 76 |
+
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 80 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
|
| 81 |
+
os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
|
| 82 |
+
os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
|
| 83 |
+
|
| 84 |
+
latents_txt_path = args.prompt_dir
|
| 85 |
+
train_dataset = T5dataset(latents_txt_path, args.vae_debug)
|
| 86 |
+
sampler = DistributedSampler(
|
| 87 |
+
train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
|
| 88 |
+
)
|
| 89 |
+
train_dataloader = DataLoader(
|
| 90 |
+
train_dataset,
|
| 91 |
+
sampler=sampler,
|
| 92 |
+
batch_size=args.train_batch_size,
|
| 93 |
+
num_workers=args.dataloader_num_workers,
|
| 94 |
+
)
|
| 95 |
+
flux_path = args.model_path
|
| 96 |
+
pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device)
|
| 97 |
+
|
| 98 |
+
json_data = []
|
| 99 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 100 |
+
try:
|
| 101 |
+
with torch.inference_mode():
|
| 102 |
+
if args.vae_debug:
|
| 103 |
+
latents = data["latents"]
|
| 104 |
+
for idx, video_name in enumerate(data["filename"]):
|
| 105 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
|
| 106 |
+
prompt=data["caption"], prompt_2=data["caption"]
|
| 107 |
+
)
|
| 108 |
+
prompt_embed_path = os.path.join(
|
| 109 |
+
args.output_dir, "prompt_embed", video_name + ".pt"
|
| 110 |
+
)
|
| 111 |
+
pooled_prompt_embeds_path = os.path.join(
|
| 112 |
+
args.output_dir, "pooled_prompt_embeds", video_name + ".pt"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
text_ids_path = os.path.join(
|
| 116 |
+
args.output_dir, "text_ids", video_name + ".pt"
|
| 117 |
+
)
|
| 118 |
+
# save latent
|
| 119 |
+
torch.save(prompt_embeds[idx], prompt_embed_path)
|
| 120 |
+
torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
|
| 121 |
+
torch.save(text_ids[idx], text_ids_path)
|
| 122 |
+
item = {}
|
| 123 |
+
item["prompt_embed_path"] = video_name + ".pt"
|
| 124 |
+
item["text_ids"] = video_name + ".pt"
|
| 125 |
+
item["pooled_prompt_embeds_path"] = video_name + ".pt"
|
| 126 |
+
item["caption"] = data["caption"][idx]
|
| 127 |
+
json_data.append(item)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"Rank {local_rank} Error: {repr(e)}")
|
| 130 |
+
dist.barrier()
|
| 131 |
+
raise
|
| 132 |
+
dist.barrier()
|
| 133 |
+
local_data = json_data
|
| 134 |
+
gathered_data = [None] * world_size
|
| 135 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 136 |
+
if local_rank == 0:
|
| 137 |
+
# os.remove(latents_json_path)
|
| 138 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 139 |
+
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
|
| 140 |
+
json.dump(all_json_data, f, indent=4)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
parser = argparse.ArgumentParser()
|
| 145 |
+
# dataset & dataloader
|
| 146 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 147 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 148 |
+
# text encoder & vae & diffusion model
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--dataloader_num_workers",
|
| 151 |
+
type=int,
|
| 152 |
+
default=1,
|
| 153 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--train_batch_size",
|
| 157 |
+
type=int,
|
| 158 |
+
default=1,
|
| 159 |
+
help="Batch size (per device) for the training dataloader.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
|
| 162 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--output_dir",
|
| 165 |
+
type=str,
|
| 166 |
+
default=None,
|
| 167 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument("--vae_debug", action="store_true")
|
| 170 |
+
parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
|
| 171 |
+
args = parser.parse_args()
|
| 172 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import torch
|
| 15 |
+
from accelerate.logging import get_logger
|
| 16 |
+
import cv2
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
import pandas as pd
|
| 21 |
+
from torch.utils.data.dataset import ConcatDataset, Dataset
|
| 22 |
+
import io
|
| 23 |
+
import torchvision.transforms as transforms
|
| 24 |
+
logger = get_logger(__name__)
|
| 25 |
+
from torch.utils.data import Dataset
|
| 26 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
import re
|
| 30 |
+
from diffusers import FluxPipeline
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 33 |
+
|
| 34 |
+
def contains_chinese(text):
|
| 35 |
+
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
| 36 |
+
|
| 37 |
+
class RFPTdataset(Dataset):
|
| 38 |
+
def __init__(
|
| 39 |
+
self, file_path,
|
| 40 |
+
):
|
| 41 |
+
self.file_path = file_path
|
| 42 |
+
file_names = os.listdir(self.file_path) # each file contains 5,000 images
|
| 43 |
+
self.file_names = [os.path.join(self.file_path, file_name) for file_name in file_names]
|
| 44 |
+
self.train_dataset = self.read_data()
|
| 45 |
+
self.transform = transforms.ToTensor()
|
| 46 |
+
|
| 47 |
+
def read_data(self):
|
| 48 |
+
df_list = [pd.read_parquet(file_name) for file_name in self.file_names]
|
| 49 |
+
combined_df = pd.concat(df_list, axis=0, ignore_index=True)
|
| 50 |
+
return combined_df
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.train_dataset)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, index):
|
| 56 |
+
|
| 57 |
+
image = self.train_dataset.iloc[index]['image']['bytes']
|
| 58 |
+
image = self.transform(Image.open(io.BytesIO(image)).convert('RGB'))
|
| 59 |
+
# print(image.shape)
|
| 60 |
+
|
| 61 |
+
caption = self.train_dataset.iloc[index]['caption_composition']
|
| 62 |
+
# print(caption)
|
| 63 |
+
filename = str(index)
|
| 64 |
+
if caption == None or image == None:
|
| 65 |
+
return self.__getitem__(index+1)
|
| 66 |
+
return dict(caption=caption, image=image, filename=filename)
|
| 67 |
+
|
| 68 |
+
class T5dataset(Dataset):
|
| 69 |
+
def __init__(
|
| 70 |
+
self, txt_path, vae_debug,
|
| 71 |
+
):
|
| 72 |
+
self.txt_path = txt_path
|
| 73 |
+
self.vae_debug = vae_debug
|
| 74 |
+
with open(self.txt_path, "r", encoding="utf-8") as f:
|
| 75 |
+
self.train_dataset = [
|
| 76 |
+
line for line in f.read().splitlines() if not contains_chinese(line)
|
| 77 |
+
][:50000]
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, idx):
|
| 80 |
+
#import pdb;pdb.set_trace()
|
| 81 |
+
caption = self.train_dataset[idx]
|
| 82 |
+
filename = str(idx)
|
| 83 |
+
#length = self.train_dataset[idx]["length"]
|
| 84 |
+
if self.vae_debug:
|
| 85 |
+
latents = torch.load(
|
| 86 |
+
os.path.join(
|
| 87 |
+
args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
|
| 88 |
+
),
|
| 89 |
+
map_location="cpu",
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
latents = []
|
| 93 |
+
|
| 94 |
+
return dict(caption=caption, latents=latents, filename=filename)
|
| 95 |
+
|
| 96 |
+
def __len__(self):
|
| 97 |
+
return len(self.train_dataset)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def main(args):
|
| 101 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 102 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 103 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 104 |
+
|
| 105 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 106 |
+
torch.cuda.set_device(local_rank)
|
| 107 |
+
if not dist.is_initialized():
|
| 108 |
+
dist.init_process_group(
|
| 109 |
+
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 113 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
|
| 114 |
+
os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
|
| 115 |
+
os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
|
| 116 |
+
os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# latents_txt_path = args.prompt_dir
|
| 119 |
+
# train_dataset = T5dataset(latents_txt_path, args.vae_debug)
|
| 120 |
+
|
| 121 |
+
train_dataset = RFPTdataset(args.prompt_dir)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
sampler = DistributedSampler(
|
| 125 |
+
train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
train_dataloader = DataLoader(
|
| 129 |
+
train_dataset,
|
| 130 |
+
sampler=sampler,
|
| 131 |
+
batch_size=args.train_batch_size,
|
| 132 |
+
num_workers=args.dataloader_num_workers,
|
| 133 |
+
)
|
| 134 |
+
flux_path = args.model_path
|
| 135 |
+
pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device)
|
| 136 |
+
image_processor = VaeImageProcessor(16)
|
| 137 |
+
|
| 138 |
+
json_data = []
|
| 139 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 140 |
+
try:
|
| 141 |
+
with torch.inference_mode():
|
| 142 |
+
if args.vae_debug:
|
| 143 |
+
latents = data["latents"]
|
| 144 |
+
for idx, video_name in enumerate(data["filename"]):
|
| 145 |
+
# prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
|
| 146 |
+
# prompt=data["caption"], prompt_2=data["caption"]
|
| 147 |
+
# )
|
| 148 |
+
# image_latents = pipe.vae.encode(data["image"].to(torch.bfloat16).to(device)).latent_dist.sample()
|
| 149 |
+
# output_image = pipe.vae.decode(image_latents, return_dict=False)[0]
|
| 150 |
+
# output_image = image_processor.postprocess(output_image)
|
| 151 |
+
# output_image[0].save('output.png')
|
| 152 |
+
# print(image_latents.latent_dist.sample())
|
| 153 |
+
# print(image_latents.latent_dist.sample().shape)
|
| 154 |
+
|
| 155 |
+
prompt_embed_path = os.path.join(
|
| 156 |
+
args.output_dir, "prompt_embed", video_name + ".pt"
|
| 157 |
+
)
|
| 158 |
+
pooled_prompt_embeds_path = os.path.join(
|
| 159 |
+
args.output_dir, "pooled_prompt_embeds", video_name + ".pt"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
text_ids_path = os.path.join(
|
| 163 |
+
args.output_dir, "text_ids", video_name + ".pt"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
image_latents_path = os.path.join(
|
| 167 |
+
args.output_dir, "images", video_name + ".pt"
|
| 168 |
+
)
|
| 169 |
+
# save latent
|
| 170 |
+
# torch.save(prompt_embeds[idx], prompt_embed_path)
|
| 171 |
+
# torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
|
| 172 |
+
# torch.save(text_ids[idx], text_ids_path)
|
| 173 |
+
torch.save(data["image"].to(torch.bfloat16), image_latents_path)
|
| 174 |
+
item = {}
|
| 175 |
+
item["prompt_embed_path"] = video_name + ".pt"
|
| 176 |
+
item["text_ids"] = video_name + ".pt"
|
| 177 |
+
item["pooled_prompt_embeds_path"] = video_name + ".pt"
|
| 178 |
+
item["caption"] = data["caption"][idx]
|
| 179 |
+
json_data.append(item)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Rank {local_rank} Error: {repr(e)}")
|
| 182 |
+
dist.barrier()
|
| 183 |
+
raise
|
| 184 |
+
dist.barrier()
|
| 185 |
+
local_data = json_data
|
| 186 |
+
gathered_data = [None] * world_size
|
| 187 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 188 |
+
if local_rank == 0:
|
| 189 |
+
# os.remove(latents_json_path)
|
| 190 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 191 |
+
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
|
| 192 |
+
json.dump(all_json_data, f, indent=4)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
parser = argparse.ArgumentParser()
|
| 197 |
+
# dataset & dataloader
|
| 198 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 199 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 200 |
+
# text encoder & vae & diffusion model
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--dataloader_num_workers",
|
| 203 |
+
type=int,
|
| 204 |
+
default=1,
|
| 205 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--train_batch_size",
|
| 209 |
+
type=int,
|
| 210 |
+
default=1,
|
| 211 |
+
help="Batch size (per device) for the training dataloader.",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
|
| 214 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--output_dir",
|
| 217 |
+
type=str,
|
| 218 |
+
default=None,
|
| 219 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument("--vae_debug", action="store_true")
|
| 222 |
+
parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
|
| 223 |
+
args = parser.parse_args()
|
| 224 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_qwenimage_embedding.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import torch
|
| 14 |
+
from accelerate.logging import get_logger
|
| 15 |
+
# from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
|
| 16 |
+
from diffusers.utils import export_to_video
|
| 17 |
+
from fastvideo.models.qwenimage.pipeline_qwenimage import QwenImagePipeline
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
from torch.utils.data import Dataset
|
| 24 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
from fastvideo.utils.load import load_text_encoder, load_vae
|
| 27 |
+
from diffusers.video_processor import VideoProcessor
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
import re
|
| 30 |
+
from diffusers import DiffusionPipeline
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
|
| 33 |
+
def contains_chinese(text):
|
| 34 |
+
"""检查字符串是否包含中文字符"""
|
| 35 |
+
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
| 36 |
+
|
| 37 |
+
class T5dataset(Dataset):
|
| 38 |
+
def __init__(
|
| 39 |
+
self, txt_path, vae_debug,
|
| 40 |
+
):
|
| 41 |
+
self.txt_path = txt_path
|
| 42 |
+
self.vae_debug = vae_debug
|
| 43 |
+
with open(self.txt_path, "r", encoding="utf-8") as f:
|
| 44 |
+
self.train_dataset = [
|
| 45 |
+
line for line in f.read().splitlines() if not contains_chinese(line)
|
| 46 |
+
]
|
| 47 |
+
#self.train_dataset = sorted(train_dataset)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
#import pdb;pdb.set_trace()
|
| 51 |
+
caption = self.train_dataset[idx]
|
| 52 |
+
filename = str(idx)
|
| 53 |
+
#length = self.train_dataset[idx]["length"]
|
| 54 |
+
if self.vae_debug:
|
| 55 |
+
latents = torch.load(
|
| 56 |
+
os.path.join(
|
| 57 |
+
args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
|
| 58 |
+
),
|
| 59 |
+
map_location="cpu",
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
latents = []
|
| 63 |
+
|
| 64 |
+
return dict(caption=caption, latents=latents, filename=filename)
|
| 65 |
+
|
| 66 |
+
def __len__(self):
|
| 67 |
+
return len(self.train_dataset)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main(args):
|
| 71 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 72 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 73 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 74 |
+
|
| 75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
+
torch.cuda.set_device(local_rank)
|
| 77 |
+
if not dist.is_initialized():
|
| 78 |
+
dist.init_process_group(
|
| 79 |
+
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
#videoprocessor = VideoProcessor(vae_scale_factor=8)
|
| 83 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 84 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
|
| 85 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True)
|
| 86 |
+
|
| 87 |
+
latents_txt_path = args.prompt_dir
|
| 88 |
+
train_dataset = T5dataset(latents_txt_path, args.vae_debug)
|
| 89 |
+
#text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
|
| 90 |
+
#vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
|
| 91 |
+
#vae.enable_tiling()
|
| 92 |
+
sampler = DistributedSampler(
|
| 93 |
+
train_dataset, rank=local_rank, num_replicas=world_size, shuffle=False
|
| 94 |
+
)
|
| 95 |
+
train_dataloader = DataLoader(
|
| 96 |
+
train_dataset,
|
| 97 |
+
sampler=sampler,
|
| 98 |
+
batch_size=args.train_batch_size,
|
| 99 |
+
num_workers=args.dataloader_num_workers,
|
| 100 |
+
)
|
| 101 |
+
# Load pipeline but don't move everything to GPU yet
|
| 102 |
+
pipe = QwenImagePipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
|
| 103 |
+
|
| 104 |
+
# Only move text_encoder to GPU for embedding generation
|
| 105 |
+
pipe.text_encoder = pipe.text_encoder.to(device)
|
| 106 |
+
|
| 107 |
+
# Delete unused components to free up RAM/VRAM
|
| 108 |
+
if not args.vae_debug:
|
| 109 |
+
# Remove from attributes
|
| 110 |
+
if hasattr(pipe, "transformer"):
|
| 111 |
+
del pipe.transformer
|
| 112 |
+
if hasattr(pipe, "vae"):
|
| 113 |
+
del pipe.vae
|
| 114 |
+
|
| 115 |
+
# Remove from components dictionary to ensure garbage collection
|
| 116 |
+
if "transformer" in pipe.components:
|
| 117 |
+
del pipe.components["transformer"]
|
| 118 |
+
if "vae" in pipe.components:
|
| 119 |
+
del pipe.components["vae"]
|
| 120 |
+
|
| 121 |
+
import gc
|
| 122 |
+
gc.collect()
|
| 123 |
+
torch.cuda.empty_cache()
|
| 124 |
+
|
| 125 |
+
# pipe._execution_device = device # This causes AttributeError, removing it.
|
| 126 |
+
|
| 127 |
+
json_data = []
|
| 128 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 129 |
+
with torch.inference_mode():
|
| 130 |
+
with torch.autocast("cuda"):
|
| 131 |
+
prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
|
| 132 |
+
prompt=data["caption"],
|
| 133 |
+
device=device # Explicitly pass device
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# ==================== 代码修改开始 ====================
|
| 137 |
+
|
| 138 |
+
# 1. 记录原始的序列长度 (第二个维度的大小)
|
| 139 |
+
original_length = prompt_embeds.shape[1]
|
| 140 |
+
target_length = 1024
|
| 141 |
+
|
| 142 |
+
# 2. 计算需要填充的长度
|
| 143 |
+
# 假设 original_length 不会超过 target_length
|
| 144 |
+
pad_len = target_length - original_length
|
| 145 |
+
|
| 146 |
+
# 3. 填充 prompt_embeds
|
| 147 |
+
# prompt_embeds 是一个3D张量 (B, L, D),我们需要填充第二个维度 L
|
| 148 |
+
# F.pad 的填充参数顺序是从最后一个维度开始的 (pad_dim_D_left, pad_dim_D_right, pad_dim_L_left, pad_dim_L_right, ...)
|
| 149 |
+
# 我们在维度1(序列长度L)的右侧进行填充
|
| 150 |
+
prompt_embeds = F.pad(prompt_embeds, (0, 0, 0, pad_len), "constant", 0)
|
| 151 |
+
|
| 152 |
+
# 4. 填充 prompt_attention_mask
|
| 153 |
+
# prompt_attention_mask 是一个2D张量 (B, L),我们同样填充第二个维度 L
|
| 154 |
+
# 我们在维度1(序列长度L)的右侧进行填充
|
| 155 |
+
prompt_attention_mask = F.pad(prompt_attention_mask, (0, pad_len), "constant", 0)
|
| 156 |
+
|
| 157 |
+
# ==================== 代码修改结束 ====================
|
| 158 |
+
|
| 159 |
+
if args.vae_debug:
|
| 160 |
+
latents = data["latents"]
|
| 161 |
+
for idx, video_name in enumerate(data["filename"]):
|
| 162 |
+
prompt_embed_path = os.path.join(
|
| 163 |
+
args.output_dir, "prompt_embed", video_name + ".pt"
|
| 164 |
+
)
|
| 165 |
+
prompt_attention_mask_path = os.path.join(
|
| 166 |
+
args.output_dir, "prompt_attention_mask", video_name + ".pt"
|
| 167 |
+
)
|
| 168 |
+
# 保存 latent (注意这里保存的是填充后的张量)
|
| 169 |
+
torch.save(prompt_embeds[idx], prompt_embed_path)
|
| 170 |
+
torch.save(prompt_attention_mask[idx], prompt_attention_mask_path)
|
| 171 |
+
item = {}
|
| 172 |
+
item["prompt_embed_path"] = video_name + ".pt"
|
| 173 |
+
item["prompt_attention_mask"] = video_name + ".pt"
|
| 174 |
+
item["caption"] = data["caption"][idx]
|
| 175 |
+
|
| 176 |
+
# [新增] 将原始长度记录到 item 字典中
|
| 177 |
+
item["original_length"] = original_length
|
| 178 |
+
|
| 179 |
+
json_data.append(item)
|
| 180 |
+
dist.barrier()
|
| 181 |
+
local_data = json_data
|
| 182 |
+
gathered_data = [None] * world_size
|
| 183 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 184 |
+
if local_rank == 0:
|
| 185 |
+
# os.remove(latents_json_path)
|
| 186 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 187 |
+
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
|
| 188 |
+
json.dump(all_json_data, f, indent=4)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
parser = argparse.ArgumentParser()
|
| 193 |
+
# dataset & dataloader
|
| 194 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 195 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 196 |
+
# text encoder & vae & diffusion model
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--dataloader_num_workers",
|
| 199 |
+
type=int,
|
| 200 |
+
default=1,
|
| 201 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--train_batch_size",
|
| 205 |
+
type=int,
|
| 206 |
+
default=1,
|
| 207 |
+
help="Batch size (per device) for the training dataloader.",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
|
| 210 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--output_dir",
|
| 213 |
+
type=str,
|
| 214 |
+
default=None,
|
| 215 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument("--vae_debug", action="store_true")
|
| 218 |
+
parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_rl_embeddings.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import torch
|
| 14 |
+
from accelerate.logging import get_logger
|
| 15 |
+
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
|
| 16 |
+
from diffusers.utils import export_to_video
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
from torch.utils.data import Dataset
|
| 23 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 24 |
+
from torch.utils.data import DataLoader
|
| 25 |
+
from fastvideo.utils.load import load_text_encoder, load_vae
|
| 26 |
+
from diffusers.video_processor import VideoProcessor
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
import re
|
| 29 |
+
|
| 30 |
+
def contains_chinese(text):
|
| 31 |
+
"""检查字符串是否包含中文字符"""
|
| 32 |
+
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
| 33 |
+
|
| 34 |
+
class T5dataset(Dataset):
|
| 35 |
+
def __init__(
|
| 36 |
+
self, txt_path, vae_debug,
|
| 37 |
+
):
|
| 38 |
+
self.txt_path = txt_path
|
| 39 |
+
self.vae_debug = vae_debug
|
| 40 |
+
with open(self.txt_path, "r", encoding="utf-8") as f:
|
| 41 |
+
self.train_dataset = [
|
| 42 |
+
line for line in f.read().splitlines() if not contains_chinese(line)
|
| 43 |
+
]
|
| 44 |
+
#self.train_dataset = sorted(train_dataset)
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
#import pdb;pdb.set_trace()
|
| 48 |
+
caption = self.train_dataset[idx]
|
| 49 |
+
filename = str(idx)
|
| 50 |
+
#length = self.train_dataset[idx]["length"]
|
| 51 |
+
if self.vae_debug:
|
| 52 |
+
latents = torch.load(
|
| 53 |
+
os.path.join(
|
| 54 |
+
args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
|
| 55 |
+
),
|
| 56 |
+
map_location="cpu",
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
latents = []
|
| 60 |
+
|
| 61 |
+
return dict(caption=caption, latents=latents, filename=filename)
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.train_dataset)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main(args):
|
| 68 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 69 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 70 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 71 |
+
|
| 72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 73 |
+
torch.cuda.set_device(local_rank)
|
| 74 |
+
if not dist.is_initialized():
|
| 75 |
+
dist.init_process_group(
|
| 76 |
+
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
#videoprocessor = VideoProcessor(vae_scale_factor=8)
|
| 80 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 81 |
+
os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True)
|
| 82 |
+
#os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
|
| 83 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
|
| 84 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True)
|
| 85 |
+
|
| 86 |
+
latents_txt_path = args.prompt_dir
|
| 87 |
+
train_dataset = T5dataset(latents_txt_path, args.vae_debug)
|
| 88 |
+
text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
|
| 89 |
+
#vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
|
| 90 |
+
#vae.enable_tiling()
|
| 91 |
+
sampler = DistributedSampler(
|
| 92 |
+
train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
|
| 93 |
+
)
|
| 94 |
+
train_dataloader = DataLoader(
|
| 95 |
+
train_dataset,
|
| 96 |
+
sampler=sampler,
|
| 97 |
+
batch_size=args.train_batch_size,
|
| 98 |
+
num_workers=args.dataloader_num_workers,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
json_data = []
|
| 102 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 103 |
+
with torch.inference_mode():
|
| 104 |
+
with torch.autocast("cuda"):
|
| 105 |
+
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
|
| 106 |
+
prompt=data["caption"],
|
| 107 |
+
)
|
| 108 |
+
if args.vae_debug:
|
| 109 |
+
latents = data["latents"]
|
| 110 |
+
#video = vae.decode(latents.to(device), return_dict=False)[0]
|
| 111 |
+
#video = videoprocessor.postprocess_video(video)
|
| 112 |
+
for idx, video_name in enumerate(data["filename"]):
|
| 113 |
+
prompt_embed_path = os.path.join(
|
| 114 |
+
args.output_dir, "prompt_embed", video_name + ".pt"
|
| 115 |
+
)
|
| 116 |
+
#video_path = os.path.join(
|
| 117 |
+
# args.output_dir, "video", video_name + ".mp4"
|
| 118 |
+
#)
|
| 119 |
+
prompt_attention_mask_path = os.path.join(
|
| 120 |
+
args.output_dir, "prompt_attention_mask", video_name + ".pt"
|
| 121 |
+
)
|
| 122 |
+
# save latent
|
| 123 |
+
torch.save(prompt_embeds[idx], prompt_embed_path)
|
| 124 |
+
torch.save(prompt_attention_mask[idx], prompt_attention_mask_path)
|
| 125 |
+
#print(f"sample {video_name} saved")
|
| 126 |
+
#if args.vae_debug:
|
| 127 |
+
# export_to_video(video[idx], video_path, fps=fps)
|
| 128 |
+
item = {}
|
| 129 |
+
#item["length"] = int(data["length"][idx])
|
| 130 |
+
#item["latent_path"] = video_name + ".pt"
|
| 131 |
+
item["prompt_embed_path"] = video_name + ".pt"
|
| 132 |
+
item["prompt_attention_mask"] = video_name + ".pt"
|
| 133 |
+
item["caption"] = data["caption"][idx]
|
| 134 |
+
json_data.append(item)
|
| 135 |
+
dist.barrier()
|
| 136 |
+
local_data = json_data
|
| 137 |
+
gathered_data = [None] * world_size
|
| 138 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 139 |
+
if local_rank == 0:
|
| 140 |
+
# os.remove(latents_json_path)
|
| 141 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 142 |
+
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
|
| 143 |
+
json.dump(all_json_data, f, indent=4)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
parser = argparse.ArgumentParser()
|
| 148 |
+
# dataset & dataloader
|
| 149 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 150 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 151 |
+
# text encoder & vae & diffusion model
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--dataloader_num_workers",
|
| 154 |
+
type=int,
|
| 155 |
+
default=1,
|
| 156 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--train_batch_size",
|
| 160 |
+
type=int,
|
| 161 |
+
default=1,
|
| 162 |
+
help="Batch size (per device) for the training dataloader.",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
|
| 165 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--output_dir",
|
| 168 |
+
type=str,
|
| 169 |
+
default=None,
|
| 170 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument("--vae_debug", action="store_true")
|
| 173 |
+
parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_text_embeddings.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from accelerate.logging import get_logger
|
| 10 |
+
from diffusers.utils import export_to_video
|
| 11 |
+
from diffusers.video_processor import VideoProcessor
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from fastvideo.utils.load import load_text_encoder, load_vae
|
| 17 |
+
|
| 18 |
+
logger = get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class T5dataset(Dataset):
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
json_path,
|
| 26 |
+
vae_debug,
|
| 27 |
+
):
|
| 28 |
+
self.json_path = json_path
|
| 29 |
+
self.vae_debug = vae_debug
|
| 30 |
+
with open(self.json_path, "r") as f:
|
| 31 |
+
train_dataset = json.load(f)
|
| 32 |
+
self.train_dataset = sorted(train_dataset,
|
| 33 |
+
key=lambda x: x["latent_path"])
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
caption = self.train_dataset[idx]["caption"]
|
| 37 |
+
filename = self.train_dataset[idx]["latent_path"].split(".")[0]
|
| 38 |
+
length = self.train_dataset[idx]["length"]
|
| 39 |
+
if self.vae_debug:
|
| 40 |
+
latents = torch.load(
|
| 41 |
+
os.path.join(args.output_dir, "latent",
|
| 42 |
+
self.train_dataset[idx]["latent_path"]),
|
| 43 |
+
map_location="cpu",
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
latents = []
|
| 47 |
+
|
| 48 |
+
return dict(caption=caption,
|
| 49 |
+
latents=latents,
|
| 50 |
+
filename=filename,
|
| 51 |
+
length=length)
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.train_dataset)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main(args):
|
| 58 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 59 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 60 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 61 |
+
|
| 62 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 63 |
+
torch.cuda.set_device(local_rank)
|
| 64 |
+
if not dist.is_initialized():
|
| 65 |
+
dist.init_process_group(backend="nccl",
|
| 66 |
+
init_method="env://",
|
| 67 |
+
world_size=world_size,
|
| 68 |
+
rank=local_rank)
|
| 69 |
+
|
| 70 |
+
videoprocessor = VideoProcessor(vae_scale_factor=8)
|
| 71 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 72 |
+
os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True)
|
| 73 |
+
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
|
| 74 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
|
| 75 |
+
os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"),
|
| 76 |
+
exist_ok=True)
|
| 77 |
+
|
| 78 |
+
latents_json_path = os.path.join(args.output_dir,
|
| 79 |
+
"videos2caption_temp.json")
|
| 80 |
+
train_dataset = T5dataset(latents_json_path, args.vae_debug)
|
| 81 |
+
text_encoder = load_text_encoder(args.model_type,
|
| 82 |
+
args.model_path,
|
| 83 |
+
device=device)
|
| 84 |
+
vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
|
| 85 |
+
vae.enable_tiling()
|
| 86 |
+
sampler = DistributedSampler(train_dataset,
|
| 87 |
+
rank=local_rank,
|
| 88 |
+
num_replicas=world_size,
|
| 89 |
+
shuffle=True)
|
| 90 |
+
train_dataloader = DataLoader(
|
| 91 |
+
train_dataset,
|
| 92 |
+
sampler=sampler,
|
| 93 |
+
batch_size=args.train_batch_size,
|
| 94 |
+
num_workers=args.dataloader_num_workers,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
json_data = []
|
| 98 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 99 |
+
with torch.inference_mode():
|
| 100 |
+
with torch.autocast("cuda", dtype=autocast_type):
|
| 101 |
+
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
|
| 102 |
+
prompt=data["caption"], )
|
| 103 |
+
if args.vae_debug:
|
| 104 |
+
latents = data["latents"]
|
| 105 |
+
video = vae.decode(latents.to(device),
|
| 106 |
+
return_dict=False)[0]
|
| 107 |
+
video = videoprocessor.postprocess_video(video)
|
| 108 |
+
for idx, video_name in enumerate(data["filename"]):
|
| 109 |
+
prompt_embed_path = os.path.join(args.output_dir,
|
| 110 |
+
"prompt_embed",
|
| 111 |
+
video_name + ".pt")
|
| 112 |
+
video_path = os.path.join(args.output_dir, "video",
|
| 113 |
+
video_name + ".mp4")
|
| 114 |
+
prompt_attention_mask_path = os.path.join(
|
| 115 |
+
args.output_dir, "prompt_attention_mask",
|
| 116 |
+
video_name + ".pt")
|
| 117 |
+
# save latent
|
| 118 |
+
torch.save(prompt_embeds[idx], prompt_embed_path)
|
| 119 |
+
torch.save(prompt_attention_mask[idx],
|
| 120 |
+
prompt_attention_mask_path)
|
| 121 |
+
print(f"sample {video_name} saved")
|
| 122 |
+
if args.vae_debug:
|
| 123 |
+
export_to_video(video[idx], video_path, fps=fps)
|
| 124 |
+
item = {}
|
| 125 |
+
item["length"] = int(data["length"][idx])
|
| 126 |
+
item["latent_path"] = video_name + ".pt"
|
| 127 |
+
item["prompt_embed_path"] = video_name + ".pt"
|
| 128 |
+
item["prompt_attention_mask"] = video_name + ".pt"
|
| 129 |
+
item["caption"] = data["caption"][idx]
|
| 130 |
+
json_data.append(item)
|
| 131 |
+
dist.barrier()
|
| 132 |
+
local_data = json_data
|
| 133 |
+
gathered_data = [None] * world_size
|
| 134 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 135 |
+
if local_rank == 0:
|
| 136 |
+
# os.remove(latents_json_path)
|
| 137 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 138 |
+
with open(os.path.join(args.output_dir, "videos2caption.json"),
|
| 139 |
+
"w") as f:
|
| 140 |
+
json.dump(all_json_data, f, indent=4)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
parser = argparse.ArgumentParser()
|
| 145 |
+
# dataset & dataloader
|
| 146 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 147 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 148 |
+
# text encoder & vae & diffusion model
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--dataloader_num_workers",
|
| 151 |
+
type=int,
|
| 152 |
+
default=1,
|
| 153 |
+
help=
|
| 154 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--train_batch_size",
|
| 158 |
+
type=int,
|
| 159 |
+
default=1,
|
| 160 |
+
help="Batch size (per device) for the training dataloader.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument("--text_encoder_name",
|
| 163 |
+
type=str,
|
| 164 |
+
default="google/t5-v1_1-xxl")
|
| 165 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--output_dir",
|
| 168 |
+
type=str,
|
| 169 |
+
default=None,
|
| 170 |
+
help=
|
| 171 |
+
"The output directory where the model predictions and checkpoints will be written.",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument("--vae_debug", action="store_true")
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_vae_latents.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from accelerate.logging import get_logger
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from fastvideo.dataset import getdataset
|
| 15 |
+
from fastvideo.utils.load import load_vae
|
| 16 |
+
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def main(args):
|
| 21 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 22 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 23 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 24 |
+
train_dataset = getdataset(args)
|
| 25 |
+
sampler = DistributedSampler(train_dataset,
|
| 26 |
+
rank=local_rank,
|
| 27 |
+
num_replicas=world_size,
|
| 28 |
+
shuffle=True)
|
| 29 |
+
train_dataloader = DataLoader(
|
| 30 |
+
train_dataset,
|
| 31 |
+
sampler=sampler,
|
| 32 |
+
batch_size=args.train_batch_size,
|
| 33 |
+
num_workers=args.dataloader_num_workers,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
encoder_device = torch.device(
|
| 37 |
+
"cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
torch.cuda.set_device(local_rank)
|
| 39 |
+
if not dist.is_initialized():
|
| 40 |
+
dist.init_process_group(backend="nccl",
|
| 41 |
+
init_method="env://",
|
| 42 |
+
world_size=world_size,
|
| 43 |
+
rank=local_rank)
|
| 44 |
+
vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
|
| 45 |
+
vae.enable_tiling()
|
| 46 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 47 |
+
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
|
| 48 |
+
|
| 49 |
+
json_data = []
|
| 50 |
+
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
|
| 51 |
+
with torch.inference_mode():
|
| 52 |
+
with torch.autocast("cuda", dtype=autocast_type):
|
| 53 |
+
latents = vae.encode(data["pixel_values"].to(
|
| 54 |
+
encoder_device))["latent_dist"].sample()
|
| 55 |
+
for idx, video_path in enumerate(data["path"]):
|
| 56 |
+
video_name = os.path.basename(video_path).split(".")[0]
|
| 57 |
+
latent_path = os.path.join(args.output_dir, "latent",
|
| 58 |
+
video_name + ".pt")
|
| 59 |
+
torch.save(latents[idx].to(torch.bfloat16), latent_path)
|
| 60 |
+
item = {}
|
| 61 |
+
item["length"] = latents[idx].shape[1]
|
| 62 |
+
item["latent_path"] = video_name + ".pt"
|
| 63 |
+
item["caption"] = data["text"][idx]
|
| 64 |
+
json_data.append(item)
|
| 65 |
+
print(f"{video_name} processed")
|
| 66 |
+
dist.barrier()
|
| 67 |
+
local_data = json_data
|
| 68 |
+
gathered_data = [None] * world_size
|
| 69 |
+
dist.all_gather_object(gathered_data, local_data)
|
| 70 |
+
if local_rank == 0:
|
| 71 |
+
all_json_data = [item for sublist in gathered_data for item in sublist]
|
| 72 |
+
with open(os.path.join(args.output_dir, "videos2caption_temp.json"),
|
| 73 |
+
"w") as f:
|
| 74 |
+
json.dump(all_json_data, f, indent=4)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
# dataset & dataloader
|
| 80 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 81 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 82 |
+
parser.add_argument("--data_merge_path", type=str, required=True)
|
| 83 |
+
parser.add_argument("--num_frames", type=int, default=163)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--dataloader_num_workers",
|
| 86 |
+
type=int,
|
| 87 |
+
default=1,
|
| 88 |
+
help=
|
| 89 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--train_batch_size",
|
| 93 |
+
type=int,
|
| 94 |
+
default=16,
|
| 95 |
+
help="Batch size (per device) for the training dataloader.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument("--num_latent_t",
|
| 98 |
+
type=int,
|
| 99 |
+
default=28,
|
| 100 |
+
help="Number of latent timesteps.")
|
| 101 |
+
parser.add_argument("--max_height", type=int, default=480)
|
| 102 |
+
parser.add_argument("--max_width", type=int, default=848)
|
| 103 |
+
parser.add_argument("--video_length_tolerance_range",
|
| 104 |
+
type=int,
|
| 105 |
+
default=2.0)
|
| 106 |
+
parser.add_argument("--group_frame", action="store_true") # TODO
|
| 107 |
+
parser.add_argument("--group_resolution", action="store_true") # TODO
|
| 108 |
+
parser.add_argument("--dataset", default="t2v")
|
| 109 |
+
parser.add_argument("--train_fps", type=int, default=30)
|
| 110 |
+
parser.add_argument("--use_image_num", type=int, default=0)
|
| 111 |
+
parser.add_argument("--text_max_length", type=int, default=256)
|
| 112 |
+
parser.add_argument("--speed_factor", type=float, default=1.0)
|
| 113 |
+
parser.add_argument("--drop_short_ratio", type=float, default=1.0)
|
| 114 |
+
# text encoder & vae & diffusion model
|
| 115 |
+
parser.add_argument("--text_encoder_name",
|
| 116 |
+
type=str,
|
| 117 |
+
default="google/t5-v1_1-xxl")
|
| 118 |
+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
|
| 119 |
+
parser.add_argument("--cfg", type=float, default=0.0)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--output_dir",
|
| 122 |
+
type=str,
|
| 123 |
+
default=None,
|
| 124 |
+
help=
|
| 125 |
+
"The output directory where the model predictions and checkpoints will be written.",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--logging_dir",
|
| 129 |
+
type=str,
|
| 130 |
+
default="logs",
|
| 131 |
+
help=
|
| 132 |
+
("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 133 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
main(args)
|
fastvideo/data_preprocess/preprocess_validation_text_embeddings.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from accelerate.logging import get_logger
|
| 9 |
+
|
| 10 |
+
from fastvideo.utils.load import load_text_encoder
|
| 11 |
+
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main(args):
|
| 16 |
+
local_rank = int(os.getenv("RANK", 0))
|
| 17 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 18 |
+
print("world_size", world_size, "local rank", local_rank)
|
| 19 |
+
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
torch.cuda.set_device(local_rank)
|
| 22 |
+
if not dist.is_initialized():
|
| 23 |
+
dist.init_process_group(backend="nccl",
|
| 24 |
+
init_method="env://",
|
| 25 |
+
world_size=world_size,
|
| 26 |
+
rank=local_rank)
|
| 27 |
+
|
| 28 |
+
text_encoder = load_text_encoder(args.model_type,
|
| 29 |
+
args.model_path,
|
| 30 |
+
device=device)
|
| 31 |
+
autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16
|
| 32 |
+
# output_dir/validation/prompt_attention_mask
|
| 33 |
+
# output_dir/validation/prompt_embed
|
| 34 |
+
os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
|
| 35 |
+
os.makedirs(
|
| 36 |
+
os.path.join(args.output_dir, "validation", "prompt_attention_mask"),
|
| 37 |
+
exist_ok=True,
|
| 38 |
+
)
|
| 39 |
+
os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"),
|
| 40 |
+
exist_ok=True)
|
| 41 |
+
|
| 42 |
+
with open(args.validation_prompt_txt, "r", encoding="utf-8") as file:
|
| 43 |
+
lines = file.readlines()
|
| 44 |
+
prompts = [line.strip() for line in lines]
|
| 45 |
+
for prompt in prompts:
|
| 46 |
+
with torch.inference_mode():
|
| 47 |
+
with torch.autocast("cuda", dtype=autocast_type):
|
| 48 |
+
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
|
| 49 |
+
prompt)
|
| 50 |
+
file_name = prompt.split(".")[0]
|
| 51 |
+
prompt_embed_path = os.path.join(args.output_dir, "validation",
|
| 52 |
+
"prompt_embed",
|
| 53 |
+
f"{file_name}.pt")
|
| 54 |
+
prompt_attention_mask_path = os.path.join(
|
| 55 |
+
args.output_dir,
|
| 56 |
+
"validation",
|
| 57 |
+
"prompt_attention_mask",
|
| 58 |
+
f"{file_name}.pt",
|
| 59 |
+
)
|
| 60 |
+
torch.save(prompt_embeds[0], prompt_embed_path)
|
| 61 |
+
torch.save(prompt_attention_mask[0],
|
| 62 |
+
prompt_attention_mask_path)
|
| 63 |
+
print(f"sample {file_name} saved")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
parser = argparse.ArgumentParser()
|
| 68 |
+
# dataset & dataloader
|
| 69 |
+
parser.add_argument("--model_path", type=str, default="data/mochi")
|
| 70 |
+
parser.add_argument("--model_type", type=str, default="mochi")
|
| 71 |
+
parser.add_argument("--validation_prompt_txt", type=str)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--output_dir",
|
| 74 |
+
type=str,
|
| 75 |
+
default=None,
|
| 76 |
+
help=
|
| 77 |
+
"The output directory where the model predictions and checkpoints will be written.",
|
| 78 |
+
)
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
main(args)
|
fastvideo/dataset/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
fastvideo/dataset/__init__.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
from torchvision.transforms import Lambda
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
from fastvideo.dataset.t2v_datasets import T2V_dataset
|
| 6 |
+
from fastvideo.dataset.transform import (CenterCropResizeVideo, Normalize255,
|
| 7 |
+
TemporalRandomCrop)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def getdataset(args):
|
| 11 |
+
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
|
| 12 |
+
norm_fun = Lambda(lambda x: 2.0 * x - 1.0)
|
| 13 |
+
resize_topcrop = [
|
| 14 |
+
CenterCropResizeVideo((args.max_height, args.max_width),
|
| 15 |
+
top_crop=True),
|
| 16 |
+
]
|
| 17 |
+
resize = [
|
| 18 |
+
CenterCropResizeVideo((args.max_height, args.max_width)),
|
| 19 |
+
]
|
| 20 |
+
transform = transforms.Compose([
|
| 21 |
+
# Normalize255(),
|
| 22 |
+
*resize,
|
| 23 |
+
])
|
| 24 |
+
transform_topcrop = transforms.Compose([
|
| 25 |
+
Normalize255(),
|
| 26 |
+
*resize_topcrop,
|
| 27 |
+
norm_fun,
|
| 28 |
+
])
|
| 29 |
+
# tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir)
|
| 30 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name,
|
| 31 |
+
cache_dir=args.cache_dir)
|
| 32 |
+
if args.dataset == "t2v":
|
| 33 |
+
return T2V_dataset(
|
| 34 |
+
args,
|
| 35 |
+
transform=transform,
|
| 36 |
+
temporal_sample=temporal_sample,
|
| 37 |
+
tokenizer=tokenizer,
|
| 38 |
+
transform_topcrop=transform_topcrop,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
raise NotImplementedError(args.dataset)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
import random
|
| 46 |
+
|
| 47 |
+
from accelerate import Accelerator
|
| 48 |
+
from tqdm import tqdm
|
| 49 |
+
|
| 50 |
+
from fastvideo.dataset.t2v_datasets import dataset_prog
|
| 51 |
+
|
| 52 |
+
args = type(
|
| 53 |
+
"args",
|
| 54 |
+
(),
|
| 55 |
+
{
|
| 56 |
+
"ae": "CausalVAEModel_4x8x8",
|
| 57 |
+
"dataset": "t2v",
|
| 58 |
+
"attention_mode": "xformers",
|
| 59 |
+
"use_rope": True,
|
| 60 |
+
"text_max_length": 300,
|
| 61 |
+
"max_height": 320,
|
| 62 |
+
"max_width": 240,
|
| 63 |
+
"num_frames": 1,
|
| 64 |
+
"use_image_num": 0,
|
| 65 |
+
"interpolation_scale_t": 1,
|
| 66 |
+
"interpolation_scale_h": 1,
|
| 67 |
+
"interpolation_scale_w": 1,
|
| 68 |
+
"cache_dir": "../cache_dir",
|
| 69 |
+
"image_data":
|
| 70 |
+
"/storage/ongoing/new/Open-Sora-Plan-bak/7.14bak/scripts/train_data/image_data.txt",
|
| 71 |
+
"video_data": "1",
|
| 72 |
+
"train_fps": 24,
|
| 73 |
+
"drop_short_ratio": 1.0,
|
| 74 |
+
"use_img_from_vid": False,
|
| 75 |
+
"speed_factor": 1.0,
|
| 76 |
+
"cfg": 0.1,
|
| 77 |
+
"text_encoder_name": "google/mt5-xxl",
|
| 78 |
+
"dataloader_num_workers": 10,
|
| 79 |
+
},
|
| 80 |
+
)
|
| 81 |
+
accelerator = Accelerator()
|
| 82 |
+
dataset = getdataset(args)
|
| 83 |
+
num = len(dataset_prog.img_cap_list)
|
| 84 |
+
zero = 0
|
| 85 |
+
for idx in tqdm(range(num)):
|
| 86 |
+
image_data = dataset_prog.img_cap_list[idx]
|
| 87 |
+
caps = [
|
| 88 |
+
i["cap"] if isinstance(i["cap"], list) else [i["cap"]]
|
| 89 |
+
for i in image_data
|
| 90 |
+
]
|
| 91 |
+
try:
|
| 92 |
+
caps = [[random.choice(i)] for i in caps]
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(e)
|
| 95 |
+
# import ipdb;ipdb.set_trace()
|
| 96 |
+
print(image_data)
|
| 97 |
+
zero += 1
|
| 98 |
+
continue
|
| 99 |
+
assert caps[0] is not None and len(caps[0]) > 0
|
| 100 |
+
print(num, zero)
|
| 101 |
+
import ipdb
|
| 102 |
+
|
| 103 |
+
ipdb.set_trace()
|
| 104 |
+
print("end")
|
fastvideo/dataset/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
fastvideo/dataset/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc
ADDED
|
Binary file (5.12 kB). View file
|
|
|
fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc
ADDED
|
Binary file (5.56 kB). View file
|
|
|
fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc
ADDED
|
Binary file (4.67 kB). View file
|
|
|
fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc
ADDED
|
Binary file (9.14 kB). View file
|
|
|
fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
fastvideo/dataset/__pycache__/transform.cpython-310.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
fastvideo/dataset/__pycache__/transform.cpython-312.pyc
ADDED
|
Binary file (27.3 kB). View file
|
|
|
fastvideo/dataset/latent_datasets.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LatentDataset(Dataset):
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
json_path,
|
| 16 |
+
num_latent_t,
|
| 17 |
+
cfg_rate,
|
| 18 |
+
):
|
| 19 |
+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
|
| 20 |
+
self.json_path = json_path
|
| 21 |
+
self.cfg_rate = cfg_rate
|
| 22 |
+
self.datase_dir_path = os.path.dirname(json_path)
|
| 23 |
+
self.video_dir = os.path.join(self.datase_dir_path, "video")
|
| 24 |
+
self.latent_dir = os.path.join(self.datase_dir_path, "latent")
|
| 25 |
+
self.prompt_embed_dir = os.path.join(self.datase_dir_path,
|
| 26 |
+
"prompt_embed")
|
| 27 |
+
self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path,
|
| 28 |
+
"prompt_attention_mask")
|
| 29 |
+
with open(self.json_path, "r") as f:
|
| 30 |
+
self.data_anno = json.load(f)
|
| 31 |
+
# json.load(f) already keeps the order
|
| 32 |
+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
|
| 33 |
+
self.num_latent_t = num_latent_t
|
| 34 |
+
# just zero embeddings [256, 4096]
|
| 35 |
+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
|
| 36 |
+
# 256 zeros
|
| 37 |
+
self.uncond_prompt_mask = torch.zeros(256).bool()
|
| 38 |
+
self.lengths = [
|
| 39 |
+
data_item["length"] if "length" in data_item else 1
|
| 40 |
+
for data_item in self.data_anno
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, idx):
|
| 44 |
+
latent_file = self.data_anno[idx]["latent_path"]
|
| 45 |
+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
|
| 46 |
+
prompt_attention_mask_file = self.data_anno[idx][
|
| 47 |
+
"prompt_attention_mask"]
|
| 48 |
+
# load
|
| 49 |
+
latent = torch.load(
|
| 50 |
+
os.path.join(self.latent_dir, latent_file),
|
| 51 |
+
map_location="cpu",
|
| 52 |
+
weights_only=True,
|
| 53 |
+
)
|
| 54 |
+
latent = latent.squeeze(0)[:, -self.num_latent_t:]
|
| 55 |
+
if random.random() < self.cfg_rate:
|
| 56 |
+
prompt_embed = self.uncond_prompt_embed
|
| 57 |
+
prompt_attention_mask = self.uncond_prompt_mask
|
| 58 |
+
else:
|
| 59 |
+
prompt_embed = torch.load(
|
| 60 |
+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
|
| 61 |
+
map_location="cpu",
|
| 62 |
+
weights_only=True,
|
| 63 |
+
)
|
| 64 |
+
prompt_attention_mask = torch.load(
|
| 65 |
+
os.path.join(self.prompt_attention_mask_dir,
|
| 66 |
+
prompt_attention_mask_file),
|
| 67 |
+
map_location="cpu",
|
| 68 |
+
weights_only=True,
|
| 69 |
+
)
|
| 70 |
+
return latent, prompt_embed, prompt_attention_mask
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
return len(self.data_anno)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def latent_collate_function(batch):
|
| 77 |
+
# return latent, prompt, latent_attn_mask, text_attn_mask
|
| 78 |
+
# latent_attn_mask: # b t h w
|
| 79 |
+
# text_attn_mask: b 1 l
|
| 80 |
+
# needs to check if the latent/prompt' size and apply padding & attn mask
|
| 81 |
+
latents, prompt_embeds, prompt_attention_masks = zip(*batch)
|
| 82 |
+
# calculate max shape
|
| 83 |
+
max_t = max([latent.shape[1] for latent in latents])
|
| 84 |
+
max_h = max([latent.shape[2] for latent in latents])
|
| 85 |
+
max_w = max([latent.shape[3] for latent in latents])
|
| 86 |
+
|
| 87 |
+
# padding
|
| 88 |
+
latents = [
|
| 89 |
+
torch.nn.functional.pad(
|
| 90 |
+
latent,
|
| 91 |
+
(
|
| 92 |
+
0,
|
| 93 |
+
max_t - latent.shape[1],
|
| 94 |
+
0,
|
| 95 |
+
max_h - latent.shape[2],
|
| 96 |
+
0,
|
| 97 |
+
max_w - latent.shape[3],
|
| 98 |
+
),
|
| 99 |
+
) for latent in latents
|
| 100 |
+
]
|
| 101 |
+
# attn mask
|
| 102 |
+
latent_attn_mask = torch.ones(len(latents), max_t, max_h, max_w)
|
| 103 |
+
# set to 0 if padding
|
| 104 |
+
for i, latent in enumerate(latents):
|
| 105 |
+
latent_attn_mask[i, latent.shape[1]:, :, :] = 0
|
| 106 |
+
latent_attn_mask[i, :, latent.shape[2]:, :] = 0
|
| 107 |
+
latent_attn_mask[i, :, :, latent.shape[3]:] = 0
|
| 108 |
+
|
| 109 |
+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
|
| 110 |
+
prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
|
| 111 |
+
latents = torch.stack(latents, dim=0)
|
| 112 |
+
return latents, prompt_embeds, latent_attn_mask, prompt_attention_masks
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt",
|
| 117 |
+
num_latent_t=28)
|
| 118 |
+
dataloader = torch.utils.data.DataLoader(
|
| 119 |
+
dataset,
|
| 120 |
+
batch_size=2,
|
| 121 |
+
shuffle=False,
|
| 122 |
+
collate_fn=latent_collate_function)
|
| 123 |
+
for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader:
|
| 124 |
+
print(
|
| 125 |
+
latent.shape,
|
| 126 |
+
prompt_embed.shape,
|
| 127 |
+
latent_attn_mask.shape,
|
| 128 |
+
prompt_attention_mask.shape,
|
| 129 |
+
)
|
| 130 |
+
import pdb
|
| 131 |
+
|
| 132 |
+
pdb.set_trace()
|
fastvideo/dataset/latent_flux_rfpt_datasets.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LatentDataset(Dataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self, json_path, num_latent_t, cfg_rate,
|
| 22 |
+
):
|
| 23 |
+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
|
| 24 |
+
self.json_path = json_path
|
| 25 |
+
self.cfg_rate = cfg_rate
|
| 26 |
+
self.datase_dir_path = os.path.dirname(json_path)
|
| 27 |
+
#self.video_dir = os.path.join(self.datase_dir_path, "video")
|
| 28 |
+
#self.latent_dir = os.path.join(self.datase_dir_path, "latent")
|
| 29 |
+
self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
|
| 30 |
+
self.pooled_prompt_embeds_dir = os.path.join(
|
| 31 |
+
self.datase_dir_path, "pooled_prompt_embeds"
|
| 32 |
+
)
|
| 33 |
+
self.text_ids_dir = os.path.join(
|
| 34 |
+
self.datase_dir_path, "text_ids"
|
| 35 |
+
)
|
| 36 |
+
self.latents_dir = os.path.join(
|
| 37 |
+
self.datase_dir_path, "images"
|
| 38 |
+
)
|
| 39 |
+
with open(self.json_path, "r") as f:
|
| 40 |
+
self.data_anno = json.load(f)
|
| 41 |
+
# json.load(f) already keeps the order
|
| 42 |
+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
|
| 43 |
+
self.num_latent_t = num_latent_t
|
| 44 |
+
# just zero embeddings [256, 4096]
|
| 45 |
+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
|
| 46 |
+
# 256 zeros
|
| 47 |
+
self.uncond_prompt_mask = torch.zeros(256).bool()
|
| 48 |
+
self.lengths = [
|
| 49 |
+
data_item["length"] if "length" in data_item else 1
|
| 50 |
+
for data_item in self.data_anno
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
#latent_file = self.data_anno[idx]["latent_path"]
|
| 55 |
+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
|
| 56 |
+
pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"]
|
| 57 |
+
text_ids_file = self.data_anno[idx]["text_ids"]
|
| 58 |
+
latent_file = text_ids_file
|
| 59 |
+
if random.random() < self.cfg_rate:
|
| 60 |
+
prompt_embed = self.uncond_prompt_embed
|
| 61 |
+
else:
|
| 62 |
+
prompt_embed = torch.load(
|
| 63 |
+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
|
| 64 |
+
map_location="cpu",
|
| 65 |
+
weights_only=True,
|
| 66 |
+
)
|
| 67 |
+
pooled_prompt_embeds = torch.load(
|
| 68 |
+
os.path.join(
|
| 69 |
+
self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file
|
| 70 |
+
),
|
| 71 |
+
map_location="cpu",
|
| 72 |
+
weights_only=True,
|
| 73 |
+
)
|
| 74 |
+
text_ids = torch.load(
|
| 75 |
+
os.path.join(
|
| 76 |
+
self.text_ids_dir, text_ids_file
|
| 77 |
+
),
|
| 78 |
+
map_location="cpu",
|
| 79 |
+
weights_only=True,
|
| 80 |
+
)
|
| 81 |
+
latents = torch.load(
|
| 82 |
+
os.path.join(
|
| 83 |
+
self.latents_dir, latent_file
|
| 84 |
+
),
|
| 85 |
+
map_location="cpu",
|
| 86 |
+
weights_only=True,
|
| 87 |
+
)
|
| 88 |
+
return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'], latents
|
| 89 |
+
|
| 90 |
+
def __len__(self):
|
| 91 |
+
return len(self.data_anno)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def latent_collate_function(batch):
|
| 95 |
+
# return latent, prompt, latent_attn_mask, text_attn_mask
|
| 96 |
+
# latent_attn_mask: # b t h w
|
| 97 |
+
# text_attn_mask: b 1 l
|
| 98 |
+
# needs to check if the latent/prompt' size and apply padding & attn mask
|
| 99 |
+
prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents = zip(*batch)
|
| 100 |
+
# attn mask
|
| 101 |
+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
|
| 102 |
+
pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0)
|
| 103 |
+
text_ids = torch.stack(text_ids, dim=0)
|
| 104 |
+
latents= torch.stack(latents, dim=0)
|
| 105 |
+
#latents = torch.stack(latents, dim=0)
|
| 106 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
|
| 111 |
+
dataloader = torch.utils.data.DataLoader(
|
| 112 |
+
dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
|
| 113 |
+
)
|
| 114 |
+
for prompt_embed, prompt_attention_mask, caption in dataloader:
|
| 115 |
+
print(
|
| 116 |
+
prompt_embed.shape,
|
| 117 |
+
prompt_attention_mask.shape,
|
| 118 |
+
caption
|
| 119 |
+
)
|
| 120 |
+
import pdb
|
| 121 |
+
|
| 122 |
+
pdb.set_trace()
|
fastvideo/dataset/latent_flux_rfpt_datasets_all.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LatentDataset(Dataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self, json_path, num_latent_t, cfg_rate,
|
| 22 |
+
):
|
| 23 |
+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
|
| 24 |
+
self.json_path = json_path
|
| 25 |
+
self.cfg_rate = cfg_rate
|
| 26 |
+
self.datase_dir_path = os.path.dirname(json_path)
|
| 27 |
+
#self.video_dir = os.path.join(self.datase_dir_path, "video")
|
| 28 |
+
#self.latent_dir = os.path.join(self.datase_dir_path, "latent")
|
| 29 |
+
self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
|
| 30 |
+
self.pooled_prompt_embeds_dir = os.path.join(
|
| 31 |
+
self.datase_dir_path, "pooled_prompt_embeds"
|
| 32 |
+
)
|
| 33 |
+
self.text_ids_dir = os.path.join(
|
| 34 |
+
self.datase_dir_path, "text_ids"
|
| 35 |
+
)
|
| 36 |
+
self.images_dir = os.path.join(
|
| 37 |
+
self.datase_dir_path, "images"
|
| 38 |
+
)
|
| 39 |
+
self.latents_dir = os.path.join(
|
| 40 |
+
self.datase_dir_path, "latents"
|
| 41 |
+
)
|
| 42 |
+
with open(self.json_path, "r") as f:
|
| 43 |
+
self.data_anno = json.load(f)
|
| 44 |
+
# json.load(f) already keeps the order
|
| 45 |
+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
|
| 46 |
+
self.num_latent_t = num_latent_t
|
| 47 |
+
# just zero embeddings [256, 4096]
|
| 48 |
+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
|
| 49 |
+
# 256 zeros
|
| 50 |
+
self.uncond_prompt_mask = torch.zeros(256).bool()
|
| 51 |
+
self.lengths = [
|
| 52 |
+
data_item["length"] if "length" in data_item else 1
|
| 53 |
+
for data_item in self.data_anno
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
#latent_file = self.data_anno[idx]["latent_path"]
|
| 58 |
+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
|
| 59 |
+
pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"]
|
| 60 |
+
text_ids_file = self.data_anno[idx]["text_ids"]
|
| 61 |
+
latent_file = text_ids_file
|
| 62 |
+
image_file = text_ids_file
|
| 63 |
+
if random.random() < self.cfg_rate:
|
| 64 |
+
prompt_embed = self.uncond_prompt_embed
|
| 65 |
+
else:
|
| 66 |
+
prompt_embed = torch.load(
|
| 67 |
+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
|
| 68 |
+
map_location="cpu",
|
| 69 |
+
weights_only=True,
|
| 70 |
+
)
|
| 71 |
+
pooled_prompt_embeds = torch.load(
|
| 72 |
+
os.path.join(
|
| 73 |
+
self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file
|
| 74 |
+
),
|
| 75 |
+
map_location="cpu",
|
| 76 |
+
weights_only=True,
|
| 77 |
+
)
|
| 78 |
+
text_ids = torch.load(
|
| 79 |
+
os.path.join(
|
| 80 |
+
self.text_ids_dir, text_ids_file
|
| 81 |
+
),
|
| 82 |
+
map_location="cpu",
|
| 83 |
+
weights_only=True,
|
| 84 |
+
)
|
| 85 |
+
latents = torch.load(
|
| 86 |
+
os.path.join(
|
| 87 |
+
self.latents_dir, latent_file
|
| 88 |
+
),
|
| 89 |
+
map_location="cpu",
|
| 90 |
+
weights_only=True,
|
| 91 |
+
)
|
| 92 |
+
images = torch.load(
|
| 93 |
+
os.path.join(
|
| 94 |
+
self.images_dir, image_file
|
| 95 |
+
),
|
| 96 |
+
map_location="cpu",
|
| 97 |
+
weights_only=True,
|
| 98 |
+
)
|
| 99 |
+
return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'], latents, images
|
| 100 |
+
|
| 101 |
+
def __len__(self):
|
| 102 |
+
return len(self.data_anno)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def latent_collate_function(batch):
|
| 106 |
+
# return latent, prompt, latent_attn_mask, text_attn_mask
|
| 107 |
+
# latent_attn_mask: # b t h w
|
| 108 |
+
# text_attn_mask: b 1 l
|
| 109 |
+
# needs to check if the latent/prompt' size and apply padding & attn mask
|
| 110 |
+
prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents, images = zip(*batch)
|
| 111 |
+
# attn mask
|
| 112 |
+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
|
| 113 |
+
pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0)
|
| 114 |
+
text_ids = torch.stack(text_ids, dim=0)
|
| 115 |
+
latents= torch.stack(latents, dim=0)
|
| 116 |
+
images= torch.stack(images, dim=0)
|
| 117 |
+
#latents = torch.stack(latents, dim=0)
|
| 118 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents, images
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
|
| 123 |
+
dataloader = torch.utils.data.DataLoader(
|
| 124 |
+
dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
|
| 125 |
+
)
|
| 126 |
+
for prompt_embed, prompt_attention_mask, caption in dataloader:
|
| 127 |
+
print(
|
| 128 |
+
prompt_embed.shape,
|
| 129 |
+
prompt_attention_mask.shape,
|
| 130 |
+
caption
|
| 131 |
+
)
|
| 132 |
+
import pdb
|
| 133 |
+
|
| 134 |
+
pdb.set_trace()
|
fastvideo/dataset/latent_flux_rl_datasets.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LatentDataset(Dataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self, json_path, num_latent_t, cfg_rate,
|
| 22 |
+
):
|
| 23 |
+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
|
| 24 |
+
self.json_path = json_path
|
| 25 |
+
self.cfg_rate = cfg_rate
|
| 26 |
+
self.datase_dir_path = os.path.dirname(json_path)
|
| 27 |
+
#self.video_dir = os.path.join(self.datase_dir_path, "video")
|
| 28 |
+
#self.latent_dir = os.path.join(self.datase_dir_path, "latent")
|
| 29 |
+
self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
|
| 30 |
+
self.pooled_prompt_embeds_dir = os.path.join(
|
| 31 |
+
self.datase_dir_path, "pooled_prompt_embeds"
|
| 32 |
+
)
|
| 33 |
+
self.text_ids_dir = os.path.join(
|
| 34 |
+
self.datase_dir_path, "text_ids"
|
| 35 |
+
)
|
| 36 |
+
with open(self.json_path, "r") as f:
|
| 37 |
+
self.data_anno = json.load(f)
|
| 38 |
+
# json.load(f) already keeps the order
|
| 39 |
+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
|
| 40 |
+
self.num_latent_t = num_latent_t
|
| 41 |
+
# just zero embeddings [256, 4096]
|
| 42 |
+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
|
| 43 |
+
# 256 zeros
|
| 44 |
+
self.uncond_prompt_mask = torch.zeros(256).bool()
|
| 45 |
+
self.lengths = [
|
| 46 |
+
data_item["length"] if "length" in data_item else 1
|
| 47 |
+
for data_item in self.data_anno
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, idx):
|
| 51 |
+
#latent_file = self.data_anno[idx]["latent_path"]
|
| 52 |
+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
|
| 53 |
+
pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"]
|
| 54 |
+
text_ids_file = self.data_anno[idx]["text_ids"]
|
| 55 |
+
if random.random() < self.cfg_rate:
|
| 56 |
+
prompt_embed = self.uncond_prompt_embed
|
| 57 |
+
else:
|
| 58 |
+
prompt_embed = torch.load(
|
| 59 |
+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
|
| 60 |
+
map_location="cpu",
|
| 61 |
+
weights_only=True,
|
| 62 |
+
)
|
| 63 |
+
pooled_prompt_embeds = torch.load(
|
| 64 |
+
os.path.join(
|
| 65 |
+
self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file
|
| 66 |
+
),
|
| 67 |
+
map_location="cpu",
|
| 68 |
+
weights_only=True,
|
| 69 |
+
)
|
| 70 |
+
text_ids = torch.load(
|
| 71 |
+
os.path.join(
|
| 72 |
+
self.text_ids_dir, text_ids_file
|
| 73 |
+
),
|
| 74 |
+
map_location="cpu",
|
| 75 |
+
weights_only=True,
|
| 76 |
+
)
|
| 77 |
+
return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption']
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.data_anno)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def latent_collate_function(batch):
|
| 84 |
+
# return latent, prompt, latent_attn_mask, text_attn_mask
|
| 85 |
+
# latent_attn_mask: # b t h w
|
| 86 |
+
# text_attn_mask: b 1 l
|
| 87 |
+
# needs to check if the latent/prompt' size and apply padding & attn mask
|
| 88 |
+
prompt_embeds, pooled_prompt_embeds, text_ids, caption = zip(*batch)
|
| 89 |
+
# attn mask
|
| 90 |
+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
|
| 91 |
+
pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0)
|
| 92 |
+
text_ids = torch.stack(text_ids, dim=0)
|
| 93 |
+
#latents = torch.stack(latents, dim=0)
|
| 94 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids, caption
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
|
| 99 |
+
dataloader = torch.utils.data.DataLoader(
|
| 100 |
+
dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
|
| 101 |
+
)
|
| 102 |
+
for prompt_embed, prompt_attention_mask, caption in dataloader:
|
| 103 |
+
print(
|
| 104 |
+
prompt_embed.shape,
|
| 105 |
+
prompt_attention_mask.shape,
|
| 106 |
+
caption
|
| 107 |
+
)
|
| 108 |
+
import pdb
|
| 109 |
+
|
| 110 |
+
pdb.set_trace()
|
fastvideo/dataset/latent_qwenimage_rl_datasets.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LatentDataset(Dataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self, json_path, num_latent_t, cfg_rate,
|
| 22 |
+
):
|
| 23 |
+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
|
| 24 |
+
self.json_path = json_path
|
| 25 |
+
self.cfg_rate = cfg_rate
|
| 26 |
+
self.datase_dir_path = os.path.dirname(json_path)
|
| 27 |
+
#self.video_dir = os.path.join(self.datase_dir_path, "video")
|
| 28 |
+
#self.latent_dir = os.path.join(self.datase_dir_path, "latent")
|
| 29 |
+
self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
|
| 30 |
+
self.prompt_attention_mask_dir = os.path.join(
|
| 31 |
+
self.datase_dir_path, "prompt_attention_mask"
|
| 32 |
+
)
|
| 33 |
+
with open(self.json_path, "r") as f:
|
| 34 |
+
self.data_anno = json.load(f)
|
| 35 |
+
# json.load(f) already keeps the order
|
| 36 |
+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
|
| 37 |
+
self.num_latent_t = num_latent_t
|
| 38 |
+
# just zero embeddings [256, 4096]
|
| 39 |
+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
|
| 40 |
+
# 256 zeros
|
| 41 |
+
self.uncond_prompt_mask = torch.zeros(256).bool()
|
| 42 |
+
self.lengths = [
|
| 43 |
+
data_item["length"] if "length" in data_item else 1
|
| 44 |
+
for data_item in self.data_anno
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, idx):
|
| 48 |
+
#latent_file = self.data_anno[idx]["latent_path"]
|
| 49 |
+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
|
| 50 |
+
prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"]
|
| 51 |
+
if random.random() < self.cfg_rate:
|
| 52 |
+
prompt_embed = self.uncond_prompt_embed
|
| 53 |
+
prompt_attention_mask = self.uncond_prompt_mask
|
| 54 |
+
else:
|
| 55 |
+
prompt_embed = torch.load(
|
| 56 |
+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
|
| 57 |
+
map_location="cpu",
|
| 58 |
+
weights_only=True,
|
| 59 |
+
)
|
| 60 |
+
prompt_attention_mask = torch.load(
|
| 61 |
+
os.path.join(
|
| 62 |
+
self.prompt_attention_mask_dir, prompt_attention_mask_file
|
| 63 |
+
),
|
| 64 |
+
map_location="cpu",
|
| 65 |
+
weights_only=True,
|
| 66 |
+
)
|
| 67 |
+
return prompt_embed, prompt_attention_mask, self.data_anno[idx]['caption'], self.data_anno[idx]['original_length']
|
| 68 |
+
|
| 69 |
+
def __len__(self):
|
| 70 |
+
return len(self.data_anno)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def latent_collate_function(batch):
|
| 74 |
+
# return latent, prompt, latent_attn_mask, text_attn_mask
|
| 75 |
+
# latent_attn_mask: # b t h w
|
| 76 |
+
# text_attn_mask: b 1 l
|
| 77 |
+
# needs to check if the latent/prompt' size and apply padding & attn mask
|
| 78 |
+
prompt_embeds, prompt_attention_masks, caption, original_length = zip(*batch)
|
| 79 |
+
# attn mask
|
| 80 |
+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
|
| 81 |
+
prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
|
| 82 |
+
|
| 83 |
+
# Convert original_length to tensor
|
| 84 |
+
original_length = torch.tensor(original_length, dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
# Convert caption to list
|
| 87 |
+
caption = list(caption)
|
| 88 |
+
|
| 89 |
+
#latents = torch.stack(latents, dim=0)
|
| 90 |
+
return prompt_embeds, prompt_attention_masks, caption, original_length
|
fastvideo/dataset/latent_rl_datasets.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) [2025] [FastVideo Team]
|
| 2 |
+
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
|
| 3 |
+
# SPDX-License-Identifier: [Apache License 2.0]
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under [Apache License 2.0], with the full license text
|
| 8 |
+
# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LatentDataset(Dataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self, json_path, num_latent_t, cfg_rate,
|
| 22 |
+
):
|
| 23 |
+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
|
| 24 |
+
self.json_path = json_path
|
| 25 |
+
self.cfg_rate = cfg_rate
|
| 26 |
+
self.datase_dir_path = os.path.dirname(json_path)
|
| 27 |
+
#self.video_dir = os.path.join(self.datase_dir_path, "video")
|
| 28 |
+
#self.latent_dir = os.path.join(self.datase_dir_path, "latent")
|
| 29 |
+
self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
|
| 30 |
+
self.prompt_attention_mask_dir = os.path.join(
|
| 31 |
+
self.datase_dir_path, "prompt_attention_mask"
|
| 32 |
+
)
|
| 33 |
+
with open(self.json_path, "r") as f:
|
| 34 |
+
self.data_anno = json.load(f)
|
| 35 |
+
# json.load(f) already keeps the order
|
| 36 |
+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
|
| 37 |
+
self.num_latent_t = num_latent_t
|
| 38 |
+
# just zero embeddings [256, 4096]
|
| 39 |
+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
|
| 40 |
+
# 256 zeros
|
| 41 |
+
self.uncond_prompt_mask = torch.zeros(256).bool()
|
| 42 |
+
self.lengths = [
|
| 43 |
+
data_item["length"] if "length" in data_item else 1
|
| 44 |
+
for data_item in self.data_anno
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, idx):
|
| 48 |
+
#latent_file = self.data_anno[idx]["latent_path"]
|
| 49 |
+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
|
| 50 |
+
prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"]
|
| 51 |
+
if random.random() < self.cfg_rate:
|
| 52 |
+
prompt_embed = self.uncond_prompt_embed
|
| 53 |
+
prompt_attention_mask = self.uncond_prompt_mask
|
| 54 |
+
else:
|
| 55 |
+
prompt_embed = torch.load(
|
| 56 |
+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
|
| 57 |
+
map_location="cpu",
|
| 58 |
+
weights_only=True,
|
| 59 |
+
)
|
| 60 |
+
prompt_attention_mask = torch.load(
|
| 61 |
+
os.path.join(
|
| 62 |
+
self.prompt_attention_mask_dir, prompt_attention_mask_file
|
| 63 |
+
),
|
| 64 |
+
map_location="cpu",
|
| 65 |
+
weights_only=True,
|
| 66 |
+
)
|
| 67 |
+
return prompt_embed, prompt_attention_mask, self.data_anno[idx]['caption']
|
| 68 |
+
|
| 69 |
+
def __len__(self):
|
| 70 |
+
return len(self.data_anno)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def latent_collate_function(batch):
|
| 74 |
+
# return latent, prompt, latent_attn_mask, text_attn_mask
|
| 75 |
+
# latent_attn_mask: # b t h w
|
| 76 |
+
# text_attn_mask: b 1 l
|
| 77 |
+
# needs to check if the latent/prompt' size and apply padding & attn mask
|
| 78 |
+
prompt_embeds, prompt_attention_masks, caption = zip(*batch)
|
| 79 |
+
# attn mask
|
| 80 |
+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
|
| 81 |
+
prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
|
| 82 |
+
#latents = torch.stack(latents, dim=0)
|
| 83 |
+
return prompt_embeds, prompt_attention_masks, caption
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
|
| 88 |
+
dataloader = torch.utils.data.DataLoader(
|
| 89 |
+
dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
|
| 90 |
+
)
|
| 91 |
+
for prompt_embed, prompt_attention_mask, caption in dataloader:
|
| 92 |
+
print(
|
| 93 |
+
prompt_embed.shape,
|
| 94 |
+
prompt_attention_mask.shape,
|
| 95 |
+
caption
|
| 96 |
+
)
|
| 97 |
+
import pdb
|
| 98 |
+
|
| 99 |
+
pdb.set_trace()
|
fastvideo/dataset/t2v_datasets.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
from collections import Counter
|
| 8 |
+
from os.path import join as opj
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
|
| 17 |
+
from fastvideo.utils.dataset_utils import DecordInit
|
| 18 |
+
from fastvideo.utils.logging_ import main_print
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SingletonMeta(type):
|
| 22 |
+
_instances = {}
|
| 23 |
+
|
| 24 |
+
def __call__(cls, *args, **kwargs):
|
| 25 |
+
if cls not in cls._instances:
|
| 26 |
+
instance = super().__call__(*args, **kwargs)
|
| 27 |
+
cls._instances[cls] = instance
|
| 28 |
+
return cls._instances[cls]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DataSetProg(metaclass=SingletonMeta):
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
self.cap_list = []
|
| 35 |
+
self.elements = []
|
| 36 |
+
self.num_workers = 1
|
| 37 |
+
self.n_elements = 0
|
| 38 |
+
self.worker_elements = dict()
|
| 39 |
+
self.n_used_elements = dict()
|
| 40 |
+
|
| 41 |
+
def set_cap_list(self, num_workers, cap_list, n_elements):
|
| 42 |
+
self.num_workers = num_workers
|
| 43 |
+
self.cap_list = cap_list
|
| 44 |
+
self.n_elements = n_elements
|
| 45 |
+
self.elements = list(range(n_elements))
|
| 46 |
+
random.shuffle(self.elements)
|
| 47 |
+
print(f"n_elements: {len(self.elements)}", flush=True)
|
| 48 |
+
|
| 49 |
+
for i in range(self.num_workers):
|
| 50 |
+
self.n_used_elements[i] = 0
|
| 51 |
+
per_worker = int(
|
| 52 |
+
math.ceil(len(self.elements) / float(self.num_workers)))
|
| 53 |
+
start = i * per_worker
|
| 54 |
+
end = min(start + per_worker, len(self.elements))
|
| 55 |
+
self.worker_elements[i] = self.elements[start:end]
|
| 56 |
+
|
| 57 |
+
def get_item(self, work_info):
|
| 58 |
+
if work_info is None:
|
| 59 |
+
worker_id = 0
|
| 60 |
+
else:
|
| 61 |
+
worker_id = work_info.id
|
| 62 |
+
|
| 63 |
+
idx = self.worker_elements[worker_id][
|
| 64 |
+
self.n_used_elements[worker_id] %
|
| 65 |
+
len(self.worker_elements[worker_id])]
|
| 66 |
+
self.n_used_elements[worker_id] += 1
|
| 67 |
+
return idx
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
dataset_prog = DataSetProg()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def filter_resolution(h,
|
| 74 |
+
w,
|
| 75 |
+
max_h_div_w_ratio=17 / 16,
|
| 76 |
+
min_h_div_w_ratio=8 / 16):
|
| 77 |
+
if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio:
|
| 78 |
+
return True
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class T2V_dataset(Dataset):
|
| 83 |
+
|
| 84 |
+
def __init__(self, args, transform, temporal_sample, tokenizer,
|
| 85 |
+
transform_topcrop):
|
| 86 |
+
self.data = args.data_merge_path
|
| 87 |
+
self.num_frames = args.num_frames
|
| 88 |
+
self.train_fps = args.train_fps
|
| 89 |
+
self.use_image_num = args.use_image_num
|
| 90 |
+
self.transform = transform
|
| 91 |
+
self.transform_topcrop = transform_topcrop
|
| 92 |
+
self.temporal_sample = temporal_sample
|
| 93 |
+
self.tokenizer = tokenizer
|
| 94 |
+
self.text_max_length = args.text_max_length
|
| 95 |
+
self.cfg = args.cfg
|
| 96 |
+
self.speed_factor = args.speed_factor
|
| 97 |
+
self.max_height = args.max_height
|
| 98 |
+
self.max_width = args.max_width
|
| 99 |
+
self.drop_short_ratio = args.drop_short_ratio
|
| 100 |
+
assert self.speed_factor >= 1
|
| 101 |
+
self.v_decoder = DecordInit()
|
| 102 |
+
self.video_length_tolerance_range = args.video_length_tolerance_range
|
| 103 |
+
self.support_Chinese = True
|
| 104 |
+
if "mt5" not in args.text_encoder_name:
|
| 105 |
+
self.support_Chinese = False
|
| 106 |
+
|
| 107 |
+
cap_list = self.get_cap_list()
|
| 108 |
+
|
| 109 |
+
assert len(cap_list) > 0
|
| 110 |
+
cap_list, self.sample_num_frames = self.define_frame_index(cap_list)
|
| 111 |
+
self.lengths = self.sample_num_frames
|
| 112 |
+
|
| 113 |
+
n_elements = len(cap_list)
|
| 114 |
+
dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list,
|
| 115 |
+
n_elements)
|
| 116 |
+
|
| 117 |
+
print(f"video length: {len(dataset_prog.cap_list)}", flush=True)
|
| 118 |
+
|
| 119 |
+
def set_checkpoint(self, n_used_elements):
|
| 120 |
+
for i in range(len(dataset_prog.n_used_elements)):
|
| 121 |
+
dataset_prog.n_used_elements[i] = n_used_elements
|
| 122 |
+
|
| 123 |
+
def __len__(self):
|
| 124 |
+
return dataset_prog.n_elements
|
| 125 |
+
|
| 126 |
+
def __getitem__(self, idx):
|
| 127 |
+
|
| 128 |
+
data = self.get_data(idx)
|
| 129 |
+
return data
|
| 130 |
+
|
| 131 |
+
def get_data(self, idx):
|
| 132 |
+
path = dataset_prog.cap_list[idx]["path"]
|
| 133 |
+
if path.endswith(".mp4"):
|
| 134 |
+
return self.get_video(idx)
|
| 135 |
+
else:
|
| 136 |
+
return self.get_image(idx)
|
| 137 |
+
|
| 138 |
+
def get_video(self, idx):
|
| 139 |
+
video_path = dataset_prog.cap_list[idx]["path"]
|
| 140 |
+
assert os.path.exists(video_path), f"file {video_path} do not exist!"
|
| 141 |
+
frame_indices = dataset_prog.cap_list[idx]["sample_frame_index"]
|
| 142 |
+
torchvision_video, _, metadata = torchvision.io.read_video(
|
| 143 |
+
video_path, output_format="TCHW")
|
| 144 |
+
video = torchvision_video[frame_indices]
|
| 145 |
+
video = self.transform(video)
|
| 146 |
+
video = rearrange(video, "t c h w -> c t h w")
|
| 147 |
+
video = video.to(torch.uint8)
|
| 148 |
+
assert video.dtype == torch.uint8
|
| 149 |
+
|
| 150 |
+
h, w = video.shape[-2:]
|
| 151 |
+
assert (
|
| 152 |
+
h / w <= 17 / 16 and h / w >= 8 / 16
|
| 153 |
+
), f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}"
|
| 154 |
+
|
| 155 |
+
video = video.float() / 127.5 - 1.0
|
| 156 |
+
|
| 157 |
+
text = dataset_prog.cap_list[idx]["cap"]
|
| 158 |
+
if not isinstance(text, list):
|
| 159 |
+
text = [text]
|
| 160 |
+
text = [random.choice(text)]
|
| 161 |
+
|
| 162 |
+
text = text[0] if random.random() > self.cfg else ""
|
| 163 |
+
text_tokens_and_mask = self.tokenizer(
|
| 164 |
+
text,
|
| 165 |
+
max_length=self.text_max_length,
|
| 166 |
+
padding="max_length",
|
| 167 |
+
truncation=True,
|
| 168 |
+
return_attention_mask=True,
|
| 169 |
+
add_special_tokens=True,
|
| 170 |
+
return_tensors="pt",
|
| 171 |
+
)
|
| 172 |
+
input_ids = text_tokens_and_mask["input_ids"]
|
| 173 |
+
cond_mask = text_tokens_and_mask["attention_mask"]
|
| 174 |
+
return dict(
|
| 175 |
+
pixel_values=video,
|
| 176 |
+
text=text,
|
| 177 |
+
input_ids=input_ids,
|
| 178 |
+
cond_mask=cond_mask,
|
| 179 |
+
path=video_path,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def get_image(self, idx):
|
| 183 |
+
image_data = dataset_prog.cap_list[
|
| 184 |
+
idx] # [{'path': path, 'cap': cap}, ...]
|
| 185 |
+
|
| 186 |
+
image = Image.open(image_data["path"]).convert("RGB") # [h, w, c]
|
| 187 |
+
image = torch.from_numpy(np.array(image)) # [h, w, c]
|
| 188 |
+
image = rearrange(image, "h w c -> c h w").unsqueeze(0) # [1 c h w]
|
| 189 |
+
# for i in image:
|
| 190 |
+
# h, w = i.shape[-2:]
|
| 191 |
+
# assert h / w <= 17 / 16 and h / w >= 8 / 16, f'Only image with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But found ratio is {round(h / w, 2)} with the shape of {i.shape}'
|
| 192 |
+
|
| 193 |
+
image = (self.transform_topcrop(image) if "human_images"
|
| 194 |
+
in image_data["path"] else self.transform(image)
|
| 195 |
+
) # [1 C H W] -> num_img [1 C H W]
|
| 196 |
+
image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W]
|
| 197 |
+
|
| 198 |
+
image = image.float() / 127.5 - 1.0
|
| 199 |
+
|
| 200 |
+
caps = (image_data["cap"] if isinstance(image_data["cap"], list) else
|
| 201 |
+
[image_data["cap"]])
|
| 202 |
+
caps = [random.choice(caps)]
|
| 203 |
+
text = caps
|
| 204 |
+
input_ids, cond_mask = [], []
|
| 205 |
+
text = text[0] if random.random() > self.cfg else ""
|
| 206 |
+
text_tokens_and_mask = self.tokenizer(
|
| 207 |
+
text,
|
| 208 |
+
max_length=self.text_max_length,
|
| 209 |
+
padding="max_length",
|
| 210 |
+
truncation=True,
|
| 211 |
+
return_attention_mask=True,
|
| 212 |
+
add_special_tokens=True,
|
| 213 |
+
return_tensors="pt",
|
| 214 |
+
)
|
| 215 |
+
input_ids = text_tokens_and_mask["input_ids"] # 1, l
|
| 216 |
+
cond_mask = text_tokens_and_mask["attention_mask"] # 1, l
|
| 217 |
+
return dict(
|
| 218 |
+
pixel_values=image,
|
| 219 |
+
text=text,
|
| 220 |
+
input_ids=input_ids,
|
| 221 |
+
cond_mask=cond_mask,
|
| 222 |
+
path=image_data["path"],
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def define_frame_index(self, cap_list):
|
| 226 |
+
new_cap_list = []
|
| 227 |
+
sample_num_frames = []
|
| 228 |
+
cnt_too_long = 0
|
| 229 |
+
cnt_too_short = 0
|
| 230 |
+
cnt_no_cap = 0
|
| 231 |
+
cnt_no_resolution = 0
|
| 232 |
+
cnt_resolution_mismatch = 0
|
| 233 |
+
cnt_movie = 0
|
| 234 |
+
cnt_img = 0
|
| 235 |
+
for i in cap_list:
|
| 236 |
+
path = i["path"]
|
| 237 |
+
cap = i.get("cap", None)
|
| 238 |
+
# ======no caption=====
|
| 239 |
+
if cap is None:
|
| 240 |
+
cnt_no_cap += 1
|
| 241 |
+
continue
|
| 242 |
+
if path.endswith(".mp4"):
|
| 243 |
+
# ======no fps and duration=====
|
| 244 |
+
duration = i.get("duration", None)
|
| 245 |
+
fps = i.get("fps", None)
|
| 246 |
+
if fps is None or duration is None:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
# ======resolution mismatch=====
|
| 250 |
+
resolution = i.get("resolution", None)
|
| 251 |
+
if resolution is None:
|
| 252 |
+
cnt_no_resolution += 1
|
| 253 |
+
continue
|
| 254 |
+
else:
|
| 255 |
+
if (resolution.get("height", None) is None
|
| 256 |
+
or resolution.get("width", None) is None):
|
| 257 |
+
cnt_no_resolution += 1
|
| 258 |
+
continue
|
| 259 |
+
height, width = i["resolution"]["height"], i["resolution"][
|
| 260 |
+
"width"]
|
| 261 |
+
aspect = self.max_height / self.max_width
|
| 262 |
+
hw_aspect_thr = 1.5
|
| 263 |
+
is_pick = filter_resolution(
|
| 264 |
+
height,
|
| 265 |
+
width,
|
| 266 |
+
max_h_div_w_ratio=hw_aspect_thr * aspect,
|
| 267 |
+
min_h_div_w_ratio=1 / hw_aspect_thr * aspect,
|
| 268 |
+
)
|
| 269 |
+
if not is_pick:
|
| 270 |
+
print("resolution mismatch")
|
| 271 |
+
cnt_resolution_mismatch += 1
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
# import ipdb;ipdb.set_trace()
|
| 275 |
+
i["num_frames"] = math.ceil(fps * duration)
|
| 276 |
+
# max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration.
|
| 277 |
+
if i["num_frames"] / fps > self.video_length_tolerance_range * (
|
| 278 |
+
self.num_frames / self.train_fps * self.speed_factor
|
| 279 |
+
): # too long video is not suitable for this training stage (self.num_frames)
|
| 280 |
+
cnt_too_long += 1
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
# resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)
|
| 284 |
+
frame_interval = fps / self.train_fps
|
| 285 |
+
start_frame_idx = 0
|
| 286 |
+
frame_indices = np.arange(start_frame_idx, i["num_frames"],
|
| 287 |
+
frame_interval).astype(int)
|
| 288 |
+
|
| 289 |
+
# comment out it to enable dynamic frames training
|
| 290 |
+
if (len(frame_indices) < self.num_frames
|
| 291 |
+
and random.random() < self.drop_short_ratio):
|
| 292 |
+
cnt_too_short += 1
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
# too long video will be temporal-crop randomly
|
| 296 |
+
if len(frame_indices) > self.num_frames:
|
| 297 |
+
begin_index, end_index = self.temporal_sample(
|
| 298 |
+
len(frame_indices))
|
| 299 |
+
frame_indices = frame_indices[begin_index:end_index]
|
| 300 |
+
# frame_indices = frame_indices[:self.num_frames] # head crop
|
| 301 |
+
i["sample_frame_index"] = frame_indices.tolist()
|
| 302 |
+
new_cap_list.append(i)
|
| 303 |
+
i["sample_num_frames"] = len(
|
| 304 |
+
i["sample_frame_index"]
|
| 305 |
+
) # will use in dataloader(group sampler)
|
| 306 |
+
sample_num_frames.append(i["sample_num_frames"])
|
| 307 |
+
elif path.endswith(".jpg"): # image
|
| 308 |
+
cnt_img += 1
|
| 309 |
+
new_cap_list.append(i)
|
| 310 |
+
i["sample_num_frames"] = 1
|
| 311 |
+
sample_num_frames.append(i["sample_num_frames"])
|
| 312 |
+
else:
|
| 313 |
+
raise NameError(
|
| 314 |
+
f"Unknown file extension {path.split('.')[-1]}, only support .mp4 for video and .jpg for image"
|
| 315 |
+
)
|
| 316 |
+
# import ipdb;ipdb.set_trace()
|
| 317 |
+
main_print(
|
| 318 |
+
f"no_cap: {cnt_no_cap}, too_long: {cnt_too_long}, too_short: {cnt_too_short}, "
|
| 319 |
+
f"no_resolution: {cnt_no_resolution}, resolution_mismatch: {cnt_resolution_mismatch}, "
|
| 320 |
+
f"Counter(sample_num_frames): {Counter(sample_num_frames)}, cnt_movie: {cnt_movie}, cnt_img: {cnt_img}, "
|
| 321 |
+
f"before filter: {len(cap_list)}, after filter: {len(new_cap_list)}"
|
| 322 |
+
)
|
| 323 |
+
return new_cap_list, sample_num_frames
|
| 324 |
+
|
| 325 |
+
def decord_read(self, path, frame_indices):
|
| 326 |
+
decord_vr = self.v_decoder(path)
|
| 327 |
+
video_data = decord_vr.get_batch(frame_indices).asnumpy()
|
| 328 |
+
video_data = torch.from_numpy(video_data)
|
| 329 |
+
video_data = video_data.permute(0, 3, 1,
|
| 330 |
+
2) # (T, H, W, C) -> (T C H W)
|
| 331 |
+
return video_data
|
| 332 |
+
|
| 333 |
+
def read_jsons(self, data):
|
| 334 |
+
cap_lists = []
|
| 335 |
+
with open(data, "r") as f:
|
| 336 |
+
folder_anno = [
|
| 337 |
+
i.strip().split(",") for i in f.readlines()
|
| 338 |
+
if len(i.strip()) > 0
|
| 339 |
+
]
|
| 340 |
+
print(folder_anno)
|
| 341 |
+
for folder, anno in folder_anno:
|
| 342 |
+
with open(anno, "r") as f:
|
| 343 |
+
sub_list = json.load(f)
|
| 344 |
+
for i in range(len(sub_list)):
|
| 345 |
+
sub_list[i]["path"] = opj(folder, sub_list[i]["path"])
|
| 346 |
+
cap_lists += sub_list
|
| 347 |
+
return cap_lists
|
| 348 |
+
|
| 349 |
+
def get_cap_list(self):
|
| 350 |
+
cap_lists = self.read_jsons(self.data)
|
| 351 |
+
return cap_lists
|
fastvideo/dataset/transform.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import numbers
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _is_tensor_video_clip(clip):
|
| 11 |
+
if not torch.is_tensor(clip):
|
| 12 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
| 13 |
+
|
| 14 |
+
if not clip.ndimension() == 4:
|
| 15 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
| 16 |
+
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def center_crop_arr(pil_image, image_size):
|
| 21 |
+
"""
|
| 22 |
+
Center cropping implementation from ADM.
|
| 23 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 24 |
+
"""
|
| 25 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 26 |
+
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size),
|
| 27 |
+
resample=Image.BOX)
|
| 28 |
+
|
| 29 |
+
scale = image_size / min(*pil_image.size)
|
| 30 |
+
pil_image = pil_image.resize(tuple(
|
| 31 |
+
round(x * scale) for x in pil_image.size),
|
| 32 |
+
resample=Image.BICUBIC)
|
| 33 |
+
|
| 34 |
+
arr = np.array(pil_image)
|
| 35 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 36 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 37 |
+
return Image.fromarray(arr[crop_y:crop_y + image_size,
|
| 38 |
+
crop_x:crop_x + image_size])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def crop(clip, i, j, h, w):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 45 |
+
"""
|
| 46 |
+
if len(clip.size()) != 4:
|
| 47 |
+
raise ValueError("clip should be a 4D tensor")
|
| 48 |
+
return clip[..., i:i + h, j:j + w]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def resize(clip, target_size, interpolation_mode):
|
| 52 |
+
if len(target_size) != 2:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"target size should be tuple (height, width), instead got {target_size}"
|
| 55 |
+
)
|
| 56 |
+
return torch.nn.functional.interpolate(
|
| 57 |
+
clip,
|
| 58 |
+
size=target_size,
|
| 59 |
+
mode=interpolation_mode,
|
| 60 |
+
align_corners=True,
|
| 61 |
+
antialias=True,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def resize_scale(clip, target_size, interpolation_mode):
|
| 66 |
+
if len(target_size) != 2:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"target size should be tuple (height, width), instead got {target_size}"
|
| 69 |
+
)
|
| 70 |
+
H, W = clip.size(-2), clip.size(-1)
|
| 71 |
+
scale_ = target_size[0] / min(H, W)
|
| 72 |
+
return torch.nn.functional.interpolate(
|
| 73 |
+
clip,
|
| 74 |
+
scale_factor=scale_,
|
| 75 |
+
mode=interpolation_mode,
|
| 76 |
+
align_corners=True,
|
| 77 |
+
antialias=True,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
| 82 |
+
"""
|
| 83 |
+
Do spatial cropping and resizing to the video clip
|
| 84 |
+
Args:
|
| 85 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 86 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
| 87 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
| 88 |
+
h (int): Height of the cropped region.
|
| 89 |
+
w (int): Width of the cropped region.
|
| 90 |
+
size (tuple(int, int)): height and width of resized clip
|
| 91 |
+
Returns:
|
| 92 |
+
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
| 93 |
+
"""
|
| 94 |
+
if not _is_tensor_video_clip(clip):
|
| 95 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 96 |
+
clip = crop(clip, i, j, h, w)
|
| 97 |
+
clip = resize(clip, size, interpolation_mode)
|
| 98 |
+
return clip
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def center_crop(clip, crop_size):
|
| 102 |
+
if not _is_tensor_video_clip(clip):
|
| 103 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 104 |
+
h, w = clip.size(-2), clip.size(-1)
|
| 105 |
+
th, tw = crop_size
|
| 106 |
+
if h < th or w < tw:
|
| 107 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
| 108 |
+
|
| 109 |
+
i = int(round((h - th) / 2.0))
|
| 110 |
+
j = int(round((w - tw) / 2.0))
|
| 111 |
+
return crop(clip, i, j, th, tw)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def center_crop_using_short_edge(clip):
|
| 115 |
+
if not _is_tensor_video_clip(clip):
|
| 116 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 117 |
+
h, w = clip.size(-2), clip.size(-1)
|
| 118 |
+
if h < w:
|
| 119 |
+
th, tw = h, h
|
| 120 |
+
i = 0
|
| 121 |
+
j = int(round((w - tw) / 2.0))
|
| 122 |
+
else:
|
| 123 |
+
th, tw = w, w
|
| 124 |
+
i = int(round((h - th) / 2.0))
|
| 125 |
+
j = 0
|
| 126 |
+
return crop(clip, i, j, th, tw)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def center_crop_th_tw(clip, th, tw, top_crop):
|
| 130 |
+
if not _is_tensor_video_clip(clip):
|
| 131 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 132 |
+
|
| 133 |
+
# import ipdb;ipdb.set_trace()
|
| 134 |
+
h, w = clip.size(-2), clip.size(-1)
|
| 135 |
+
tr = th / tw
|
| 136 |
+
if h / w > tr:
|
| 137 |
+
new_h = int(w * tr)
|
| 138 |
+
new_w = w
|
| 139 |
+
else:
|
| 140 |
+
new_h = h
|
| 141 |
+
new_w = int(h / tr)
|
| 142 |
+
|
| 143 |
+
i = 0 if top_crop else int(round((h - new_h) / 2.0))
|
| 144 |
+
j = int(round((w - new_w) / 2.0))
|
| 145 |
+
return crop(clip, i, j, new_h, new_w)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def random_shift_crop(clip):
|
| 149 |
+
"""
|
| 150 |
+
Slide along the long edge, with the short edge as crop size
|
| 151 |
+
"""
|
| 152 |
+
if not _is_tensor_video_clip(clip):
|
| 153 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 154 |
+
h, w = clip.size(-2), clip.size(-1)
|
| 155 |
+
|
| 156 |
+
if h <= w:
|
| 157 |
+
short_edge = h
|
| 158 |
+
else:
|
| 159 |
+
short_edge = w
|
| 160 |
+
|
| 161 |
+
th, tw = short_edge, short_edge
|
| 162 |
+
|
| 163 |
+
i = torch.randint(0, h - th + 1, size=(1, )).item()
|
| 164 |
+
j = torch.randint(0, w - tw + 1, size=(1, )).item()
|
| 165 |
+
return crop(clip, i, j, th, tw)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def normalize_video(clip):
|
| 169 |
+
"""
|
| 170 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
| 171 |
+
permute the dimensions of clip tensor
|
| 172 |
+
Args:
|
| 173 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
| 174 |
+
Return:
|
| 175 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
| 176 |
+
"""
|
| 177 |
+
_is_tensor_video_clip(clip)
|
| 178 |
+
if not clip.dtype == torch.uint8:
|
| 179 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" %
|
| 180 |
+
str(clip.dtype))
|
| 181 |
+
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
| 182 |
+
return clip.float() / 255.0
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def normalize(clip, mean, std, inplace=False):
|
| 186 |
+
"""
|
| 187 |
+
Args:
|
| 188 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
| 189 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
| 190 |
+
std (tuple): pixel standard deviation. Size is (3)
|
| 191 |
+
Returns:
|
| 192 |
+
normalized clip (torch.tensor): Size is (T, C, H, W)
|
| 193 |
+
"""
|
| 194 |
+
if not _is_tensor_video_clip(clip):
|
| 195 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 196 |
+
if not inplace:
|
| 197 |
+
clip = clip.clone()
|
| 198 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
| 199 |
+
# print(mean)
|
| 200 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
| 201 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
| 202 |
+
return clip
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def hflip(clip):
|
| 206 |
+
"""
|
| 207 |
+
Args:
|
| 208 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
| 209 |
+
Returns:
|
| 210 |
+
flipped clip (torch.tensor): Size is (T, C, H, W)
|
| 211 |
+
"""
|
| 212 |
+
if not _is_tensor_video_clip(clip):
|
| 213 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 214 |
+
return clip.flip(-1)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class RandomCropVideo:
|
| 218 |
+
|
| 219 |
+
def __init__(self, size):
|
| 220 |
+
if isinstance(size, numbers.Number):
|
| 221 |
+
self.size = (int(size), int(size))
|
| 222 |
+
else:
|
| 223 |
+
self.size = size
|
| 224 |
+
|
| 225 |
+
def __call__(self, clip):
|
| 226 |
+
"""
|
| 227 |
+
Args:
|
| 228 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 229 |
+
Returns:
|
| 230 |
+
torch.tensor: randomly cropped video clip.
|
| 231 |
+
size is (T, C, OH, OW)
|
| 232 |
+
"""
|
| 233 |
+
i, j, h, w = self.get_params(clip)
|
| 234 |
+
return crop(clip, i, j, h, w)
|
| 235 |
+
|
| 236 |
+
def get_params(self, clip):
|
| 237 |
+
h, w = clip.shape[-2:]
|
| 238 |
+
th, tw = self.size
|
| 239 |
+
|
| 240 |
+
if h < th or w < tw:
|
| 241 |
+
raise ValueError(
|
| 242 |
+
f"Required crop size {(th, tw)} is larger than input image size {(h, w)}"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if w == tw and h == th:
|
| 246 |
+
return 0, 0, h, w
|
| 247 |
+
|
| 248 |
+
i = torch.randint(0, h - th + 1, size=(1, )).item()
|
| 249 |
+
j = torch.randint(0, w - tw + 1, size=(1, )).item()
|
| 250 |
+
|
| 251 |
+
return i, j, th, tw
|
| 252 |
+
|
| 253 |
+
def __repr__(self) -> str:
|
| 254 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class SpatialStrideCropVideo:
|
| 258 |
+
|
| 259 |
+
def __init__(self, stride):
|
| 260 |
+
self.stride = stride
|
| 261 |
+
|
| 262 |
+
def __call__(self, clip):
|
| 263 |
+
"""
|
| 264 |
+
Args:
|
| 265 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 266 |
+
Returns:
|
| 267 |
+
torch.tensor: cropped video clip by stride.
|
| 268 |
+
size is (T, C, OH, OW)
|
| 269 |
+
"""
|
| 270 |
+
i, j, h, w = self.get_params(clip)
|
| 271 |
+
return crop(clip, i, j, h, w)
|
| 272 |
+
|
| 273 |
+
def get_params(self, clip):
|
| 274 |
+
h, w = clip.shape[-2:]
|
| 275 |
+
|
| 276 |
+
th, tw = h // self.stride * self.stride, w // self.stride * self.stride
|
| 277 |
+
|
| 278 |
+
return 0, 0, th, tw # from top-left
|
| 279 |
+
|
| 280 |
+
def __repr__(self) -> str:
|
| 281 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class LongSideResizeVideo:
|
| 285 |
+
"""
|
| 286 |
+
First use the long side,
|
| 287 |
+
then resize to the specified size
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
def __init__(
|
| 291 |
+
self,
|
| 292 |
+
size,
|
| 293 |
+
skip_low_resolution=False,
|
| 294 |
+
interpolation_mode="bilinear",
|
| 295 |
+
):
|
| 296 |
+
self.size = size
|
| 297 |
+
self.skip_low_resolution = skip_low_resolution
|
| 298 |
+
self.interpolation_mode = interpolation_mode
|
| 299 |
+
|
| 300 |
+
def __call__(self, clip):
|
| 301 |
+
"""
|
| 302 |
+
Args:
|
| 303 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 304 |
+
Returns:
|
| 305 |
+
torch.tensor: scale resized video clip.
|
| 306 |
+
size is (T, C, 512, *) or (T, C, *, 512)
|
| 307 |
+
"""
|
| 308 |
+
_, _, h, w = clip.shape
|
| 309 |
+
if self.skip_low_resolution and max(h, w) <= self.size:
|
| 310 |
+
return clip
|
| 311 |
+
if h > w:
|
| 312 |
+
w = int(w * self.size / h)
|
| 313 |
+
h = self.size
|
| 314 |
+
else:
|
| 315 |
+
h = int(h * self.size / w)
|
| 316 |
+
w = self.size
|
| 317 |
+
resize_clip = resize(clip,
|
| 318 |
+
target_size=(h, w),
|
| 319 |
+
interpolation_mode=self.interpolation_mode)
|
| 320 |
+
return resize_clip
|
| 321 |
+
|
| 322 |
+
def __repr__(self) -> str:
|
| 323 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class CenterCropResizeVideo:
|
| 327 |
+
"""
|
| 328 |
+
First use the short side for cropping length,
|
| 329 |
+
center crop video, then resize to the specified size
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
def __init__(
|
| 333 |
+
self,
|
| 334 |
+
size,
|
| 335 |
+
top_crop=False,
|
| 336 |
+
interpolation_mode="bilinear",
|
| 337 |
+
):
|
| 338 |
+
if len(size) != 2:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"size should be tuple (height, width), instead got {size}")
|
| 341 |
+
self.size = size
|
| 342 |
+
self.top_crop = top_crop
|
| 343 |
+
self.interpolation_mode = interpolation_mode
|
| 344 |
+
|
| 345 |
+
def __call__(self, clip):
|
| 346 |
+
"""
|
| 347 |
+
Args:
|
| 348 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 349 |
+
Returns:
|
| 350 |
+
torch.tensor: scale resized / center cropped video clip.
|
| 351 |
+
size is (T, C, crop_size, crop_size)
|
| 352 |
+
"""
|
| 353 |
+
# clip_center_crop = center_crop_using_short_edge(clip)
|
| 354 |
+
clip_center_crop = center_crop_th_tw(clip,
|
| 355 |
+
self.size[0],
|
| 356 |
+
self.size[1],
|
| 357 |
+
top_crop=self.top_crop)
|
| 358 |
+
# import ipdb;ipdb.set_trace()
|
| 359 |
+
clip_center_crop_resize = resize(
|
| 360 |
+
clip_center_crop,
|
| 361 |
+
target_size=self.size,
|
| 362 |
+
interpolation_mode=self.interpolation_mode,
|
| 363 |
+
)
|
| 364 |
+
return clip_center_crop_resize
|
| 365 |
+
|
| 366 |
+
def __repr__(self) -> str:
|
| 367 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class UCFCenterCropVideo:
|
| 371 |
+
"""
|
| 372 |
+
First scale to the specified size in equal proportion to the short edge,
|
| 373 |
+
then center cropping
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
def __init__(
|
| 377 |
+
self,
|
| 378 |
+
size,
|
| 379 |
+
interpolation_mode="bilinear",
|
| 380 |
+
):
|
| 381 |
+
if isinstance(size, tuple):
|
| 382 |
+
if len(size) != 2:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
f"size should be tuple (height, width), instead got {size}"
|
| 385 |
+
)
|
| 386 |
+
self.size = size
|
| 387 |
+
else:
|
| 388 |
+
self.size = (size, size)
|
| 389 |
+
|
| 390 |
+
self.interpolation_mode = interpolation_mode
|
| 391 |
+
|
| 392 |
+
def __call__(self, clip):
|
| 393 |
+
"""
|
| 394 |
+
Args:
|
| 395 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 396 |
+
Returns:
|
| 397 |
+
torch.tensor: scale resized / center cropped video clip.
|
| 398 |
+
size is (T, C, crop_size, crop_size)
|
| 399 |
+
"""
|
| 400 |
+
clip_resize = resize_scale(clip=clip,
|
| 401 |
+
target_size=self.size,
|
| 402 |
+
interpolation_mode=self.interpolation_mode)
|
| 403 |
+
clip_center_crop = center_crop(clip_resize, self.size)
|
| 404 |
+
return clip_center_crop
|
| 405 |
+
|
| 406 |
+
def __repr__(self) -> str:
|
| 407 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class KineticsRandomCropResizeVideo:
|
| 411 |
+
"""
|
| 412 |
+
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
def __init__(
|
| 416 |
+
self,
|
| 417 |
+
size,
|
| 418 |
+
interpolation_mode="bilinear",
|
| 419 |
+
):
|
| 420 |
+
if isinstance(size, tuple):
|
| 421 |
+
if len(size) != 2:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"size should be tuple (height, width), instead got {size}"
|
| 424 |
+
)
|
| 425 |
+
self.size = size
|
| 426 |
+
else:
|
| 427 |
+
self.size = (size, size)
|
| 428 |
+
|
| 429 |
+
self.interpolation_mode = interpolation_mode
|
| 430 |
+
|
| 431 |
+
def __call__(self, clip):
|
| 432 |
+
clip_random_crop = random_shift_crop(clip)
|
| 433 |
+
clip_resize = resize(clip_random_crop, self.size,
|
| 434 |
+
self.interpolation_mode)
|
| 435 |
+
return clip_resize
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class CenterCropVideo:
|
| 439 |
+
|
| 440 |
+
def __init__(
|
| 441 |
+
self,
|
| 442 |
+
size,
|
| 443 |
+
interpolation_mode="bilinear",
|
| 444 |
+
):
|
| 445 |
+
if isinstance(size, tuple):
|
| 446 |
+
if len(size) != 2:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
f"size should be tuple (height, width), instead got {size}"
|
| 449 |
+
)
|
| 450 |
+
self.size = size
|
| 451 |
+
else:
|
| 452 |
+
self.size = (size, size)
|
| 453 |
+
|
| 454 |
+
self.interpolation_mode = interpolation_mode
|
| 455 |
+
|
| 456 |
+
def __call__(self, clip):
|
| 457 |
+
"""
|
| 458 |
+
Args:
|
| 459 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
| 460 |
+
Returns:
|
| 461 |
+
torch.tensor: center cropped video clip.
|
| 462 |
+
size is (T, C, crop_size, crop_size)
|
| 463 |
+
"""
|
| 464 |
+
clip_center_crop = center_crop(clip, self.size)
|
| 465 |
+
return clip_center_crop
|
| 466 |
+
|
| 467 |
+
def __repr__(self) -> str:
|
| 468 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class Normalize:
|
| 472 |
+
"""
|
| 473 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
| 474 |
+
Args:
|
| 475 |
+
mean (3-tuple): pixel RGB mean
|
| 476 |
+
std (3-tuple): pixel RGB standard deviation
|
| 477 |
+
inplace (boolean): whether do in-place normalization
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
def __init__(self, mean, std, inplace=False):
|
| 481 |
+
self.mean = mean
|
| 482 |
+
self.std = std
|
| 483 |
+
self.inplace = inplace
|
| 484 |
+
|
| 485 |
+
def __call__(self, clip):
|
| 486 |
+
"""
|
| 487 |
+
Args:
|
| 488 |
+
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
| 489 |
+
"""
|
| 490 |
+
return normalize(clip, self.mean, self.std, self.inplace)
|
| 491 |
+
|
| 492 |
+
def __repr__(self) -> str:
|
| 493 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
class Normalize255:
|
| 497 |
+
"""
|
| 498 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
def __init__(self):
|
| 502 |
+
pass
|
| 503 |
+
|
| 504 |
+
def __call__(self, clip):
|
| 505 |
+
"""
|
| 506 |
+
Args:
|
| 507 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
| 508 |
+
Return:
|
| 509 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
| 510 |
+
"""
|
| 511 |
+
return normalize_video(clip)
|
| 512 |
+
|
| 513 |
+
def __repr__(self) -> str:
|
| 514 |
+
return self.__class__.__name__
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
class RandomHorizontalFlipVideo:
|
| 518 |
+
"""
|
| 519 |
+
Flip the video clip along the horizontal direction with a given probability
|
| 520 |
+
Args:
|
| 521 |
+
p (float): probability of the clip being flipped. Default value is 0.5
|
| 522 |
+
"""
|
| 523 |
+
|
| 524 |
+
def __init__(self, p=0.5):
|
| 525 |
+
self.p = p
|
| 526 |
+
|
| 527 |
+
def __call__(self, clip):
|
| 528 |
+
"""
|
| 529 |
+
Args:
|
| 530 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
| 531 |
+
Return:
|
| 532 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
| 533 |
+
"""
|
| 534 |
+
if random.random() < self.p:
|
| 535 |
+
clip = hflip(clip)
|
| 536 |
+
return clip
|
| 537 |
+
|
| 538 |
+
def __repr__(self) -> str:
|
| 539 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# ------------------------------------------------------------
|
| 543 |
+
# --------------------- Sampling ---------------------------
|
| 544 |
+
# ------------------------------------------------------------
|
| 545 |
+
class TemporalRandomCrop(object):
|
| 546 |
+
"""Temporally crop the given frame indices at a random location.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
size (int): Desired length of frames will be seen in the model.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
def __init__(self, size):
|
| 553 |
+
self.size = size
|
| 554 |
+
|
| 555 |
+
def __call__(self, total_frames):
|
| 556 |
+
rand_end = max(0, total_frames - self.size - 1)
|
| 557 |
+
begin_index = random.randint(0, rand_end)
|
| 558 |
+
end_index = min(begin_index + self.size, total_frames)
|
| 559 |
+
return begin_index, end_index
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class DynamicSampleDuration(object):
|
| 563 |
+
"""Temporally crop the given frame indices at a random location.
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
size (int): Desired length of frames will be seen in the model.
|
| 567 |
+
"""
|
| 568 |
+
|
| 569 |
+
def __init__(self, t_stride, extra_1):
|
| 570 |
+
self.t_stride = t_stride
|
| 571 |
+
self.extra_1 = extra_1
|
| 572 |
+
|
| 573 |
+
def __call__(self, t, h, w):
|
| 574 |
+
if self.extra_1:
|
| 575 |
+
t = t - 1
|
| 576 |
+
truncate_t_list = list(
|
| 577 |
+
range(t + 1))[t // 2:][::self.t_stride] # need half at least
|
| 578 |
+
truncate_t = random.choice(truncate_t_list)
|
| 579 |
+
if self.extra_1:
|
| 580 |
+
truncate_t = truncate_t + 1
|
| 581 |
+
return 0, truncate_t
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
if __name__ == "__main__":
|
| 585 |
+
import os
|
| 586 |
+
|
| 587 |
+
import numpy as np
|
| 588 |
+
import torchvision.io as io
|
| 589 |
+
from torchvision import transforms
|
| 590 |
+
from torchvision.utils import save_image
|
| 591 |
+
|
| 592 |
+
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi",
|
| 593 |
+
pts_unit="sec",
|
| 594 |
+
output_format="TCHW")
|
| 595 |
+
|
| 596 |
+
trans = transforms.Compose([
|
| 597 |
+
Normalize255(),
|
| 598 |
+
RandomHorizontalFlipVideo(),
|
| 599 |
+
UCFCenterCropVideo(512),
|
| 600 |
+
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 601 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
| 602 |
+
std=[0.5, 0.5, 0.5],
|
| 603 |
+
inplace=True),
|
| 604 |
+
])
|
| 605 |
+
|
| 606 |
+
target_video_len = 32
|
| 607 |
+
frame_interval = 1
|
| 608 |
+
total_frames = len(vframes)
|
| 609 |
+
print(total_frames)
|
| 610 |
+
|
| 611 |
+
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
|
| 612 |
+
|
| 613 |
+
# Sampling video frames
|
| 614 |
+
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
| 615 |
+
# print(start_frame_ind)
|
| 616 |
+
# print(end_frame_ind)
|
| 617 |
+
assert end_frame_ind - start_frame_ind >= target_video_len
|
| 618 |
+
frame_indice = np.linspace(start_frame_ind,
|
| 619 |
+
end_frame_ind - 1,
|
| 620 |
+
target_video_len,
|
| 621 |
+
dtype=int)
|
| 622 |
+
print(frame_indice)
|
| 623 |
+
|
| 624 |
+
select_vframes = vframes[frame_indice]
|
| 625 |
+
print(select_vframes.shape)
|
| 626 |
+
print(select_vframes.dtype)
|
| 627 |
+
|
| 628 |
+
select_vframes_trans = trans(select_vframes)
|
| 629 |
+
print(select_vframes_trans.shape)
|
| 630 |
+
print(select_vframes_trans.dtype)
|
| 631 |
+
|
| 632 |
+
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) *
|
| 633 |
+
255).to(dtype=torch.uint8)
|
| 634 |
+
print(select_vframes_trans_int.dtype)
|
| 635 |
+
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
|
| 636 |
+
|
| 637 |
+
io.write_video("./test.avi",
|
| 638 |
+
select_vframes_trans_int.permute(0, 2, 3, 1),
|
| 639 |
+
fps=8)
|
| 640 |
+
|
| 641 |
+
for i in range(target_video_len):
|
| 642 |
+
save_image(
|
| 643 |
+
select_vframes_trans[i],
|
| 644 |
+
os.path.join("./test000", "%04d.png" % i),
|
| 645 |
+
normalize=True,
|
| 646 |
+
value_range=(-1, 1),
|
| 647 |
+
)
|
fastvideo/distill/__init__.py
ADDED
|
File without changes
|
fastvideo/distill/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
fastvideo/distill/__pycache__/solver.cpython-312.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
fastvideo/distill/discriminator.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from diffusers.utils import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DiscriminatorHead(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, input_channel, output_channel=1):
|
| 12 |
+
super().__init__()
|
| 13 |
+
inner_channel = 1024
|
| 14 |
+
self.conv1 = nn.Sequential(
|
| 15 |
+
nn.Conv2d(input_channel, inner_channel, 1, 1, 0),
|
| 16 |
+
nn.GroupNorm(32, inner_channel),
|
| 17 |
+
nn.LeakyReLU(
|
| 18 |
+
inplace=True
|
| 19 |
+
), # use LeakyReLu instead of GELU shown in the paper to save memory
|
| 20 |
+
)
|
| 21 |
+
self.conv2 = nn.Sequential(
|
| 22 |
+
nn.Conv2d(inner_channel, inner_channel, 1, 1, 0),
|
| 23 |
+
nn.GroupNorm(32, inner_channel),
|
| 24 |
+
nn.LeakyReLU(
|
| 25 |
+
inplace=True
|
| 26 |
+
), # use LeakyReLu instead of GELU shown in the paper to save memory
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.conv_out = nn.Conv2d(inner_channel, output_channel, 1, 1, 0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
b, twh, c = x.shape
|
| 33 |
+
t = twh // (30 * 53)
|
| 34 |
+
x = x.view(-1, 30 * 53, c)
|
| 35 |
+
x = x.permute(0, 2, 1)
|
| 36 |
+
x = x.view(b * t, c, 30, 53)
|
| 37 |
+
x = self.conv1(x)
|
| 38 |
+
x = self.conv2(x) + x
|
| 39 |
+
x = self.conv_out(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Discriminator(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
stride=8,
|
| 48 |
+
num_h_per_head=1,
|
| 49 |
+
adapter_channel_dims=[3072],
|
| 50 |
+
total_layers=48,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
adapter_channel_dims = adapter_channel_dims * (total_layers // stride)
|
| 54 |
+
self.stride = stride
|
| 55 |
+
self.num_h_per_head = num_h_per_head
|
| 56 |
+
self.head_num = len(adapter_channel_dims)
|
| 57 |
+
self.heads = nn.ModuleList([
|
| 58 |
+
nn.ModuleList([
|
| 59 |
+
DiscriminatorHead(adapter_channel)
|
| 60 |
+
for _ in range(self.num_h_per_head)
|
| 61 |
+
]) for adapter_channel in adapter_channel_dims
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
def forward(self, features):
|
| 65 |
+
outputs = []
|
| 66 |
+
|
| 67 |
+
def create_custom_forward(module):
|
| 68 |
+
|
| 69 |
+
def custom_forward(*inputs):
|
| 70 |
+
return module(*inputs)
|
| 71 |
+
|
| 72 |
+
return custom_forward
|
| 73 |
+
|
| 74 |
+
assert len(features) == len(self.heads)
|
| 75 |
+
for i in range(0, len(features)):
|
| 76 |
+
for h in self.heads[i]:
|
| 77 |
+
# out = torch.utils.checkpoint.checkpoint(
|
| 78 |
+
# create_custom_forward(h),
|
| 79 |
+
# features[i],
|
| 80 |
+
# use_reentrant=False
|
| 81 |
+
# )
|
| 82 |
+
out = h(features[i])
|
| 83 |
+
outputs.append(out)
|
| 84 |
+
return outputs
|
fastvideo/distill/solver.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 10 |
+
from diffusers.utils import BaseOutput, logging
|
| 11 |
+
|
| 12 |
+
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
|
| 13 |
+
|
| 14 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class PCMFMSchedulerOutput(BaseOutput):
|
| 19 |
+
prev_sample: torch.FloatTensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def extract_into_tensor(a, t, x_shape):
|
| 23 |
+
b, *_ = t.shape
|
| 24 |
+
out = a.gather(-1, t)
|
| 25 |
+
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PCMFMScheduler(SchedulerMixin, ConfigMixin):
|
| 29 |
+
_compatibles = []
|
| 30 |
+
order = 1
|
| 31 |
+
|
| 32 |
+
@register_to_config
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
num_train_timesteps: int = 1000,
|
| 36 |
+
shift: float = 1.0,
|
| 37 |
+
pcm_timesteps: int = 50,
|
| 38 |
+
linear_quadratic=False,
|
| 39 |
+
linear_quadratic_threshold=0.025,
|
| 40 |
+
linear_range=0.5,
|
| 41 |
+
):
|
| 42 |
+
if linear_quadratic:
|
| 43 |
+
linear_steps = int(num_train_timesteps * linear_range)
|
| 44 |
+
sigmas = linear_quadratic_schedule(num_train_timesteps,
|
| 45 |
+
linear_quadratic_threshold,
|
| 46 |
+
linear_steps)
|
| 47 |
+
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
|
| 48 |
+
else:
|
| 49 |
+
timesteps = np.linspace(1,
|
| 50 |
+
num_train_timesteps,
|
| 51 |
+
num_train_timesteps,
|
| 52 |
+
dtype=np.float32)[::-1].copy()
|
| 53 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
| 54 |
+
sigmas = timesteps / num_train_timesteps
|
| 55 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 56 |
+
self.euler_timesteps = (np.arange(1, pcm_timesteps + 1) *
|
| 57 |
+
(num_train_timesteps //
|
| 58 |
+
pcm_timesteps)).round().astype(np.int64) - 1
|
| 59 |
+
self.sigmas = sigmas.numpy()[::-1][self.euler_timesteps]
|
| 60 |
+
self.sigmas = torch.from_numpy((self.sigmas[::-1].copy()))
|
| 61 |
+
self.timesteps = self.sigmas * num_train_timesteps
|
| 62 |
+
self._step_index = None
|
| 63 |
+
self._begin_index = None
|
| 64 |
+
self.sigmas = self.sigmas.to(
|
| 65 |
+
"cpu") # to avoid too much CPU/GPU communication
|
| 66 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 67 |
+
self.sigma_max = self.sigmas[0].item()
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def step_index(self):
|
| 71 |
+
"""
|
| 72 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 73 |
+
"""
|
| 74 |
+
return self._step_index
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def begin_index(self):
|
| 78 |
+
"""
|
| 79 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 80 |
+
"""
|
| 81 |
+
return self._begin_index
|
| 82 |
+
|
| 83 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 84 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 85 |
+
"""
|
| 86 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
begin_index (`int`):
|
| 90 |
+
The begin index for the scheduler.
|
| 91 |
+
"""
|
| 92 |
+
self._begin_index = begin_index
|
| 93 |
+
|
| 94 |
+
def scale_noise(
|
| 95 |
+
self,
|
| 96 |
+
sample: torch.FloatTensor,
|
| 97 |
+
timestep: Union[float, torch.FloatTensor],
|
| 98 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 99 |
+
) -> torch.FloatTensor:
|
| 100 |
+
"""
|
| 101 |
+
Forward process in flow-matching
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
sample (`torch.FloatTensor`):
|
| 105 |
+
The input sample.
|
| 106 |
+
timestep (`int`, *optional*):
|
| 107 |
+
The current timestep in the diffusion chain.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
`torch.FloatTensor`:
|
| 111 |
+
A scaled input sample.
|
| 112 |
+
"""
|
| 113 |
+
if self.step_index is None:
|
| 114 |
+
self._init_step_index(timestep)
|
| 115 |
+
|
| 116 |
+
sigma = self.sigmas[self.step_index]
|
| 117 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
| 118 |
+
|
| 119 |
+
return sample
|
| 120 |
+
|
| 121 |
+
def _sigma_to_t(self, sigma):
|
| 122 |
+
return sigma * self.config.num_train_timesteps
|
| 123 |
+
|
| 124 |
+
def set_timesteps(self,
|
| 125 |
+
num_inference_steps: int,
|
| 126 |
+
device: Union[str, torch.device] = None):
|
| 127 |
+
"""
|
| 128 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
num_inference_steps (`int`):
|
| 132 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 133 |
+
device (`str` or `torch.device`, *optional*):
|
| 134 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 135 |
+
"""
|
| 136 |
+
self.num_inference_steps = num_inference_steps
|
| 137 |
+
inference_indices = np.linspace(0,
|
| 138 |
+
self.config.pcm_timesteps,
|
| 139 |
+
num=num_inference_steps,
|
| 140 |
+
endpoint=False)
|
| 141 |
+
inference_indices = np.floor(inference_indices).astype(np.int64)
|
| 142 |
+
inference_indices = torch.from_numpy(inference_indices).long()
|
| 143 |
+
|
| 144 |
+
self.sigmas_ = self.sigmas[inference_indices]
|
| 145 |
+
timesteps = self.sigmas_ * self.config.num_train_timesteps
|
| 146 |
+
self.timesteps = timesteps.to(device=device)
|
| 147 |
+
self.sigmas_ = torch.cat(
|
| 148 |
+
[self.sigmas_,
|
| 149 |
+
torch.zeros(1, device=self.sigmas_.device)])
|
| 150 |
+
self._step_index = None
|
| 151 |
+
self._begin_index = None
|
| 152 |
+
|
| 153 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 154 |
+
if schedule_timesteps is None:
|
| 155 |
+
schedule_timesteps = self.timesteps
|
| 156 |
+
|
| 157 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 158 |
+
|
| 159 |
+
# The sigma index that is taken for the **very** first `step`
|
| 160 |
+
# is always the second index (or the last index if there is only 1)
|
| 161 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 162 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 163 |
+
pos = 1 if len(indices) > 1 else 0
|
| 164 |
+
|
| 165 |
+
return indices[pos].item()
|
| 166 |
+
|
| 167 |
+
def _init_step_index(self, timestep):
|
| 168 |
+
if self.begin_index is None:
|
| 169 |
+
if isinstance(timestep, torch.Tensor):
|
| 170 |
+
timestep = timestep.to(self.timesteps.device)
|
| 171 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 172 |
+
else:
|
| 173 |
+
self._step_index = self._begin_index
|
| 174 |
+
|
| 175 |
+
def step(
|
| 176 |
+
self,
|
| 177 |
+
model_output: torch.FloatTensor,
|
| 178 |
+
timestep: Union[float, torch.FloatTensor],
|
| 179 |
+
sample: torch.FloatTensor,
|
| 180 |
+
generator: Optional[torch.Generator] = None,
|
| 181 |
+
return_dict: bool = True,
|
| 182 |
+
) -> Union[PCMFMSchedulerOutput, Tuple]:
|
| 183 |
+
"""
|
| 184 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 185 |
+
process from the learned model outputs (most often the predicted noise).
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
model_output (`torch.FloatTensor`):
|
| 189 |
+
The direct output from learned diffusion model.
|
| 190 |
+
timestep (`float`):
|
| 191 |
+
The current discrete timestep in the diffusion chain.
|
| 192 |
+
sample (`torch.FloatTensor`):
|
| 193 |
+
A current instance of a sample created by the diffusion process.
|
| 194 |
+
s_churn (`float`):
|
| 195 |
+
s_tmin (`float`):
|
| 196 |
+
s_tmax (`float`):
|
| 197 |
+
s_noise (`float`, defaults to 1.0):
|
| 198 |
+
Scaling factor for noise added to the sample.
|
| 199 |
+
generator (`torch.Generator`, *optional*):
|
| 200 |
+
A random number generator.
|
| 201 |
+
return_dict (`bool`):
|
| 202 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 203 |
+
tuple.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 207 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 208 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor)
|
| 212 |
+
or isinstance(timestep, torch.LongTensor)):
|
| 213 |
+
raise ValueError((
|
| 214 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 215 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 216 |
+
" one of the `scheduler.timesteps` as a timestep."), )
|
| 217 |
+
|
| 218 |
+
if self.step_index is None:
|
| 219 |
+
self._init_step_index(timestep)
|
| 220 |
+
|
| 221 |
+
sample = sample.to(torch.float32)
|
| 222 |
+
|
| 223 |
+
sigma = self.sigmas_[self.step_index]
|
| 224 |
+
|
| 225 |
+
denoised = sample - model_output * sigma
|
| 226 |
+
derivative = (sample - denoised) / sigma
|
| 227 |
+
|
| 228 |
+
dt = self.sigmas_[self.step_index + 1] - sigma
|
| 229 |
+
prev_sample = sample + derivative * dt
|
| 230 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 231 |
+
self._step_index += 1
|
| 232 |
+
|
| 233 |
+
if not return_dict:
|
| 234 |
+
return (prev_sample, )
|
| 235 |
+
|
| 236 |
+
return PCMFMSchedulerOutput(prev_sample=prev_sample)
|
| 237 |
+
|
| 238 |
+
def __len__(self):
|
| 239 |
+
return self.config.num_train_timesteps
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class EulerSolver:
|
| 243 |
+
|
| 244 |
+
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
|
| 245 |
+
self.step_ratio = timesteps // euler_timesteps
|
| 246 |
+
self.euler_timesteps = (np.arange(1, euler_timesteps + 1) *
|
| 247 |
+
self.step_ratio).round().astype(np.int64) - 1
|
| 248 |
+
self.euler_timesteps_prev = np.asarray(
|
| 249 |
+
[0] + self.euler_timesteps[:-1].tolist())
|
| 250 |
+
self.sigmas = sigmas[self.euler_timesteps]
|
| 251 |
+
self.sigmas_prev = np.asarray(
|
| 252 |
+
[sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()
|
| 253 |
+
) # either use sigma0 or 0
|
| 254 |
+
|
| 255 |
+
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
|
| 256 |
+
self.euler_timesteps_prev = torch.from_numpy(
|
| 257 |
+
self.euler_timesteps_prev).long()
|
| 258 |
+
self.sigmas = torch.from_numpy(self.sigmas)
|
| 259 |
+
self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
|
| 260 |
+
|
| 261 |
+
def to(self, device):
|
| 262 |
+
self.euler_timesteps = self.euler_timesteps.to(device)
|
| 263 |
+
self.euler_timesteps_prev = self.euler_timesteps_prev.to(device)
|
| 264 |
+
|
| 265 |
+
self.sigmas = self.sigmas.to(device)
|
| 266 |
+
self.sigmas_prev = self.sigmas_prev.to(device)
|
| 267 |
+
return self
|
| 268 |
+
|
| 269 |
+
def euler_step(self, sample, model_pred, timestep_index):
|
| 270 |
+
sigma = extract_into_tensor(self.sigmas, timestep_index,
|
| 271 |
+
model_pred.shape)
|
| 272 |
+
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index,
|
| 273 |
+
model_pred.shape)
|
| 274 |
+
x_prev = sample + (sigma_prev - sigma) * model_pred
|
| 275 |
+
return x_prev
|
| 276 |
+
|
| 277 |
+
def euler_style_multiphase_pred(
|
| 278 |
+
self,
|
| 279 |
+
sample,
|
| 280 |
+
model_pred,
|
| 281 |
+
timestep_index,
|
| 282 |
+
multiphase,
|
| 283 |
+
is_target=False,
|
| 284 |
+
):
|
| 285 |
+
inference_indices = np.linspace(0,
|
| 286 |
+
len(self.euler_timesteps),
|
| 287 |
+
num=multiphase,
|
| 288 |
+
endpoint=False)
|
| 289 |
+
inference_indices = np.floor(inference_indices).astype(np.int64)
|
| 290 |
+
inference_indices = (torch.from_numpy(inference_indices).long().to(
|
| 291 |
+
self.euler_timesteps.device))
|
| 292 |
+
expanded_timestep_index = timestep_index.unsqueeze(1).expand(
|
| 293 |
+
-1, inference_indices.size(0))
|
| 294 |
+
valid_indices_mask = expanded_timestep_index >= inference_indices
|
| 295 |
+
last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(
|
| 296 |
+
dim=1)
|
| 297 |
+
last_valid_index = inference_indices.size(0) - 1 - last_valid_index
|
| 298 |
+
timestep_index_end = inference_indices[last_valid_index]
|
| 299 |
+
|
| 300 |
+
if is_target:
|
| 301 |
+
sigma = extract_into_tensor(self.sigmas_prev, timestep_index,
|
| 302 |
+
sample.shape)
|
| 303 |
+
else:
|
| 304 |
+
sigma = extract_into_tensor(self.sigmas, timestep_index,
|
| 305 |
+
sample.shape)
|
| 306 |
+
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end,
|
| 307 |
+
sample.shape)
|
| 308 |
+
x_prev = sample + (sigma_prev - sigma) * model_pred
|
| 309 |
+
|
| 310 |
+
return x_prev, timestep_index_end
|
fastvideo/models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
fastvideo/models/flash_attn_no_pad.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange
|
| 2 |
+
from flash_attn import flash_attn_varlen_qkvpacked_func
|
| 3 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def flash_attn_no_pad(qkv,
|
| 7 |
+
key_padding_mask,
|
| 8 |
+
causal=False,
|
| 9 |
+
dropout_p=0.0,
|
| 10 |
+
softmax_scale=None):
|
| 11 |
+
# adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
|
| 12 |
+
batch_size = qkv.shape[0]
|
| 13 |
+
seqlen = qkv.shape[1]
|
| 14 |
+
nheads = qkv.shape[-2]
|
| 15 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 16 |
+
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(
|
| 17 |
+
x, key_padding_mask)
|
| 18 |
+
|
| 19 |
+
x_unpad = rearrange(x_unpad,
|
| 20 |
+
"nnz (three h d) -> nnz three h d",
|
| 21 |
+
three=3,
|
| 22 |
+
h=nheads)
|
| 23 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 24 |
+
x_unpad,
|
| 25 |
+
cu_seqlens,
|
| 26 |
+
max_s,
|
| 27 |
+
dropout_p,
|
| 28 |
+
softmax_scale=softmax_scale,
|
| 29 |
+
causal=causal,
|
| 30 |
+
)
|
| 31 |
+
output = rearrange(
|
| 32 |
+
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices,
|
| 33 |
+
batch_size, seqlen),
|
| 34 |
+
"b s (h d) -> b s h d",
|
| 35 |
+
h=nheads,
|
| 36 |
+
)
|
| 37 |
+
return output
|
fastvideo/reward_model/clip_score.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import clip
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from typing import List, Tuple, Union
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import os
|
| 10 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 11 |
+
import argparse
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def calculate_clip_score(prompts, images, clip_model, device):
|
| 17 |
+
texts = clip.tokenize(prompts, truncate=True).to(device=device)
|
| 18 |
+
|
| 19 |
+
image_features = clip_model.encode_image(images)
|
| 20 |
+
text_features = clip_model.encode_text(texts)
|
| 21 |
+
|
| 22 |
+
scores = F.cosine_similarity(image_features, text_features)
|
| 23 |
+
return scores
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CLIPScoreRewardModel():
|
| 27 |
+
def __init__(self, clip_model_path, device, http_proxy=None, https_proxy=None, clip_model_type='ViT-H-14'):
|
| 28 |
+
super().__init__()
|
| 29 |
+
if http_proxy:
|
| 30 |
+
os.environ["http_proxy"] = http_proxy
|
| 31 |
+
if https_proxy:
|
| 32 |
+
os.environ["https_proxy"] = https_proxy
|
| 33 |
+
self.clip_model_path = clip_model_path
|
| 34 |
+
self.clip_model_type = clip_model_type
|
| 35 |
+
self.device = device
|
| 36 |
+
self.load_model()
|
| 37 |
+
|
| 38 |
+
def load_model(self, logger=None):
|
| 39 |
+
self.model, self.preprocess = create_model_from_pretrained(self.clip_model_path)
|
| 40 |
+
self.tokenizer = get_tokenizer(self.clip_model_type)
|
| 41 |
+
self.model.to(self.device)
|
| 42 |
+
|
| 43 |
+
# calculate clip score directly, such as for rerank
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def __call__(
|
| 46 |
+
self,
|
| 47 |
+
prompts: Union[str, List[str]],
|
| 48 |
+
images: List[Image.Image]
|
| 49 |
+
) -> List[float]:
|
| 50 |
+
if isinstance(prompts, str):
|
| 51 |
+
prompts = [prompts] * len(images)
|
| 52 |
+
if len(prompts) != len(images):
|
| 53 |
+
raise ValueError("prompts must have the same length as images")
|
| 54 |
+
|
| 55 |
+
scores = []
|
| 56 |
+
for prompt, image in zip(prompts, images):
|
| 57 |
+
image_proc = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 58 |
+
text = self.tokenizer(
|
| 59 |
+
[prompt],
|
| 60 |
+
context_length=self.model.context_length
|
| 61 |
+
).to(self.device)
|
| 62 |
+
|
| 63 |
+
image_features = self.model.encode_image(image_proc)
|
| 64 |
+
text_features = self.model.encode_text(text)
|
| 65 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 66 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 67 |
+
|
| 68 |
+
clip_score = image_features @ text_features.T
|
| 69 |
+
|
| 70 |
+
scores.append(clip_score.item())
|
| 71 |
+
|
| 72 |
+
return scores
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
parser = argparse.ArgumentParser(description="PickScore Reward Model")
|
| 78 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (e.g., 'cuda', 'cpu')")
|
| 79 |
+
parser.add_argument("--http_proxy", type=str, default=None, help="HTTP proxy URL")
|
| 80 |
+
parser.add_argument("--https_proxy", type=str, default=None, help="HTTPS proxy URL")
|
| 81 |
+
args = parser.parse_args()
|
| 82 |
+
|
| 83 |
+
# Example usage
|
| 84 |
+
clip_model_path = 'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384'
|
| 85 |
+
reward_model = CLIPScoreRewardModel(
|
| 86 |
+
clip_model_path,
|
| 87 |
+
device=args.device,
|
| 88 |
+
http_proxy=args.http_proxy,
|
| 89 |
+
https_proxy=args.https_proxy
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
image_path = "assets/reward_demo.jpg"
|
| 93 |
+
prompt = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting."
|
| 94 |
+
|
| 95 |
+
image = Image.open(image_path).convert("RGB")
|
| 96 |
+
clip_score = reward_model(prompt, [image])
|
| 97 |
+
|
| 98 |
+
print(f"CLIP Score: {clip_score}")
|
fastvideo/reward_model/hps_score.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from HPSv2.hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class HPSClipRewardModel(object):
|
| 10 |
+
def __init__(self, device, clip_ckpt_path, hps_ckpt_path, model_name='ViT-H-14'):
|
| 11 |
+
self.device = device
|
| 12 |
+
self.clip_ckpt_path = clip_ckpt_path
|
| 13 |
+
self.hps_ckpt_path = hps_ckpt_path
|
| 14 |
+
self.model_name = model_name
|
| 15 |
+
self.reward_model, self.text_processor, self.img_processor = self.build_reward_model()
|
| 16 |
+
|
| 17 |
+
def build_reward_model(self):
|
| 18 |
+
model, preprocess_train, img_preprocess_val = create_model_and_transforms(
|
| 19 |
+
self.model_name,
|
| 20 |
+
self.clip_ckpt_path,
|
| 21 |
+
precision='amp',
|
| 22 |
+
device=self.device,
|
| 23 |
+
jit=False,
|
| 24 |
+
force_quick_gelu=False,
|
| 25 |
+
force_custom_text=False,
|
| 26 |
+
force_patch_dropout=False,
|
| 27 |
+
force_image_size=None,
|
| 28 |
+
pretrained_image=False,
|
| 29 |
+
image_mean=None,
|
| 30 |
+
image_std=None,
|
| 31 |
+
light_augmentation=True,
|
| 32 |
+
aug_cfg={},
|
| 33 |
+
output_dict=True,
|
| 34 |
+
with_score_predictor=False,
|
| 35 |
+
with_region_predictor=False
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Convert device name to proper format
|
| 39 |
+
if isinstance(self.device, int):
|
| 40 |
+
ml_device = str(self.device)
|
| 41 |
+
else:
|
| 42 |
+
ml_device = self.device
|
| 43 |
+
|
| 44 |
+
if not ml_device.startswith('cuda'):
|
| 45 |
+
ml_device = f'cuda:{ml_device}' if ml_device.isdigit() else ml_device
|
| 46 |
+
|
| 47 |
+
checkpoint = torch.load(self.hps_ckpt_path, map_location=ml_device)
|
| 48 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 49 |
+
text_processor = get_tokenizer(self.model_name)
|
| 50 |
+
reward_model = model.to(self.device)
|
| 51 |
+
reward_model.eval()
|
| 52 |
+
|
| 53 |
+
return reward_model, text_processor, img_preprocess_val
|
| 54 |
+
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def __call__(
|
| 57 |
+
self,
|
| 58 |
+
images: Union[Image.Image, List[Image.Image]],
|
| 59 |
+
texts: Union[str, List[str]],
|
| 60 |
+
):
|
| 61 |
+
if isinstance(images, Image.Image):
|
| 62 |
+
images = [images]
|
| 63 |
+
if isinstance(texts, str):
|
| 64 |
+
texts = [texts]
|
| 65 |
+
|
| 66 |
+
rewards = []
|
| 67 |
+
for image, text in zip(images, texts):
|
| 68 |
+
image = self.img_processor(image).unsqueeze(0).to(self.device, non_blocking=True)
|
| 69 |
+
text = self.text_processor([text]).to(device=self.device, non_blocking=True)
|
| 70 |
+
with torch.amp.autocast('cuda'):
|
| 71 |
+
outputs = self.reward_model(image, text)
|
| 72 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 73 |
+
logits_per_image = image_features @ text_features.T
|
| 74 |
+
hps_score = torch.diagonal(logits_per_image)
|
| 75 |
+
|
| 76 |
+
# reward is a tensor of shape (1,) --> list
|
| 77 |
+
rewards.append(hps_score.float().cpu().item())
|
| 78 |
+
|
| 79 |
+
return rewards
|
fastvideo/reward_model/image_reward.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image-Reward: Copyied from https://github.com/THUDM/ImageReward
|
| 2 |
+
import os
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
try:
|
| 8 |
+
import ImageReward as RM
|
| 9 |
+
except:
|
| 10 |
+
raise Warning("ImageReward is required to be installed (`pip install image-reward`) when using ImageReward for post-training.")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ImageRewardModel(object):
|
| 14 |
+
def __init__(self, model_name, device, http_proxy=None, https_proxy=None, med_config=None):
|
| 15 |
+
if http_proxy:
|
| 16 |
+
os.environ["http_proxy"] = http_proxy
|
| 17 |
+
if https_proxy:
|
| 18 |
+
os.environ["https_proxy"] = https_proxy
|
| 19 |
+
self.model_name = model_name if model_name else "ImageReward-v1.0"
|
| 20 |
+
self.device = device
|
| 21 |
+
self.med_config = med_config
|
| 22 |
+
self.build_reward_model()
|
| 23 |
+
|
| 24 |
+
def build_reward_model(self):
|
| 25 |
+
self.model = RM.load(self.model_name, device=self.device, med_config=self.med_config)
|
| 26 |
+
|
| 27 |
+
@torch.no_grad()
|
| 28 |
+
def __call__(
|
| 29 |
+
self,
|
| 30 |
+
images,
|
| 31 |
+
texts,
|
| 32 |
+
):
|
| 33 |
+
if isinstance(texts, str):
|
| 34 |
+
texts = [texts] * len(images)
|
| 35 |
+
|
| 36 |
+
rewards = []
|
| 37 |
+
for image, text in zip(images, texts):
|
| 38 |
+
ranking, reward = self.model.inference_rank(text, [image])
|
| 39 |
+
rewards.append(reward)
|
| 40 |
+
return rewards
|
fastvideo/reward_model/pick_score.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
from typing import List, Tuple, Union
|
| 5 |
+
from transformers import AutoProcessor, AutoModel
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PickScoreRewardModel(object):
|
| 10 |
+
def __init__(self, device: str = "cuda", http_proxy=None, https_proxy=None, mean=18.0, std=8.0):
|
| 11 |
+
"""
|
| 12 |
+
Initialize PickScore reward model.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
device: Device to run the model on ('cuda' or 'cpu')
|
| 16 |
+
"""
|
| 17 |
+
if http_proxy:
|
| 18 |
+
os.environ["http_proxy"] = http_proxy
|
| 19 |
+
if https_proxy:
|
| 20 |
+
os.environ["https_proxy"] = https_proxy
|
| 21 |
+
self.device = device
|
| 22 |
+
self.processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 23 |
+
self.model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
|
| 24 |
+
self.mean = mean
|
| 25 |
+
self.std = std
|
| 26 |
+
|
| 27 |
+
# Initialize model and processor
|
| 28 |
+
self.processor = AutoProcessor.from_pretrained(self.processor_name_or_path)
|
| 29 |
+
self.model = AutoModel.from_pretrained(self.model_pretrained_name_or_path).eval().to(device)
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def __call__(
|
| 33 |
+
self,
|
| 34 |
+
images: List[Image.Image],
|
| 35 |
+
prompts: Union[str, List[str]],
|
| 36 |
+
) -> Tuple[List[float], List[float]]:
|
| 37 |
+
"""
|
| 38 |
+
Calculate probabilities and scores for images given a prompt.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
prompts: Text prompt to evaluate images against
|
| 42 |
+
images: List of PIL Images to evaluate
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Tuple of (probabilities, scores) for each image
|
| 46 |
+
"""
|
| 47 |
+
if isinstance(prompts, str):
|
| 48 |
+
prompts = [prompts] * len(images)
|
| 49 |
+
if len(prompts) != len(images):
|
| 50 |
+
raise ValueError("prompts must have the same length as images")
|
| 51 |
+
|
| 52 |
+
scores = []
|
| 53 |
+
for prompt, image in zip(prompts, images):
|
| 54 |
+
# Preprocess images
|
| 55 |
+
image_inputs = self.processor(
|
| 56 |
+
images=[image],
|
| 57 |
+
padding=True,
|
| 58 |
+
truncation=True,
|
| 59 |
+
max_length=77,
|
| 60 |
+
return_tensors="pt",
|
| 61 |
+
).to(self.device)
|
| 62 |
+
|
| 63 |
+
# Preprocess text
|
| 64 |
+
text_inputs = self.processor(
|
| 65 |
+
text=prompt,
|
| 66 |
+
padding=True,
|
| 67 |
+
truncation=True,
|
| 68 |
+
max_length=77,
|
| 69 |
+
return_tensors="pt",
|
| 70 |
+
).to(self.device)
|
| 71 |
+
|
| 72 |
+
# Get embeddings
|
| 73 |
+
image_embs = self.model.get_image_features(**image_inputs)
|
| 74 |
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
| 75 |
+
|
| 76 |
+
text_embs = self.model.get_text_features(**text_inputs)
|
| 77 |
+
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
| 78 |
+
|
| 79 |
+
# Calculate scores
|
| 80 |
+
score = self.model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
|
| 81 |
+
score = (score - self.mean) / self.std
|
| 82 |
+
scores.extend(score.cpu().tolist())
|
| 83 |
+
|
| 84 |
+
return scores
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
parser = argparse.ArgumentParser(description="PickScore Reward Model")
|
| 89 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (e.g., 'cuda', 'cpu')")
|
| 90 |
+
parser.add_argument("--http_proxy", type=str, default=None, help="HTTP proxy URL")
|
| 91 |
+
parser.add_argument("--https_proxy", type=str, default=None, help="HTTPS proxy URL")
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
|
| 94 |
+
# Example usage
|
| 95 |
+
reward_model = PickScoreRewardModel(
|
| 96 |
+
device=args.device,
|
| 97 |
+
http_proxy=args.http_proxy,
|
| 98 |
+
https_proxy=args.https_proxy,
|
| 99 |
+
)
|
| 100 |
+
pil_images = [Image.open("assets/reward_demo.jpg")]
|
| 101 |
+
|
| 102 |
+
prompt = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting."
|
| 103 |
+
|
| 104 |
+
scores = reward_model(pil_images, [prompt] * len(pil_images))
|
| 105 |
+
scores = [(s * reward_model.std + reward_model.mean) / 100.0 for s in scores]
|
| 106 |
+
print(f"scores: {scores}")
|
| 107 |
+
|
fastvideo/reward_model/unified_reward.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import requests
|
| 6 |
+
import time
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from typing import List, Optional, Union
|
| 10 |
+
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
QUESTION_TEMPLATE_SEMANTIC = (
|
| 15 |
+
"You are presented with a generated image and its associated text caption. Your task is to analyze the image across multiple dimensions in relation to the caption. Specifically:\n\n"
|
| 16 |
+
"1. Evaluate each word in the caption based on how well it is visually represented in the image. Assign a numerical score to each word using the format:\n"
|
| 17 |
+
" Word-wise Scores: [[\"word1\", score1], [\"word2\", score2], ..., [\"wordN\", scoreN], [\"[No_mistakes]\", scoreM]]\n"
|
| 18 |
+
" - A higher score indicates that the word is less well represented in the image.\n"
|
| 19 |
+
" - The special token [No_mistakes] represents whether all elements in the caption were correctly depicted. A high score suggests no mistakes; a low score suggests missing or incorrect elements.\n\n"
|
| 20 |
+
"2. Provide overall assessments for the image along the following axes (each rated from 1 to 5):\n"
|
| 21 |
+
"- Alignment Score: How well the image matches the caption in terms of content.\n"
|
| 22 |
+
"- Coherence Score: How logically consistent the image is (absence of visual glitches, object distortions, etc.).\n"
|
| 23 |
+
"- Style Score: How aesthetically appealing the image looks, regardless of caption accuracy.\n\n"
|
| 24 |
+
"Output your evaluation using the format below:\n\n"
|
| 25 |
+
"---\n\n"
|
| 26 |
+
"Word-wise Scores: [[\"word1\", score1], ..., [\"[No_mistakes]\", scoreM]]\n\n"
|
| 27 |
+
"Alignment Score (1-5): X\n"
|
| 28 |
+
"Coherence Score (1-5): Y\n"
|
| 29 |
+
"Style Score (1-5): Z\n\n"
|
| 30 |
+
"Your task is provided as follows:\nText Caption: [{}]"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
QUESTION_TEMPLATE_SCORE = (
|
| 34 |
+
"You are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n"
|
| 35 |
+
"1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n"
|
| 36 |
+
"2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\n"
|
| 37 |
+
"Extract key elements from the provided text caption, evaluate their presence in the generated image using the format: \'element (type): value\' (where value=0 means not generated, and value=1 means generated), and assign a score from 1 to 5 after \'Final Score:\'.\n"
|
| 38 |
+
"Your task is provided as follows:\nText Caption: [{}]"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class VLMessageClient:
|
| 43 |
+
def __init__(self, api_url):
|
| 44 |
+
self.api_url = api_url
|
| 45 |
+
self._session = None
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def session(self):
|
| 49 |
+
if self._session is None:
|
| 50 |
+
self._session = requests.Session()
|
| 51 |
+
return self._session
|
| 52 |
+
|
| 53 |
+
def close(self):
|
| 54 |
+
"""Close the session if it exists."""
|
| 55 |
+
if self._session is not None:
|
| 56 |
+
self._session.close()
|
| 57 |
+
self._session = None
|
| 58 |
+
|
| 59 |
+
def __enter__(self):
|
| 60 |
+
return self
|
| 61 |
+
|
| 62 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 63 |
+
self.close()
|
| 64 |
+
|
| 65 |
+
def _encode_image_base64(self, image):
|
| 66 |
+
if isinstance(image, str):
|
| 67 |
+
with Image.open(image) as img:
|
| 68 |
+
img = img.convert("RGB")
|
| 69 |
+
buffered = BytesIO()
|
| 70 |
+
img.save(buffered, format="JPEG", quality=95)
|
| 71 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 72 |
+
elif isinstance(image, Image.Image):
|
| 73 |
+
buffered = BytesIO()
|
| 74 |
+
image.save(buffered, format="JPEG", quality=95)
|
| 75 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
| 78 |
+
|
| 79 |
+
def build_messages(self, item, image_root=""):
|
| 80 |
+
if isinstance(item['image'], str):
|
| 81 |
+
image_path = os.path.join(image_root, item['image'])
|
| 82 |
+
return [
|
| 83 |
+
{
|
| 84 |
+
"role": "user",
|
| 85 |
+
"content": [
|
| 86 |
+
{"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
|
| 87 |
+
{
|
| 88 |
+
"type": "text",
|
| 89 |
+
"text": f"{item['question']}"
|
| 90 |
+
}
|
| 91 |
+
]
|
| 92 |
+
}
|
| 93 |
+
]
|
| 94 |
+
assert isinstance(item['image'], Image.Image), f"image must be a PIL.Image.Image, but got {type(item['image'])}"
|
| 95 |
+
return [
|
| 96 |
+
{
|
| 97 |
+
"role": "user",
|
| 98 |
+
"content": [
|
| 99 |
+
{"type": "pil_image", "pil_image": item['image']},
|
| 100 |
+
{
|
| 101 |
+
"type": "text",
|
| 102 |
+
"text": f"{item['question']}"
|
| 103 |
+
}
|
| 104 |
+
]
|
| 105 |
+
}
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
def format_messages(self, messages):
|
| 109 |
+
formatted = []
|
| 110 |
+
for msg in messages:
|
| 111 |
+
new_msg = {"role": msg["role"], "content": []}
|
| 112 |
+
|
| 113 |
+
if msg["role"] == "system":
|
| 114 |
+
new_msg["content"] = msg["content"][0]["text"]
|
| 115 |
+
else:
|
| 116 |
+
for part in msg["content"]:
|
| 117 |
+
if part["type"] == "image_url":
|
| 118 |
+
img_path = part["image_url"]["url"].replace("file://", "")
|
| 119 |
+
base64_image = self._encode_image_base64(img_path)
|
| 120 |
+
new_part = {
|
| 121 |
+
"type": "image_url",
|
| 122 |
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
|
| 123 |
+
}
|
| 124 |
+
new_msg["content"].append(new_part)
|
| 125 |
+
elif part["type"] == "pil_image":
|
| 126 |
+
base64_image = self._encode_image_base64(part["pil_image"])
|
| 127 |
+
new_part = {
|
| 128 |
+
"type": "image_url",
|
| 129 |
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
|
| 130 |
+
}
|
| 131 |
+
new_msg["content"].append(new_part)
|
| 132 |
+
else:
|
| 133 |
+
new_msg["content"].append(part)
|
| 134 |
+
formatted.append(new_msg)
|
| 135 |
+
return formatted
|
| 136 |
+
|
| 137 |
+
def process_item(self, item, image_root=""):
|
| 138 |
+
max_retries = 3
|
| 139 |
+
attempt = 0
|
| 140 |
+
result = None
|
| 141 |
+
|
| 142 |
+
while attempt < max_retries:
|
| 143 |
+
try:
|
| 144 |
+
attempt += 1
|
| 145 |
+
raw_messages = self.build_messages(item, image_root)
|
| 146 |
+
formatted_messages = self.format_messages(raw_messages)
|
| 147 |
+
|
| 148 |
+
payload = {
|
| 149 |
+
"model": "UnifiedReward",
|
| 150 |
+
"messages": formatted_messages,
|
| 151 |
+
"temperature": 0,
|
| 152 |
+
"max_tokens": 4096,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
response = self.session.post(
|
| 156 |
+
f"{self.api_url}/v1/chat/completions",
|
| 157 |
+
json=payload,
|
| 158 |
+
timeout=30 + attempt*5
|
| 159 |
+
)
|
| 160 |
+
response.raise_for_status()
|
| 161 |
+
|
| 162 |
+
output = response.json()["choices"][0]["message"]["content"]
|
| 163 |
+
|
| 164 |
+
result = {
|
| 165 |
+
"question": item["question"],
|
| 166 |
+
"image_path": item["image"] if isinstance(item["image"], str) else "PIL_Image",
|
| 167 |
+
"model_output": output,
|
| 168 |
+
"attempt": attempt,
|
| 169 |
+
"success": True
|
| 170 |
+
}
|
| 171 |
+
break
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
if attempt == max_retries:
|
| 175 |
+
result = {
|
| 176 |
+
"question": item["question"],
|
| 177 |
+
"image_path": item["image"] if isinstance(item["image"], str) else "PIL_Image",
|
| 178 |
+
"error": str(e),
|
| 179 |
+
"attempt": attempt,
|
| 180 |
+
"success": False
|
| 181 |
+
}
|
| 182 |
+
raise(e)
|
| 183 |
+
else:
|
| 184 |
+
sleep_time = min(2 ** attempt, 10)
|
| 185 |
+
time.sleep(sleep_time)
|
| 186 |
+
|
| 187 |
+
return result, result.get("success", False)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class UnifiedRewardModel(object):
|
| 191 |
+
def __init__(self, api_url, default_question_type="score", num_workers=8):
|
| 192 |
+
self.api_url = api_url
|
| 193 |
+
self.num_workers = num_workers
|
| 194 |
+
self.default_question_type = default_question_type
|
| 195 |
+
self.question_template_score = QUESTION_TEMPLATE_SCORE
|
| 196 |
+
self.question_template_semantic = QUESTION_TEMPLATE_SEMANTIC
|
| 197 |
+
# self.client = VLMessageClient(self.api_url)
|
| 198 |
+
|
| 199 |
+
def question_constructor(self, prompt, question_type=None):
|
| 200 |
+
if question_type is None:
|
| 201 |
+
question_type = self.default_question_type
|
| 202 |
+
if question_type == "score":
|
| 203 |
+
return self.question_template_score.format(prompt)
|
| 204 |
+
elif question_type == "semantic":
|
| 205 |
+
return self.question_template_semantic.format(prompt)
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError(f"Invalid question type: {question_type}")
|
| 208 |
+
|
| 209 |
+
def _process_item_wrapper(self, client, image, question):
|
| 210 |
+
try:
|
| 211 |
+
item = {
|
| 212 |
+
"image": image,
|
| 213 |
+
"question": question,
|
| 214 |
+
}
|
| 215 |
+
result, _ = client.process_item(item)
|
| 216 |
+
return result
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Encountered error in unified reward model processing: {str(e)}")
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
def _reset_proxy(self):
|
| 222 |
+
os.environ.pop('http_proxy', None)
|
| 223 |
+
os.environ.pop('https_proxy', None)
|
| 224 |
+
|
| 225 |
+
def __call__(self,
|
| 226 |
+
images: Union[List[Image.Image], List[str]],
|
| 227 |
+
prompts: Union[str, List[str]],
|
| 228 |
+
question_type: Optional[str] = None,
|
| 229 |
+
):
|
| 230 |
+
# Reset proxy, otherwise cannot access the server url
|
| 231 |
+
self._reset_proxy()
|
| 232 |
+
if isinstance(prompts, str):
|
| 233 |
+
prompts = [prompts] * len(images)
|
| 234 |
+
if len(prompts) != len(images):
|
| 235 |
+
raise ValueError("prompts must have the same length as images")
|
| 236 |
+
|
| 237 |
+
with VLMessageClient(self.api_url) as client:
|
| 238 |
+
questions = [self.question_constructor(prompt, question_type) for prompt in prompts]
|
| 239 |
+
|
| 240 |
+
# Initialize results and successes lists with None and False
|
| 241 |
+
results = [None] * len(images)
|
| 242 |
+
successes = [False] * len(images)
|
| 243 |
+
|
| 244 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
| 245 |
+
# Submit all tasks and keep track of their order
|
| 246 |
+
future_to_idx = {
|
| 247 |
+
executor.submit(self._process_item_wrapper, client, image, question): idx
|
| 248 |
+
for idx, (image, question) in enumerate(zip(images, questions))
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
# Get results in completion order but store them in the correct position
|
| 252 |
+
for future in concurrent.futures.as_completed(future_to_idx):
|
| 253 |
+
idx = future_to_idx[future]
|
| 254 |
+
result = future.result()
|
| 255 |
+
if result is not None and result.get("success", False):
|
| 256 |
+
output = result.get("model_output", "")
|
| 257 |
+
score = self.score_parser(output, question_type)
|
| 258 |
+
results[idx] = score
|
| 259 |
+
successes[idx] = True
|
| 260 |
+
else:
|
| 261 |
+
results[idx] = None
|
| 262 |
+
successes[idx] = False
|
| 263 |
+
|
| 264 |
+
return results, successes
|
| 265 |
+
|
| 266 |
+
def score_parser(self, text, question_type=None):
|
| 267 |
+
if question_type is None:
|
| 268 |
+
question_type = self.default_question_type
|
| 269 |
+
if question_type == "score":
|
| 270 |
+
return self.extract_final_score(text)
|
| 271 |
+
elif question_type == "semantic":
|
| 272 |
+
return self.extract_alignment_score(text)
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"Invalid question type: {question_type}")
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def extract_alignment_score(text):
|
| 278 |
+
"""
|
| 279 |
+
Extract Alignment Score (1-5) from the evaluation text.
|
| 280 |
+
Returns a float score if found, None otherwise.
|
| 281 |
+
"""
|
| 282 |
+
match = re.search(r'Alignment Score \(1-5\):\s*([0-5](?:\.\d+)?)', text)
|
| 283 |
+
if match:
|
| 284 |
+
return float(match.group(1))
|
| 285 |
+
else:
|
| 286 |
+
return None
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
def extract_final_score(text):
|
| 290 |
+
"""
|
| 291 |
+
Extract Final Score from the evaluation text.
|
| 292 |
+
Returns a float score if found, None otherwise.
|
| 293 |
+
Example input:
|
| 294 |
+
'ocean (location): 0
|
| 295 |
+
clouds (object): 1
|
| 296 |
+
birds (animal): 0
|
| 297 |
+
day time (attribute): 1
|
| 298 |
+
low depth field effect (attribute): 1
|
| 299 |
+
painting (attribute): 1
|
| 300 |
+
Final Score: 2.33'
|
| 301 |
+
"""
|
| 302 |
+
match = re.search(r'Final Score:\s*([0-5](?:\.\d+)?)', text)
|
| 303 |
+
if match:
|
| 304 |
+
return float(match.group(1))
|
| 305 |
+
else:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
parser = argparse.ArgumentParser()
|
| 311 |
+
parser.add_argument("--api_url", type=str)
|
| 312 |
+
parser.add_argument("--max_workers", type=int)
|
| 313 |
+
args = parser.parse_args()
|
| 314 |
+
|
| 315 |
+
unified_reward_model = UnifiedRewardModel(args.api_url, num_workers=args.max_workers)
|
| 316 |
+
img_path = "assets/reward_demo.jpg"
|
| 317 |
+
images = [
|
| 318 |
+
Image.open(img_path).convert("RGB")
|
| 319 |
+
for i in range(1, 5)
|
| 320 |
+
] * 4
|
| 321 |
+
prompts = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting."
|
| 322 |
+
results, successes = unified_reward_model(images, prompts, question_type="semantic")
|
| 323 |
+
print(results)
|
| 324 |
+
print(successes)
|
| 325 |
+
|
| 326 |
+
# # 并发测试
|
| 327 |
+
# proc_num = 32
|
| 328 |
+
|
| 329 |
+
# for i in range(5):
|
| 330 |
+
# with concurrent.futures.ThreadPoolExecutor(max_workers=proc_num) as executor:
|
| 331 |
+
# futures = [executor.submit(unified_reward_model, images, prompts, question_type="semantic") for _ in range(proc_num)]
|
| 332 |
+
# results = [future.result() for future in concurrent.futures.as_completed(futures)]
|
| 333 |
+
# print(results)
|
fastvideo/reward_model/utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
def _compute_single_reward(reward_model, images, input_prompts):
|
| 5 |
+
"""Compute reward for a single reward model."""
|
| 6 |
+
reward_model_name = type(reward_model).__name__
|
| 7 |
+
try:
|
| 8 |
+
if reward_model_name == 'HPSClipRewardModel':
|
| 9 |
+
rewards = reward_model(images, input_prompts)
|
| 10 |
+
successes = [1] * len(rewards)
|
| 11 |
+
|
| 12 |
+
elif reward_model_name == 'CLIPScoreRewardModel':
|
| 13 |
+
rewards = reward_model(input_prompts, images)
|
| 14 |
+
successes = [1] * len(rewards)
|
| 15 |
+
|
| 16 |
+
elif reward_model_name == 'ImageRewardModel':
|
| 17 |
+
rewards = reward_model(images, input_prompts)
|
| 18 |
+
successes = [1] * len(rewards)
|
| 19 |
+
|
| 20 |
+
elif reward_model_name == 'UnifiedRewardModel':
|
| 21 |
+
rewards, successes_bool = reward_model(images, input_prompts)
|
| 22 |
+
rewards = [float(reward) if success else 0.0 for reward, success in zip(rewards, successes_bool)]
|
| 23 |
+
successes = [1 if success else 0 for success in successes_bool]
|
| 24 |
+
|
| 25 |
+
elif reward_model_name == 'PickScoreRewardModel':
|
| 26 |
+
rewards = reward_model(images, input_prompts)
|
| 27 |
+
successes = [1] * len(rewards)
|
| 28 |
+
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"Unknown reward model: {reward_model_name}")
|
| 31 |
+
|
| 32 |
+
# Verify the length of results matches input
|
| 33 |
+
assert len(rewards) == len(input_prompts), \
|
| 34 |
+
f"Length mismatch in {reward_model_name}: rewards ({len(rewards)}) != input_prompts ({len(input_prompts)})"
|
| 35 |
+
assert len(successes) == len(input_prompts), \
|
| 36 |
+
f"Length mismatch in {reward_model_name}: successes ({len(successes)}) != input_prompts ({len(input_prompts)})"
|
| 37 |
+
|
| 38 |
+
return rewards, successes
|
| 39 |
+
|
| 40 |
+
except Exception as e:
|
| 41 |
+
raise ValueError(f"Error in _compute_single_reward with {reward_model_name}: {e}") from e
|
| 42 |
+
|
| 43 |
+
def compute_reward(images, input_prompts, reward_models, reward_weights):
|
| 44 |
+
assert (
|
| 45 |
+
len(images) == len(input_prompts)
|
| 46 |
+
), f"length of `images` ({len(images)}) must be equal to length of `input_prompts` ({len(input_prompts)})"
|
| 47 |
+
|
| 48 |
+
# Initialize results
|
| 49 |
+
rewards_dict = {}
|
| 50 |
+
successes_dict = {}
|
| 51 |
+
|
| 52 |
+
# Create a thread pool for parallel reward computation
|
| 53 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=len(reward_models)) as executor:
|
| 54 |
+
# Submit all reward computation tasks
|
| 55 |
+
future_to_model = {
|
| 56 |
+
executor.submit(_compute_single_reward, reward_model, images, input_prompts): reward_model
|
| 57 |
+
for reward_model in reward_models
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Process results as they complete
|
| 61 |
+
for future in concurrent.futures.as_completed(future_to_model):
|
| 62 |
+
reward_model = future_to_model[future]
|
| 63 |
+
model_name = type(reward_model).__name__
|
| 64 |
+
try:
|
| 65 |
+
model_rewards, model_successes = future.result()
|
| 66 |
+
rewards_dict[model_name] = model_rewards
|
| 67 |
+
successes_dict[model_name] = model_successes
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error computing reward with {model_name}: {e}")
|
| 70 |
+
rewards_dict[model_name] = [0.0] * len(input_prompts)
|
| 71 |
+
successes_dict[model_name] = [0] * len(input_prompts)
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
# Merge rewards based on weights
|
| 75 |
+
merged_rewards = [0.0] * len(input_prompts)
|
| 76 |
+
merged_successes = [0] * len(input_prompts)
|
| 77 |
+
|
| 78 |
+
# First check if all models are successful for each sample
|
| 79 |
+
for i in range(len(merged_rewards)):
|
| 80 |
+
all_success = True
|
| 81 |
+
for model_name in reward_weights.keys():
|
| 82 |
+
if model_name in successes_dict and successes_dict[model_name][i] != 1:
|
| 83 |
+
all_success = False
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
if all_success:
|
| 87 |
+
# Only compute weighted sum if all models are successful
|
| 88 |
+
for model_name, weight in reward_weights.items():
|
| 89 |
+
if model_name in rewards_dict:
|
| 90 |
+
merged_rewards[i] += rewards_dict[model_name][i] * weight
|
| 91 |
+
merged_successes[i] = 1
|
| 92 |
+
|
| 93 |
+
return merged_rewards, merged_successes, rewards_dict, successes_dict
|
| 94 |
+
|
| 95 |
+
def balance_pos_neg(samples, use_random=False):
|
| 96 |
+
"""Balance positive and negative samples distribution in the samples list."""
|
| 97 |
+
if use_random:
|
| 98 |
+
return random.sample(samples, len(samples))
|
| 99 |
+
else:
|
| 100 |
+
positive_samples = [sample for sample in samples if sample['advantages'].item() > 0]
|
| 101 |
+
negative_samples = [sample for sample in samples if sample['advantages'].item() < 0]
|
| 102 |
+
|
| 103 |
+
positive_samples = random.sample(positive_samples, len(positive_samples))
|
| 104 |
+
negative_samples = random.sample(negative_samples, len(negative_samples))
|
| 105 |
+
|
| 106 |
+
num_positive = len(positive_samples)
|
| 107 |
+
num_negative = len(negative_samples)
|
| 108 |
+
|
| 109 |
+
balanced_samples = []
|
| 110 |
+
|
| 111 |
+
if num_positive < num_negative:
|
| 112 |
+
smaller_group = positive_samples
|
| 113 |
+
larger_group = negative_samples
|
| 114 |
+
else:
|
| 115 |
+
smaller_group = negative_samples
|
| 116 |
+
larger_group = positive_samples
|
| 117 |
+
|
| 118 |
+
for i in range(len(smaller_group)):
|
| 119 |
+
balanced_samples.append(smaller_group[i])
|
| 120 |
+
balanced_samples.append(larger_group[i])
|
| 121 |
+
|
| 122 |
+
# If there are remaining samples in the larger group, add them
|
| 123 |
+
remaining_samples = larger_group[len(smaller_group):]
|
| 124 |
+
balanced_samples.extend(remaining_samples)
|
| 125 |
+
return balanced_samples
|
| 126 |
+
|
fastvideo/utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
fastvideo/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed.checkpoint as dist_cp
|
| 8 |
+
from peft import get_peft_model_state_dict
|
| 9 |
+
from safetensors.torch import load_file, save_file
|
| 10 |
+
from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner,
|
| 11 |
+
DefaultSavePlanner)
|
| 12 |
+
from torch.distributed.checkpoint.optimizer import \
|
| 13 |
+
load_sharded_optimizer_state_dict
|
| 14 |
+
from torch.distributed.fsdp import (FullOptimStateDictConfig,
|
| 15 |
+
FullStateDictConfig)
|
| 16 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 17 |
+
from torch.distributed.fsdp import StateDictType
|
| 18 |
+
|
| 19 |
+
from fastvideo.utils.logging_ import main_print
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def save_checkpoint_optimizer(model,
|
| 23 |
+
optimizer,
|
| 24 |
+
rank,
|
| 25 |
+
output_dir,
|
| 26 |
+
step,
|
| 27 |
+
discriminator=False):
|
| 28 |
+
with FSDP.state_dict_type(
|
| 29 |
+
model,
|
| 30 |
+
StateDictType.FULL_STATE_DICT,
|
| 31 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 32 |
+
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 33 |
+
):
|
| 34 |
+
cpu_state = model.state_dict()
|
| 35 |
+
optim_state = FSDP.optim_state_dict(
|
| 36 |
+
model,
|
| 37 |
+
optimizer,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# todo move to get_state_dict
|
| 41 |
+
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
|
| 42 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 43 |
+
# save using safetensors
|
| 44 |
+
if rank <= 0 and not discriminator:
|
| 45 |
+
weight_path = os.path.join(save_dir,
|
| 46 |
+
"diffusion_pytorch_model.safetensors")
|
| 47 |
+
save_file(cpu_state, weight_path)
|
| 48 |
+
config_dict = dict(model.config)
|
| 49 |
+
config_dict.pop('dtype')
|
| 50 |
+
config_path = os.path.join(save_dir, "config.json")
|
| 51 |
+
# save dict as json
|
| 52 |
+
with open(config_path, "w") as f:
|
| 53 |
+
json.dump(config_dict, f, indent=4)
|
| 54 |
+
optimizer_path = os.path.join(save_dir, "optimizer.pt")
|
| 55 |
+
torch.save(optim_state, optimizer_path)
|
| 56 |
+
else:
|
| 57 |
+
weight_path = os.path.join(save_dir,
|
| 58 |
+
"discriminator_pytorch_model.safetensors")
|
| 59 |
+
save_file(cpu_state, weight_path)
|
| 60 |
+
optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt")
|
| 61 |
+
torch.save(optim_state, optimizer_path)
|
| 62 |
+
main_print(f"--> checkpoint saved at step {step}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def save_checkpoint(transformer, rank, output_dir, step, epoch):
|
| 66 |
+
main_print(f"--> saving checkpoint at step {step}")
|
| 67 |
+
with FSDP.state_dict_type(
|
| 68 |
+
transformer,
|
| 69 |
+
StateDictType.FULL_STATE_DICT,
|
| 70 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 71 |
+
):
|
| 72 |
+
cpu_state = transformer.state_dict()
|
| 73 |
+
# todo move to get_state_dict
|
| 74 |
+
if rank <= 0:
|
| 75 |
+
save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}")
|
| 76 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 77 |
+
# save using safetensors
|
| 78 |
+
weight_path = os.path.join(save_dir,
|
| 79 |
+
"diffusion_pytorch_model.safetensors")
|
| 80 |
+
save_file(cpu_state, weight_path)
|
| 81 |
+
config_dict = dict(transformer.config)
|
| 82 |
+
if "dtype" in config_dict:
|
| 83 |
+
del config_dict["dtype"] # TODO
|
| 84 |
+
config_path = os.path.join(save_dir, "config.json")
|
| 85 |
+
# save dict as json
|
| 86 |
+
with open(config_path, "w") as f:
|
| 87 |
+
json.dump(config_dict, f, indent=4)
|
| 88 |
+
main_print(f"--> checkpoint saved at step {step}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def save_checkpoint_generator_discriminator(
|
| 92 |
+
model,
|
| 93 |
+
optimizer,
|
| 94 |
+
discriminator,
|
| 95 |
+
discriminator_optimizer,
|
| 96 |
+
rank,
|
| 97 |
+
output_dir,
|
| 98 |
+
step,
|
| 99 |
+
):
|
| 100 |
+
with FSDP.state_dict_type(
|
| 101 |
+
model,
|
| 102 |
+
StateDictType.FULL_STATE_DICT,
|
| 103 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 104 |
+
):
|
| 105 |
+
cpu_state = model.state_dict()
|
| 106 |
+
|
| 107 |
+
# todo move to get_state_dict
|
| 108 |
+
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
|
| 109 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 110 |
+
hf_weight_dir = os.path.join(save_dir, "hf_weights")
|
| 111 |
+
os.makedirs(hf_weight_dir, exist_ok=True)
|
| 112 |
+
# save using safetensors
|
| 113 |
+
if rank <= 0:
|
| 114 |
+
config_dict = dict(model.config)
|
| 115 |
+
config_path = os.path.join(hf_weight_dir, "config.json")
|
| 116 |
+
# save dict as json
|
| 117 |
+
with open(config_path, "w") as f:
|
| 118 |
+
json.dump(config_dict, f, indent=4)
|
| 119 |
+
weight_path = os.path.join(hf_weight_dir,
|
| 120 |
+
"diffusion_pytorch_model.safetensors")
|
| 121 |
+
save_file(cpu_state, weight_path)
|
| 122 |
+
|
| 123 |
+
main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}")
|
| 124 |
+
model_weight_dir = os.path.join(save_dir, "model_weights_state")
|
| 125 |
+
os.makedirs(model_weight_dir, exist_ok=True)
|
| 126 |
+
model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state")
|
| 127 |
+
os.makedirs(model_optimizer_dir, exist_ok=True)
|
| 128 |
+
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
| 129 |
+
optim_state = FSDP.optim_state_dict(model, optimizer)
|
| 130 |
+
model_state = model.state_dict()
|
| 131 |
+
weight_state_dict = {"model": model_state}
|
| 132 |
+
dist_cp.save_state_dict(
|
| 133 |
+
state_dict=weight_state_dict,
|
| 134 |
+
storage_writer=dist_cp.FileSystemWriter(model_weight_dir),
|
| 135 |
+
planner=DefaultSavePlanner(),
|
| 136 |
+
)
|
| 137 |
+
optimizer_state_dict = {"optimizer": optim_state}
|
| 138 |
+
dist_cp.save_state_dict(
|
| 139 |
+
state_dict=optimizer_state_dict,
|
| 140 |
+
storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir),
|
| 141 |
+
planner=DefaultSavePlanner(),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
discriminator_fsdp_state_dir = os.path.join(save_dir,
|
| 145 |
+
"discriminator_fsdp_state")
|
| 146 |
+
os.makedirs(discriminator_fsdp_state_dir, exist_ok=True)
|
| 147 |
+
with FSDP.state_dict_type(
|
| 148 |
+
discriminator,
|
| 149 |
+
StateDictType.FULL_STATE_DICT,
|
| 150 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 151 |
+
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 152 |
+
):
|
| 153 |
+
optim_state = FSDP.optim_state_dict(discriminator,
|
| 154 |
+
discriminator_optimizer)
|
| 155 |
+
model_state = discriminator.state_dict()
|
| 156 |
+
state_dict = {"optimizer": optim_state, "model": model_state}
|
| 157 |
+
if rank <= 0:
|
| 158 |
+
discriminator_fsdp_state_fil = os.path.join(
|
| 159 |
+
discriminator_fsdp_state_dir, "discriminator_state.pt")
|
| 160 |
+
torch.save(state_dict, discriminator_fsdp_state_fil)
|
| 161 |
+
|
| 162 |
+
main_print("--> saved FSDP state checkpoint")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_sharded_model(model, optimizer, model_dir, optimizer_dir):
|
| 166 |
+
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
| 167 |
+
weight_state_dict = {"model": model.state_dict()}
|
| 168 |
+
|
| 169 |
+
optim_state = load_sharded_optimizer_state_dict(
|
| 170 |
+
model_state_dict=weight_state_dict["model"],
|
| 171 |
+
optimizer_key="optimizer",
|
| 172 |
+
storage_reader=dist_cp.FileSystemReader(optimizer_dir),
|
| 173 |
+
)
|
| 174 |
+
optim_state = optim_state["optimizer"]
|
| 175 |
+
flattened_osd = FSDP.optim_state_dict_to_load(
|
| 176 |
+
model=model, optim=optimizer, optim_state_dict=optim_state)
|
| 177 |
+
optimizer.load_state_dict(flattened_osd)
|
| 178 |
+
dist_cp.load_state_dict(
|
| 179 |
+
state_dict=weight_state_dict,
|
| 180 |
+
storage_reader=dist_cp.FileSystemReader(model_dir),
|
| 181 |
+
planner=DefaultLoadPlanner(),
|
| 182 |
+
)
|
| 183 |
+
model_state = weight_state_dict["model"]
|
| 184 |
+
model.load_state_dict(model_state)
|
| 185 |
+
main_print(f"--> loaded model and optimizer from path {model_dir}")
|
| 186 |
+
return model, optimizer
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def load_full_state_model(model, optimizer, checkpoint_file, rank):
|
| 190 |
+
with FSDP.state_dict_type(
|
| 191 |
+
model,
|
| 192 |
+
StateDictType.FULL_STATE_DICT,
|
| 193 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 194 |
+
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 195 |
+
):
|
| 196 |
+
discriminator_state = torch.load(checkpoint_file)
|
| 197 |
+
model_state = discriminator_state["model"]
|
| 198 |
+
if rank <= 0:
|
| 199 |
+
optim_state = discriminator_state["optimizer"]
|
| 200 |
+
else:
|
| 201 |
+
optim_state = None
|
| 202 |
+
model.load_state_dict(model_state)
|
| 203 |
+
discriminator_optim_state = FSDP.optim_state_dict_to_load(
|
| 204 |
+
model=model, optim=optimizer, optim_state_dict=optim_state)
|
| 205 |
+
optimizer.load_state_dict(discriminator_optim_state)
|
| 206 |
+
main_print(
|
| 207 |
+
f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}"
|
| 208 |
+
)
|
| 209 |
+
return model, optimizer
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def resume_training_generator_discriminator(model, optimizer, discriminator,
|
| 213 |
+
discriminator_optimizer,
|
| 214 |
+
checkpoint_dir, rank):
|
| 215 |
+
step = int(checkpoint_dir.split("-")[-1])
|
| 216 |
+
model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state")
|
| 217 |
+
model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state")
|
| 218 |
+
model, optimizer = load_sharded_model(model, optimizer, model_weight_dir,
|
| 219 |
+
model_optimizer_dir)
|
| 220 |
+
discriminator_ckpt_file = os.path.join(checkpoint_dir,
|
| 221 |
+
"discriminator_fsdp_state",
|
| 222 |
+
"discriminator_state.pt")
|
| 223 |
+
discriminator, discriminator_optimizer = load_full_state_model(
|
| 224 |
+
discriminator, discriminator_optimizer, discriminator_ckpt_file, rank)
|
| 225 |
+
return model, optimizer, discriminator, discriminator_optimizer, step
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def resume_training(model, optimizer, checkpoint_dir, discriminator=False):
|
| 229 |
+
weight_path = os.path.join(checkpoint_dir,
|
| 230 |
+
"diffusion_pytorch_model.safetensors")
|
| 231 |
+
if discriminator:
|
| 232 |
+
weight_path = os.path.join(checkpoint_dir,
|
| 233 |
+
"discriminator_pytorch_model.safetensors")
|
| 234 |
+
model_weights = load_file(weight_path)
|
| 235 |
+
|
| 236 |
+
with FSDP.state_dict_type(
|
| 237 |
+
model,
|
| 238 |
+
StateDictType.FULL_STATE_DICT,
|
| 239 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 240 |
+
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 241 |
+
):
|
| 242 |
+
current_state = model.state_dict()
|
| 243 |
+
current_state.update(model_weights)
|
| 244 |
+
model.load_state_dict(current_state, strict=False)
|
| 245 |
+
if discriminator:
|
| 246 |
+
optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt")
|
| 247 |
+
else:
|
| 248 |
+
optim_path = os.path.join(checkpoint_dir, "optimizer.pt")
|
| 249 |
+
optimizer_state_dict = torch.load(optim_path, weights_only=False)
|
| 250 |
+
optim_state = FSDP.optim_state_dict_to_load(
|
| 251 |
+
model=model, optim=optimizer, optim_state_dict=optimizer_state_dict)
|
| 252 |
+
optimizer.load_state_dict(optim_state)
|
| 253 |
+
step = int(checkpoint_dir.split("-")[-1])
|
| 254 |
+
return model, optimizer, step
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step,
|
| 258 |
+
pipeline, epoch):
|
| 259 |
+
with FSDP.state_dict_type(
|
| 260 |
+
transformer,
|
| 261 |
+
StateDictType.FULL_STATE_DICT,
|
| 262 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
| 263 |
+
):
|
| 264 |
+
full_state_dict = transformer.state_dict()
|
| 265 |
+
lora_optim_state = FSDP.optim_state_dict(
|
| 266 |
+
transformer,
|
| 267 |
+
optimizer,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if rank <= 0:
|
| 271 |
+
save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}-{epoch}")
|
| 272 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
# save optimizer
|
| 275 |
+
optim_path = os.path.join(save_dir, "lora_optimizer.pt")
|
| 276 |
+
torch.save(lora_optim_state, optim_path)
|
| 277 |
+
# save lora weight
|
| 278 |
+
main_print(f"--> saving LoRA checkpoint at step {step}")
|
| 279 |
+
transformer_lora_layers = get_peft_model_state_dict(
|
| 280 |
+
model=transformer, state_dict=full_state_dict)
|
| 281 |
+
pipeline.save_lora_weights(
|
| 282 |
+
save_directory=save_dir,
|
| 283 |
+
transformer_lora_layers=transformer_lora_layers,
|
| 284 |
+
is_main_process=True,
|
| 285 |
+
)
|
| 286 |
+
# save config
|
| 287 |
+
lora_config = {
|
| 288 |
+
"step": step,
|
| 289 |
+
"lora_params": {
|
| 290 |
+
"lora_rank": transformer.config.lora_rank,
|
| 291 |
+
"lora_alpha": transformer.config.lora_alpha,
|
| 292 |
+
"target_modules": transformer.config.lora_target_modules,
|
| 293 |
+
},
|
| 294 |
+
}
|
| 295 |
+
config_path = os.path.join(save_dir, "lora_config.json")
|
| 296 |
+
with open(config_path, "w") as f:
|
| 297 |
+
json.dump(lora_config, f, indent=4)
|
| 298 |
+
main_print(f"--> LoRA checkpoint saved at step {step}")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def resume_lora_optimizer(transformer, checkpoint_dir, optimizer):
|
| 302 |
+
config_path = os.path.join(checkpoint_dir, "lora_config.json")
|
| 303 |
+
with open(config_path, "r") as f:
|
| 304 |
+
config_dict = json.load(f)
|
| 305 |
+
optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt")
|
| 306 |
+
optimizer_state_dict = torch.load(optim_path, weights_only=False)
|
| 307 |
+
optim_state = FSDP.optim_state_dict_to_load(
|
| 308 |
+
model=transformer,
|
| 309 |
+
optim=optimizer,
|
| 310 |
+
optim_state_dict=optimizer_state_dict)
|
| 311 |
+
optimizer.load_state_dict(optim_state)
|
| 312 |
+
step = config_dict["step"]
|
| 313 |
+
main_print(f"--> Successfully resuming LoRA optimizer from step {step}")
|
| 314 |
+
return transformer, optimizer, step
|
fastvideo/utils/communications.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
|
| 2 |
+
|
| 3 |
+
from typing import Any, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from fastvideo.utils.parallel_states import nccl_info
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def broadcast(input_: torch.Tensor):
|
| 13 |
+
src = nccl_info.group_id * nccl_info.sp_size
|
| 14 |
+
dist.broadcast(input_, src=src, group=nccl_info.group)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _all_to_all_4D(input: torch.tensor,
|
| 18 |
+
scatter_idx: int = 2,
|
| 19 |
+
gather_idx: int = 1,
|
| 20 |
+
group=None) -> torch.tensor:
|
| 21 |
+
"""
|
| 22 |
+
all-to-all for QKV
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
input (torch.tensor): a tensor sharded along dim scatter dim
|
| 26 |
+
scatter_idx (int): default 1
|
| 27 |
+
gather_idx (int): default 2
|
| 28 |
+
group : torch process group
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
|
| 32 |
+
"""
|
| 33 |
+
assert (
|
| 34 |
+
input.dim() == 4
|
| 35 |
+
), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
|
| 36 |
+
|
| 37 |
+
seq_world_size = dist.get_world_size(group)
|
| 38 |
+
|
| 39 |
+
if scatter_idx == 2 and gather_idx == 1:
|
| 40 |
+
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
|
| 41 |
+
bs, shard_seqlen, hc, hs = input.shape
|
| 42 |
+
seqlen = shard_seqlen * seq_world_size
|
| 43 |
+
shard_hc = hc // seq_world_size
|
| 44 |
+
|
| 45 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
| 46 |
+
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
|
| 47 |
+
input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc,
|
| 48 |
+
hs).transpose(0, 2).contiguous())
|
| 49 |
+
|
| 50 |
+
output = torch.empty_like(input_t)
|
| 51 |
+
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
|
| 52 |
+
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
|
| 53 |
+
if seq_world_size > 1:
|
| 54 |
+
dist.all_to_all_single(output, input_t, group=group)
|
| 55 |
+
torch.cuda.synchronize()
|
| 56 |
+
else:
|
| 57 |
+
output = input_t
|
| 58 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
| 59 |
+
output = output.reshape(seqlen, bs, shard_hc, hs)
|
| 60 |
+
|
| 61 |
+
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
|
| 62 |
+
output = output.transpose(0, 1).contiguous().reshape(
|
| 63 |
+
bs, seqlen, shard_hc, hs)
|
| 64 |
+
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
elif scatter_idx == 1 and gather_idx == 2:
|
| 68 |
+
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
|
| 69 |
+
bs, seqlen, shard_hc, hs = input.shape
|
| 70 |
+
hc = shard_hc * seq_world_size
|
| 71 |
+
shard_seqlen = seqlen // seq_world_size
|
| 72 |
+
seq_world_size = dist.get_world_size(group)
|
| 73 |
+
|
| 74 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
| 75 |
+
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
|
| 76 |
+
input_t = (input.reshape(
|
| 77 |
+
bs, seq_world_size, shard_seqlen, shard_hc,
|
| 78 |
+
hs).transpose(0, 3).transpose(0, 1).contiguous().reshape(
|
| 79 |
+
seq_world_size, shard_hc, shard_seqlen, bs, hs))
|
| 80 |
+
|
| 81 |
+
output = torch.empty_like(input_t)
|
| 82 |
+
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
|
| 83 |
+
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
|
| 84 |
+
if seq_world_size > 1:
|
| 85 |
+
dist.all_to_all_single(output, input_t, group=group)
|
| 86 |
+
torch.cuda.synchronize()
|
| 87 |
+
else:
|
| 88 |
+
output = input_t
|
| 89 |
+
|
| 90 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
| 91 |
+
output = output.reshape(hc, shard_seqlen, bs, hs)
|
| 92 |
+
|
| 93 |
+
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
|
| 94 |
+
output = output.transpose(0, 2).contiguous().reshape(
|
| 95 |
+
bs, shard_seqlen, hc, hs)
|
| 96 |
+
|
| 97 |
+
return output
|
| 98 |
+
else:
|
| 99 |
+
raise RuntimeError(
|
| 100 |
+
"scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SeqAllToAll4D(torch.autograd.Function):
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def forward(
|
| 107 |
+
ctx: Any,
|
| 108 |
+
group: dist.ProcessGroup,
|
| 109 |
+
input: Tensor,
|
| 110 |
+
scatter_idx: int,
|
| 111 |
+
gather_idx: int,
|
| 112 |
+
) -> Tensor:
|
| 113 |
+
ctx.group = group
|
| 114 |
+
ctx.scatter_idx = scatter_idx
|
| 115 |
+
ctx.gather_idx = gather_idx
|
| 116 |
+
|
| 117 |
+
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def backward(ctx: Any,
|
| 121 |
+
*grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
| 122 |
+
return (
|
| 123 |
+
None,
|
| 124 |
+
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx,
|
| 125 |
+
ctx.scatter_idx),
|
| 126 |
+
None,
|
| 127 |
+
None,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def all_to_all_4D(
|
| 132 |
+
input_: torch.Tensor,
|
| 133 |
+
scatter_dim: int = 2,
|
| 134 |
+
gather_dim: int = 1,
|
| 135 |
+
):
|
| 136 |
+
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim,
|
| 137 |
+
gather_dim)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _all_to_all(
|
| 141 |
+
input_: torch.Tensor,
|
| 142 |
+
world_size: int,
|
| 143 |
+
group: dist.ProcessGroup,
|
| 144 |
+
scatter_dim: int,
|
| 145 |
+
gather_dim: int,
|
| 146 |
+
):
|
| 147 |
+
input_list = [
|
| 148 |
+
t.contiguous()
|
| 149 |
+
for t in torch.tensor_split(input_, world_size, scatter_dim)
|
| 150 |
+
]
|
| 151 |
+
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
| 152 |
+
dist.all_to_all(output_list, input_list, group=group)
|
| 153 |
+
return torch.cat(output_list, dim=gather_dim).contiguous()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class _AllToAll(torch.autograd.Function):
|
| 157 |
+
"""All-to-all communication.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
input_: input matrix
|
| 161 |
+
process_group: communication group
|
| 162 |
+
scatter_dim: scatter dimension
|
| 163 |
+
gather_dim: gather dimension
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
| 168 |
+
ctx.process_group = process_group
|
| 169 |
+
ctx.scatter_dim = scatter_dim
|
| 170 |
+
ctx.gather_dim = gather_dim
|
| 171 |
+
ctx.world_size = dist.get_world_size(process_group)
|
| 172 |
+
output = _all_to_all(input_, ctx.world_size, process_group,
|
| 173 |
+
scatter_dim, gather_dim)
|
| 174 |
+
return output
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def backward(ctx, grad_output):
|
| 178 |
+
grad_output = _all_to_all(
|
| 179 |
+
grad_output,
|
| 180 |
+
ctx.world_size,
|
| 181 |
+
ctx.process_group,
|
| 182 |
+
ctx.gather_dim,
|
| 183 |
+
ctx.scatter_dim,
|
| 184 |
+
)
|
| 185 |
+
return (
|
| 186 |
+
grad_output,
|
| 187 |
+
None,
|
| 188 |
+
None,
|
| 189 |
+
None,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def all_to_all(
|
| 194 |
+
input_: torch.Tensor,
|
| 195 |
+
scatter_dim: int = 2,
|
| 196 |
+
gather_dim: int = 1,
|
| 197 |
+
):
|
| 198 |
+
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class _AllGather(torch.autograd.Function):
|
| 202 |
+
"""All-gather communication with autograd support.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
input_: input tensor
|
| 206 |
+
dim: dimension along which to concatenate
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def forward(ctx, input_, dim):
|
| 211 |
+
ctx.dim = dim
|
| 212 |
+
world_size = nccl_info.sp_size
|
| 213 |
+
group = nccl_info.group
|
| 214 |
+
input_size = list(input_.size())
|
| 215 |
+
|
| 216 |
+
ctx.input_size = input_size[dim]
|
| 217 |
+
|
| 218 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
| 219 |
+
input_ = input_.contiguous()
|
| 220 |
+
dist.all_gather(tensor_list, input_, group=group)
|
| 221 |
+
|
| 222 |
+
output = torch.cat(tensor_list, dim=dim)
|
| 223 |
+
return output
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def backward(ctx, grad_output):
|
| 227 |
+
world_size = nccl_info.sp_size
|
| 228 |
+
rank = nccl_info.rank_within_group
|
| 229 |
+
dim = ctx.dim
|
| 230 |
+
input_size = ctx.input_size
|
| 231 |
+
|
| 232 |
+
sizes = [input_size] * world_size
|
| 233 |
+
|
| 234 |
+
grad_input_list = torch.split(grad_output, sizes, dim=dim)
|
| 235 |
+
grad_input = grad_input_list[rank]
|
| 236 |
+
|
| 237 |
+
return grad_input, None
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def all_gather(input_: torch.Tensor, dim: int = 1):
|
| 241 |
+
"""Performs an all-gather operation on the input tensor along the specified dimension.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
|
| 245 |
+
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
|
| 249 |
+
"""
|
| 250 |
+
return _AllGather.apply(input_, dim)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def prepare_sequence_parallel_data(
|
| 254 |
+
encoder_hidden_states, encoder_attention_mask, caption
|
| 255 |
+
):
|
| 256 |
+
if nccl_info.sp_size == 1:
|
| 257 |
+
return (
|
| 258 |
+
encoder_hidden_states,
|
| 259 |
+
encoder_attention_mask,
|
| 260 |
+
caption,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def prepare(
|
| 264 |
+
encoder_hidden_states, encoder_attention_mask, caption
|
| 265 |
+
):
|
| 266 |
+
#hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
|
| 267 |
+
encoder_hidden_states = all_to_all(
|
| 268 |
+
encoder_hidden_states, scatter_dim=1, gather_dim=0
|
| 269 |
+
)
|
| 270 |
+
#attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0)
|
| 271 |
+
encoder_attention_mask = all_to_all(
|
| 272 |
+
encoder_attention_mask, scatter_dim=1, gather_dim=0
|
| 273 |
+
)
|
| 274 |
+
return (
|
| 275 |
+
encoder_hidden_states,
|
| 276 |
+
encoder_attention_mask,
|
| 277 |
+
caption
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
sp_size = nccl_info.sp_size
|
| 281 |
+
#frame = hidden_states.shape[2]
|
| 282 |
+
#assert frame % sp_size == 0, "frame should be a multiple of sp_size"
|
| 283 |
+
|
| 284 |
+
(
|
| 285 |
+
#hidden_states,
|
| 286 |
+
encoder_hidden_states,
|
| 287 |
+
#attention_mask,
|
| 288 |
+
encoder_attention_mask,
|
| 289 |
+
caption,
|
| 290 |
+
) = prepare(
|
| 291 |
+
#hidden_states,
|
| 292 |
+
encoder_hidden_states.repeat(1, sp_size, 1),
|
| 293 |
+
#attention_mask.repeat(1, sp_size, 1, 1),
|
| 294 |
+
encoder_attention_mask.repeat(1, sp_size),
|
| 295 |
+
caption,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
return encoder_hidden_states, encoder_attention_mask, caption
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def sp_parallel_dataloader_wrapper(
|
| 302 |
+
dataloader, device, train_batch_size, sp_size, train_sp_batch_size
|
| 303 |
+
):
|
| 304 |
+
while True:
|
| 305 |
+
for data_item in dataloader:
|
| 306 |
+
cond, cond_mask, caption = data_item
|
| 307 |
+
#latents = latents.to(device)
|
| 308 |
+
cond = cond.to(device)
|
| 309 |
+
#attn_mask = attn_mask.to(device)
|
| 310 |
+
cond_mask = cond_mask.to(device)
|
| 311 |
+
#frame = latents.shape[2]
|
| 312 |
+
frame = 19
|
| 313 |
+
if frame == 1:
|
| 314 |
+
yield cond, cond_mask, caption
|
| 315 |
+
else:
|
| 316 |
+
cond, cond_mask, caption = prepare_sequence_parallel_data(
|
| 317 |
+
cond, cond_mask, caption
|
| 318 |
+
)
|
| 319 |
+
assert (
|
| 320 |
+
train_batch_size * sp_size >= train_sp_batch_size
|
| 321 |
+
), "train_batch_size * sp_size should be greater than train_sp_batch_size"
|
| 322 |
+
for iter in range(train_batch_size * sp_size // train_sp_batch_size):
|
| 323 |
+
st_idx = iter * train_sp_batch_size
|
| 324 |
+
ed_idx = (iter + 1) * train_sp_batch_size
|
| 325 |
+
encoder_hidden_states = cond[st_idx:ed_idx]
|
| 326 |
+
#attention_mask = attn_mask[st_idx:ed_idx]
|
| 327 |
+
encoder_attention_mask = cond_mask[st_idx:ed_idx]
|
| 328 |
+
yield (
|
| 329 |
+
#latents[st_idx:ed_idx],
|
| 330 |
+
encoder_hidden_states,
|
| 331 |
+
#attention_mask,
|
| 332 |
+
encoder_attention_mask,
|
| 333 |
+
caption
|
| 334 |
+
)
|
| 335 |
+
|