studyOverflow commited on
Commit
b171568
·
verified ·
1 Parent(s): 45d12c1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fastvideo/config_sd/__pycache__/base.cpython-310.pyc +0 -0
  2. fastvideo/config_sd/base.py +113 -0
  3. fastvideo/config_sd/dgx.py +60 -0
  4. fastvideo/data_preprocess/.DS_Store +0 -0
  5. fastvideo/data_preprocess/preprocess_flux_embedding.py +170 -0
  6. fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py +172 -0
  7. fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py +224 -0
  8. fastvideo/data_preprocess/preprocess_qwenimage_embedding.py +220 -0
  9. fastvideo/data_preprocess/preprocess_rl_embeddings.py +175 -0
  10. fastvideo/data_preprocess/preprocess_text_embeddings.py +175 -0
  11. fastvideo/data_preprocess/preprocess_vae_latents.py +137 -0
  12. fastvideo/data_preprocess/preprocess_validation_text_embeddings.py +80 -0
  13. fastvideo/dataset/.DS_Store +0 -0
  14. fastvideo/dataset/__init__.py +104 -0
  15. fastvideo/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  16. fastvideo/dataset/__pycache__/__init__.cpython-312.pyc +0 -0
  17. fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc +0 -0
  18. fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc +0 -0
  19. fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc +0 -0
  20. fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc +0 -0
  21. fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc +0 -0
  22. fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc +0 -0
  23. fastvideo/dataset/__pycache__/transform.cpython-310.pyc +0 -0
  24. fastvideo/dataset/__pycache__/transform.cpython-312.pyc +0 -0
  25. fastvideo/dataset/latent_datasets.py +132 -0
  26. fastvideo/dataset/latent_flux_rfpt_datasets.py +122 -0
  27. fastvideo/dataset/latent_flux_rfpt_datasets_all.py +134 -0
  28. fastvideo/dataset/latent_flux_rl_datasets.py +110 -0
  29. fastvideo/dataset/latent_qwenimage_rl_datasets.py +90 -0
  30. fastvideo/dataset/latent_rl_datasets.py +99 -0
  31. fastvideo/dataset/t2v_datasets.py +351 -0
  32. fastvideo/dataset/transform.py +647 -0
  33. fastvideo/distill/__init__.py +0 -0
  34. fastvideo/distill/__pycache__/__init__.cpython-312.pyc +0 -0
  35. fastvideo/distill/__pycache__/solver.cpython-312.pyc +0 -0
  36. fastvideo/distill/discriminator.py +84 -0
  37. fastvideo/distill/solver.py +310 -0
  38. fastvideo/models/.DS_Store +0 -0
  39. fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc +0 -0
  40. fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc +0 -0
  41. fastvideo/models/flash_attn_no_pad.py +37 -0
  42. fastvideo/reward_model/clip_score.py +98 -0
  43. fastvideo/reward_model/hps_score.py +79 -0
  44. fastvideo/reward_model/image_reward.py +40 -0
  45. fastvideo/reward_model/pick_score.py +107 -0
  46. fastvideo/reward_model/unified_reward.py +333 -0
  47. fastvideo/reward_model/utils.py +126 -0
  48. fastvideo/utils/.DS_Store +0 -0
  49. fastvideo/utils/checkpoint.py +314 -0
  50. fastvideo/utils/communications.py +335 -0
fastvideo/config_sd/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
fastvideo/config_sd/base.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+
3
+
4
+ def get_config():
5
+ config = ml_collections.ConfigDict()
6
+
7
+ ###### General ######
8
+ # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime.
9
+ config.run_name = ""
10
+ # random seed for reproducibility.
11
+ config.seed = 42
12
+ # top-level logging directory for checkpoint saving.
13
+ config.logdir = "logs"
14
+ # number of epochs to train for. each epoch is one round of sampling from the model followed by training on those
15
+ # samples.
16
+ config.num_epochs = 300
17
+ # number of epochs between saving model checkpoints.
18
+ config.save_freq = 20
19
+ # number of checkpoints to keep before overwriting old ones.
20
+ config.num_checkpoint_limit = 5
21
+ # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly.
22
+ config.mixed_precision = "bf16"
23
+ # allow tf32 on Ampere GPUs, which can speed up training.
24
+ config.allow_tf32 = True
25
+ # resume training from a checkpoint. either an exact checkpoint directory (e.g. checkpoint_50), or a directory
26
+ # containing checkpoints, in which case the latest one will be used. `config.use_lora` must be set to the same value
27
+ # as the run that generated the saved checkpoint.
28
+ config.resume_from = ""
29
+ # whether or not to use LoRA. LoRA reduces memory usage significantly by injecting small weight matrices into the
30
+ # attention layers of the UNet. with LoRA, fp16, and a batch size of 1, finetuning Stable Diffusion should take
31
+ # about 10GB of GPU memory. beware that if LoRA is disabled, training will take a lot of memory and saved checkpoint
32
+ # files will also be large.
33
+ config.use_lora = False
34
+
35
+ ###### Pretrained Model ######
36
+ config.pretrained = pretrained = ml_collections.ConfigDict()
37
+ # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub.
38
+ pretrained.model = "./data/StableDiffusion"
39
+ # revision of the model to load.
40
+ pretrained.revision = "main"
41
+
42
+ ###### Sampling ######
43
+ config.sample = sample = ml_collections.ConfigDict()
44
+ # number of sampler inference steps.
45
+ sample.num_steps = 50
46
+ # eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0
47
+ # being fully deterministic and 1.0 being equivalent to the DDPM sampler.
48
+ sample.eta = 1.0
49
+ # classifier-free guidance weight. 1.0 is no guidance.
50
+ sample.guidance_scale = 5.0
51
+ # batch size (per GPU!) to use for sampling.
52
+ sample.batch_size = 1
53
+ # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch *
54
+ # batch_size * num_gpus`.
55
+ sample.num_batches_per_epoch = 2
56
+
57
+ ###### Training ######
58
+ config.train = train = ml_collections.ConfigDict()
59
+ # batch size (per GPU!) to use for training.
60
+ train.batch_size = 1
61
+ # whether to use the 8bit Adam optimizer from bitsandbytes.
62
+ train.use_8bit_adam = False
63
+ # learning rate.
64
+ train.learning_rate = 1e-5
65
+ # Adam beta1.
66
+ train.adam_beta1 = 0.9
67
+ # Adam beta2.
68
+ train.adam_beta2 = 0.999
69
+ # Adam weight decay.
70
+ train.adam_weight_decay = 1e-4
71
+ # Adam epsilon.
72
+ train.adam_epsilon = 1e-8
73
+ # number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus *
74
+ # gradient_accumulation_steps`.
75
+ train.gradient_accumulation_steps = 1
76
+ # maximum gradient norm for gradient clipping.
77
+ train.max_grad_norm = 1.0
78
+ # number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one
79
+ # outer epoch's round of sampling.
80
+ train.num_inner_epochs = 1
81
+ # whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during
82
+ # sampling will be used during training.
83
+ train.cfg = True
84
+ # clip advantages to the range [-adv_clip_max, adv_clip_max].
85
+ train.adv_clip_max = 5
86
+ # the PPO clip range.
87
+ train.clip_range = 1e-4
88
+ # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
89
+ # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates.
90
+ train.timestep_fraction = 1.0
91
+
92
+ ###### Prompt Function ######
93
+ # prompt function to use. see `prompts.py` for available prompt functions.
94
+ config.prompt_fn = "imagenet_animals"
95
+ # kwargs to pass to the prompt function.
96
+ config.prompt_fn_kwargs = {}
97
+
98
+ ###### Reward Function ######
99
+ # reward function to use. see `rewards.py` for available reward functions.
100
+ config.reward_fn = "hpsv2"
101
+
102
+ ###### Per-Prompt Stat Tracking ######
103
+ # when enabled, the model will track the mean and std of reward on a per-prompt basis and use that to compute
104
+ # advantages. set `config.per_prompt_stat_tracking` to None to disable per-prompt stat tracking, in which case
105
+ # advantages will be calculated using the mean and std of the entire batch.
106
+ #config.per_prompt_stat_tracking = ml_collections.ConfigDict()
107
+ # number of reward values to store in the buffer for each prompt. the buffer persists across epochs.
108
+ #config.per_prompt_stat_tracking.buffer_size = 16
109
+ # the minimum number of reward values to store in the buffer before using the per-prompt mean and std. if the buffer
110
+ # contains fewer than `min_count` values, the mean and std of the entire batch will be used instead.
111
+ #config.per_prompt_stat_tracking.min_count = 16
112
+
113
+ return config
fastvideo/config_sd/dgx.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import imp
3
+ import os
4
+
5
+ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
6
+
7
+
8
+ def compressibility():
9
+ config = base.get_config()
10
+
11
+ config.pretrained.model = "CompVis/stable-diffusion-v1-4"
12
+
13
+ config.num_epochs = 300
14
+ config.save_freq = 50
15
+ config.num_checkpoint_limit = 100000000
16
+
17
+ # the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
18
+ config.sample.batch_size = 8
19
+ config.sample.num_batches_per_epoch = 4
20
+
21
+ # this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
22
+ config.train.batch_size = 1
23
+ config.train.gradient_accumulation_steps = 4
24
+
25
+ # prompting
26
+ config.prompt_fn = "imagenet_animals"
27
+ config.prompt_fn_kwargs = {}
28
+
29
+ # rewards
30
+ config.reward_fn = "jpeg_compressibility"
31
+
32
+ config.per_prompt_stat_tracking = {
33
+ "buffer_size": 16,
34
+ "min_count": 16,
35
+ }
36
+
37
+ return config
38
+
39
+ def hps():
40
+ config = compressibility()
41
+ config.num_epochs = 300
42
+ config.reward_fn = "aesthetic_score"
43
+
44
+ # this reward is a bit harder to optimize, so I used 2 gradient updates per epoch.
45
+ config.train.gradient_accumulation_steps = 8
46
+
47
+ # the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
48
+ config.sample.batch_size = 4
49
+
50
+ # this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
51
+ config.train.batch_size = 4
52
+
53
+ config.prompt_fn = "aes"
54
+ config.chosen_number = 16
55
+ config.num_generations = 16
56
+ return config
57
+
58
+
59
+ def get_config(name):
60
+ return globals()[name]()
fastvideo/data_preprocess/.DS_Store ADDED
Binary file (6.15 kB). View file
 
fastvideo/data_preprocess/preprocess_flux_embedding.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ import argparse
14
+ import torch
15
+ from accelerate.logging import get_logger
16
+ import cv2
17
+ import json
18
+ import os
19
+ import torch.distributed as dist
20
+ from pathlib import Path
21
+
22
+ logger = get_logger(__name__)
23
+ from torch.utils.data import Dataset
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.utils.data import DataLoader
26
+ from tqdm import tqdm
27
+ import re
28
+ from diffusers import FluxPipeline
29
+
30
+ def contains_chinese(text):
31
+ return bool(re.search(r'[\u4e00-\u9fff]', text))
32
+
33
+ class T5dataset(Dataset):
34
+ def __init__(
35
+ self, txt_path, vae_debug,
36
+ ):
37
+ self.txt_path = txt_path
38
+ self.vae_debug = vae_debug
39
+ with open(self.txt_path, "r", encoding="utf-8") as f:
40
+ self.train_dataset = [
41
+ line for line in f.read().splitlines() if not contains_chinese(line)
42
+ ][:50000]
43
+
44
+ def __getitem__(self, idx):
45
+ #import pdb;pdb.set_trace()
46
+ caption = self.train_dataset[idx]
47
+ filename = str(idx)
48
+ #length = self.train_dataset[idx]["length"]
49
+ if self.vae_debug:
50
+ latents = torch.load(
51
+ os.path.join(
52
+ args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
53
+ ),
54
+ map_location="cpu",
55
+ )
56
+ else:
57
+ latents = []
58
+
59
+ return dict(caption=caption, latents=latents, filename=filename)
60
+
61
+ def __len__(self):
62
+ return len(self.train_dataset)
63
+
64
+
65
+ def main(args):
66
+ local_rank = int(os.getenv("RANK", 0))
67
+ world_size = int(os.getenv("WORLD_SIZE", 1))
68
+ print("world_size", world_size, "local rank", local_rank)
69
+
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ torch.cuda.set_device(local_rank)
72
+ if not dist.is_initialized():
73
+ dist.init_process_group(
74
+ backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
75
+ )
76
+
77
+ os.makedirs(args.output_dir, exist_ok=True)
78
+ os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
79
+ os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
80
+ os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
81
+
82
+ latents_txt_path = args.prompt_dir
83
+ train_dataset = T5dataset(latents_txt_path, args.vae_debug)
84
+ sampler = DistributedSampler(
85
+ train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
86
+ )
87
+ train_dataloader = DataLoader(
88
+ train_dataset,
89
+ sampler=sampler,
90
+ batch_size=args.train_batch_size,
91
+ num_workers=args.dataloader_num_workers,
92
+ )
93
+ flux_path = args.model_path
94
+ pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device)
95
+
96
+ json_data = []
97
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
98
+ try:
99
+ with torch.inference_mode():
100
+ if args.vae_debug:
101
+ latents = data["latents"]
102
+ for idx, video_name in enumerate(data["filename"]):
103
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
104
+ prompt=data["caption"], prompt_2=data["caption"]
105
+ )
106
+ prompt_embed_path = os.path.join(
107
+ args.output_dir, "prompt_embed", video_name + ".pt"
108
+ )
109
+ pooled_prompt_embeds_path = os.path.join(
110
+ args.output_dir, "pooled_prompt_embeds", video_name + ".pt"
111
+ )
112
+
113
+ text_ids_path = os.path.join(
114
+ args.output_dir, "text_ids", video_name + ".pt"
115
+ )
116
+ # save latent
117
+ torch.save(prompt_embeds[idx], prompt_embed_path)
118
+ torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
119
+ torch.save(text_ids[idx], text_ids_path)
120
+ item = {}
121
+ item["prompt_embed_path"] = video_name + ".pt"
122
+ item["text_ids"] = video_name + ".pt"
123
+ item["pooled_prompt_embeds_path"] = video_name + ".pt"
124
+ item["caption"] = data["caption"][idx]
125
+ json_data.append(item)
126
+ except Exception as e:
127
+ print(f"Rank {local_rank} Error: {repr(e)}")
128
+ dist.barrier()
129
+ raise
130
+ dist.barrier()
131
+ local_data = json_data
132
+ gathered_data = [None] * world_size
133
+ dist.all_gather_object(gathered_data, local_data)
134
+ if local_rank == 0:
135
+ # os.remove(latents_json_path)
136
+ all_json_data = [item for sublist in gathered_data for item in sublist]
137
+ with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
138
+ json.dump(all_json_data, f, indent=4)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ parser = argparse.ArgumentParser()
143
+ # dataset & dataloader
144
+ parser.add_argument("--model_path", type=str, default="data/mochi")
145
+ parser.add_argument("--model_type", type=str, default="mochi")
146
+ # text encoder & vae & diffusion model
147
+ parser.add_argument(
148
+ "--dataloader_num_workers",
149
+ type=int,
150
+ default=1,
151
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
152
+ )
153
+ parser.add_argument(
154
+ "--train_batch_size",
155
+ type=int,
156
+ default=1,
157
+ help="Batch size (per device) for the training dataloader.",
158
+ )
159
+ parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
160
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
161
+ parser.add_argument(
162
+ "--output_dir",
163
+ type=str,
164
+ default=None,
165
+ help="The output directory where the model predictions and checkpoints will be written.",
166
+ )
167
+ parser.add_argument("--vae_debug", action="store_true")
168
+ parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
169
+ args = parser.parse_args()
170
+ main(args)
fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ import argparse
14
+ import torch
15
+ from accelerate.logging import get_logger
16
+ import cv2
17
+ import json
18
+ import os
19
+ import torch.distributed as dist
20
+ from pathlib import Path
21
+
22
+ logger = get_logger(__name__)
23
+ from torch.utils.data import Dataset
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.utils.data import DataLoader
26
+ from tqdm import tqdm
27
+ import re
28
+ from diffusers import FluxPipeline
29
+
30
+ def contains_chinese(text):
31
+ return bool(re.search(r'[\u4e00-\u9fff]', text))
32
+
33
+ class T5dataset(Dataset):
34
+ def __init__(
35
+ self, txt_path, vae_debug,
36
+ ):
37
+ self.txt_path = txt_path
38
+ self.vae_debug = vae_debug
39
+ print(f"[DEBUG] Loading captions from: {self.txt_path}")
40
+ with open(self.txt_path, "r", encoding="utf-8") as f:
41
+ self.train_dataset = [
42
+ line.strip() for line in f.read().splitlines() if line.strip() and not contains_chinese(line)
43
+ ][:50000]
44
+ print(f"[DEBUG] Loaded {len(self.train_dataset)} captions after filtering")
45
+
46
+ def __getitem__(self, idx):
47
+ #import pdb;pdb.set_trace()
48
+ caption = self.train_dataset[idx]
49
+ filename = str(idx)
50
+ #length = self.train_dataset[idx]["length"]
51
+ if self.vae_debug:
52
+ latents = torch.load(
53
+ os.path.join(
54
+ args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
55
+ ),
56
+ map_location="cpu",
57
+ )
58
+ else:
59
+ latents = []
60
+
61
+ return dict(caption=caption, latents=latents, filename=filename)
62
+
63
+ def __len__(self):
64
+ return len(self.train_dataset)
65
+
66
+
67
+ def main(args):
68
+ local_rank = int(os.getenv("RANK", 0))
69
+ world_size = int(os.getenv("WORLD_SIZE", 1))
70
+ print("world_size", world_size, "local rank", local_rank)
71
+
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ torch.cuda.set_device(local_rank)
74
+ if not dist.is_initialized():
75
+ dist.init_process_group(
76
+ backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
77
+ )
78
+
79
+ os.makedirs(args.output_dir, exist_ok=True)
80
+ os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
81
+ os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
82
+ os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
83
+
84
+ latents_txt_path = args.prompt_dir
85
+ train_dataset = T5dataset(latents_txt_path, args.vae_debug)
86
+ sampler = DistributedSampler(
87
+ train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
88
+ )
89
+ train_dataloader = DataLoader(
90
+ train_dataset,
91
+ sampler=sampler,
92
+ batch_size=args.train_batch_size,
93
+ num_workers=args.dataloader_num_workers,
94
+ )
95
+ flux_path = args.model_path
96
+ pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device)
97
+
98
+ json_data = []
99
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
100
+ try:
101
+ with torch.inference_mode():
102
+ if args.vae_debug:
103
+ latents = data["latents"]
104
+ for idx, video_name in enumerate(data["filename"]):
105
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
106
+ prompt=data["caption"], prompt_2=data["caption"]
107
+ )
108
+ prompt_embed_path = os.path.join(
109
+ args.output_dir, "prompt_embed", video_name + ".pt"
110
+ )
111
+ pooled_prompt_embeds_path = os.path.join(
112
+ args.output_dir, "pooled_prompt_embeds", video_name + ".pt"
113
+ )
114
+
115
+ text_ids_path = os.path.join(
116
+ args.output_dir, "text_ids", video_name + ".pt"
117
+ )
118
+ # save latent
119
+ torch.save(prompt_embeds[idx], prompt_embed_path)
120
+ torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
121
+ torch.save(text_ids[idx], text_ids_path)
122
+ item = {}
123
+ item["prompt_embed_path"] = video_name + ".pt"
124
+ item["text_ids"] = video_name + ".pt"
125
+ item["pooled_prompt_embeds_path"] = video_name + ".pt"
126
+ item["caption"] = data["caption"][idx]
127
+ json_data.append(item)
128
+ except Exception as e:
129
+ print(f"Rank {local_rank} Error: {repr(e)}")
130
+ dist.barrier()
131
+ raise
132
+ dist.barrier()
133
+ local_data = json_data
134
+ gathered_data = [None] * world_size
135
+ dist.all_gather_object(gathered_data, local_data)
136
+ if local_rank == 0:
137
+ # os.remove(latents_json_path)
138
+ all_json_data = [item for sublist in gathered_data for item in sublist]
139
+ with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
140
+ json.dump(all_json_data, f, indent=4)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # dataset & dataloader
146
+ parser.add_argument("--model_path", type=str, default="data/mochi")
147
+ parser.add_argument("--model_type", type=str, default="mochi")
148
+ # text encoder & vae & diffusion model
149
+ parser.add_argument(
150
+ "--dataloader_num_workers",
151
+ type=int,
152
+ default=1,
153
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
154
+ )
155
+ parser.add_argument(
156
+ "--train_batch_size",
157
+ type=int,
158
+ default=1,
159
+ help="Batch size (per device) for the training dataloader.",
160
+ )
161
+ parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
162
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
163
+ parser.add_argument(
164
+ "--output_dir",
165
+ type=str,
166
+ default=None,
167
+ help="The output directory where the model predictions and checkpoints will be written.",
168
+ )
169
+ parser.add_argument("--vae_debug", action="store_true")
170
+ parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
171
+ args = parser.parse_args()
172
+ main(args)
fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ import argparse
14
+ import torch
15
+ from accelerate.logging import get_logger
16
+ import cv2
17
+ import json
18
+ import os
19
+ import torch.distributed as dist
20
+ import pandas as pd
21
+ from torch.utils.data.dataset import ConcatDataset, Dataset
22
+ import io
23
+ import torchvision.transforms as transforms
24
+ logger = get_logger(__name__)
25
+ from torch.utils.data import Dataset
26
+ from torch.utils.data.distributed import DistributedSampler
27
+ from torch.utils.data import DataLoader
28
+ from tqdm import tqdm
29
+ import re
30
+ from diffusers import FluxPipeline
31
+ from PIL import Image
32
+ from diffusers.image_processor import VaeImageProcessor
33
+
34
+ def contains_chinese(text):
35
+ return bool(re.search(r'[\u4e00-\u9fff]', text))
36
+
37
+ class RFPTdataset(Dataset):
38
+ def __init__(
39
+ self, file_path,
40
+ ):
41
+ self.file_path = file_path
42
+ file_names = os.listdir(self.file_path) # each file contains 5,000 images
43
+ self.file_names = [os.path.join(self.file_path, file_name) for file_name in file_names]
44
+ self.train_dataset = self.read_data()
45
+ self.transform = transforms.ToTensor()
46
+
47
+ def read_data(self):
48
+ df_list = [pd.read_parquet(file_name) for file_name in self.file_names]
49
+ combined_df = pd.concat(df_list, axis=0, ignore_index=True)
50
+ return combined_df
51
+
52
+ def __len__(self):
53
+ return len(self.train_dataset)
54
+
55
+ def __getitem__(self, index):
56
+
57
+ image = self.train_dataset.iloc[index]['image']['bytes']
58
+ image = self.transform(Image.open(io.BytesIO(image)).convert('RGB'))
59
+ # print(image.shape)
60
+
61
+ caption = self.train_dataset.iloc[index]['caption_composition']
62
+ # print(caption)
63
+ filename = str(index)
64
+ if caption == None or image == None:
65
+ return self.__getitem__(index+1)
66
+ return dict(caption=caption, image=image, filename=filename)
67
+
68
+ class T5dataset(Dataset):
69
+ def __init__(
70
+ self, txt_path, vae_debug,
71
+ ):
72
+ self.txt_path = txt_path
73
+ self.vae_debug = vae_debug
74
+ with open(self.txt_path, "r", encoding="utf-8") as f:
75
+ self.train_dataset = [
76
+ line for line in f.read().splitlines() if not contains_chinese(line)
77
+ ][:50000]
78
+
79
+ def __getitem__(self, idx):
80
+ #import pdb;pdb.set_trace()
81
+ caption = self.train_dataset[idx]
82
+ filename = str(idx)
83
+ #length = self.train_dataset[idx]["length"]
84
+ if self.vae_debug:
85
+ latents = torch.load(
86
+ os.path.join(
87
+ args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
88
+ ),
89
+ map_location="cpu",
90
+ )
91
+ else:
92
+ latents = []
93
+
94
+ return dict(caption=caption, latents=latents, filename=filename)
95
+
96
+ def __len__(self):
97
+ return len(self.train_dataset)
98
+
99
+
100
+ def main(args):
101
+ local_rank = int(os.getenv("RANK", 0))
102
+ world_size = int(os.getenv("WORLD_SIZE", 1))
103
+ print("world_size", world_size, "local rank", local_rank)
104
+
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ torch.cuda.set_device(local_rank)
107
+ if not dist.is_initialized():
108
+ dist.init_process_group(
109
+ backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
110
+ )
111
+
112
+ os.makedirs(args.output_dir, exist_ok=True)
113
+ os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
114
+ os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
115
+ os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
116
+ os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True)
117
+
118
+ # latents_txt_path = args.prompt_dir
119
+ # train_dataset = T5dataset(latents_txt_path, args.vae_debug)
120
+
121
+ train_dataset = RFPTdataset(args.prompt_dir)
122
+
123
+
124
+ sampler = DistributedSampler(
125
+ train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
126
+ )
127
+
128
+ train_dataloader = DataLoader(
129
+ train_dataset,
130
+ sampler=sampler,
131
+ batch_size=args.train_batch_size,
132
+ num_workers=args.dataloader_num_workers,
133
+ )
134
+ flux_path = args.model_path
135
+ pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device)
136
+ image_processor = VaeImageProcessor(16)
137
+
138
+ json_data = []
139
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
140
+ try:
141
+ with torch.inference_mode():
142
+ if args.vae_debug:
143
+ latents = data["latents"]
144
+ for idx, video_name in enumerate(data["filename"]):
145
+ # prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
146
+ # prompt=data["caption"], prompt_2=data["caption"]
147
+ # )
148
+ # image_latents = pipe.vae.encode(data["image"].to(torch.bfloat16).to(device)).latent_dist.sample()
149
+ # output_image = pipe.vae.decode(image_latents, return_dict=False)[0]
150
+ # output_image = image_processor.postprocess(output_image)
151
+ # output_image[0].save('output.png')
152
+ # print(image_latents.latent_dist.sample())
153
+ # print(image_latents.latent_dist.sample().shape)
154
+
155
+ prompt_embed_path = os.path.join(
156
+ args.output_dir, "prompt_embed", video_name + ".pt"
157
+ )
158
+ pooled_prompt_embeds_path = os.path.join(
159
+ args.output_dir, "pooled_prompt_embeds", video_name + ".pt"
160
+ )
161
+
162
+ text_ids_path = os.path.join(
163
+ args.output_dir, "text_ids", video_name + ".pt"
164
+ )
165
+
166
+ image_latents_path = os.path.join(
167
+ args.output_dir, "images", video_name + ".pt"
168
+ )
169
+ # save latent
170
+ # torch.save(prompt_embeds[idx], prompt_embed_path)
171
+ # torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
172
+ # torch.save(text_ids[idx], text_ids_path)
173
+ torch.save(data["image"].to(torch.bfloat16), image_latents_path)
174
+ item = {}
175
+ item["prompt_embed_path"] = video_name + ".pt"
176
+ item["text_ids"] = video_name + ".pt"
177
+ item["pooled_prompt_embeds_path"] = video_name + ".pt"
178
+ item["caption"] = data["caption"][idx]
179
+ json_data.append(item)
180
+ except Exception as e:
181
+ print(f"Rank {local_rank} Error: {repr(e)}")
182
+ dist.barrier()
183
+ raise
184
+ dist.barrier()
185
+ local_data = json_data
186
+ gathered_data = [None] * world_size
187
+ dist.all_gather_object(gathered_data, local_data)
188
+ if local_rank == 0:
189
+ # os.remove(latents_json_path)
190
+ all_json_data = [item for sublist in gathered_data for item in sublist]
191
+ with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
192
+ json.dump(all_json_data, f, indent=4)
193
+
194
+
195
+ if __name__ == "__main__":
196
+ parser = argparse.ArgumentParser()
197
+ # dataset & dataloader
198
+ parser.add_argument("--model_path", type=str, default="data/mochi")
199
+ parser.add_argument("--model_type", type=str, default="mochi")
200
+ # text encoder & vae & diffusion model
201
+ parser.add_argument(
202
+ "--dataloader_num_workers",
203
+ type=int,
204
+ default=1,
205
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
206
+ )
207
+ parser.add_argument(
208
+ "--train_batch_size",
209
+ type=int,
210
+ default=1,
211
+ help="Batch size (per device) for the training dataloader.",
212
+ )
213
+ parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
214
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
215
+ parser.add_argument(
216
+ "--output_dir",
217
+ type=str,
218
+ default=None,
219
+ help="The output directory where the model predictions and checkpoints will be written.",
220
+ )
221
+ parser.add_argument("--vae_debug", action="store_true")
222
+ parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
223
+ args = parser.parse_args()
224
+ main(args)
fastvideo/data_preprocess/preprocess_qwenimage_embedding.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import argparse
13
+ import torch
14
+ from accelerate.logging import get_logger
15
+ # from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
16
+ from diffusers.utils import export_to_video
17
+ from fastvideo.models.qwenimage.pipeline_qwenimage import QwenImagePipeline
18
+ import json
19
+ import os
20
+ import torch.distributed as dist
21
+
22
+ logger = get_logger(__name__)
23
+ from torch.utils.data import Dataset
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.utils.data import DataLoader
26
+ from fastvideo.utils.load import load_text_encoder, load_vae
27
+ from diffusers.video_processor import VideoProcessor
28
+ from tqdm import tqdm
29
+ import re
30
+ from diffusers import DiffusionPipeline
31
+ import torch.nn.functional as F
32
+
33
+ def contains_chinese(text):
34
+ """检查字符串是否包含中文字符"""
35
+ return bool(re.search(r'[\u4e00-\u9fff]', text))
36
+
37
+ class T5dataset(Dataset):
38
+ def __init__(
39
+ self, txt_path, vae_debug,
40
+ ):
41
+ self.txt_path = txt_path
42
+ self.vae_debug = vae_debug
43
+ with open(self.txt_path, "r", encoding="utf-8") as f:
44
+ self.train_dataset = [
45
+ line for line in f.read().splitlines() if not contains_chinese(line)
46
+ ]
47
+ #self.train_dataset = sorted(train_dataset)
48
+
49
+ def __getitem__(self, idx):
50
+ #import pdb;pdb.set_trace()
51
+ caption = self.train_dataset[idx]
52
+ filename = str(idx)
53
+ #length = self.train_dataset[idx]["length"]
54
+ if self.vae_debug:
55
+ latents = torch.load(
56
+ os.path.join(
57
+ args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
58
+ ),
59
+ map_location="cpu",
60
+ )
61
+ else:
62
+ latents = []
63
+
64
+ return dict(caption=caption, latents=latents, filename=filename)
65
+
66
+ def __len__(self):
67
+ return len(self.train_dataset)
68
+
69
+
70
+ def main(args):
71
+ local_rank = int(os.getenv("RANK", 0))
72
+ world_size = int(os.getenv("WORLD_SIZE", 1))
73
+ print("world_size", world_size, "local rank", local_rank)
74
+
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ torch.cuda.set_device(local_rank)
77
+ if not dist.is_initialized():
78
+ dist.init_process_group(
79
+ backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
80
+ )
81
+
82
+ #videoprocessor = VideoProcessor(vae_scale_factor=8)
83
+ os.makedirs(args.output_dir, exist_ok=True)
84
+ os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
85
+ os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True)
86
+
87
+ latents_txt_path = args.prompt_dir
88
+ train_dataset = T5dataset(latents_txt_path, args.vae_debug)
89
+ #text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
90
+ #vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
91
+ #vae.enable_tiling()
92
+ sampler = DistributedSampler(
93
+ train_dataset, rank=local_rank, num_replicas=world_size, shuffle=False
94
+ )
95
+ train_dataloader = DataLoader(
96
+ train_dataset,
97
+ sampler=sampler,
98
+ batch_size=args.train_batch_size,
99
+ num_workers=args.dataloader_num_workers,
100
+ )
101
+ # Load pipeline but don't move everything to GPU yet
102
+ pipe = QwenImagePipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
103
+
104
+ # Only move text_encoder to GPU for embedding generation
105
+ pipe.text_encoder = pipe.text_encoder.to(device)
106
+
107
+ # Delete unused components to free up RAM/VRAM
108
+ if not args.vae_debug:
109
+ # Remove from attributes
110
+ if hasattr(pipe, "transformer"):
111
+ del pipe.transformer
112
+ if hasattr(pipe, "vae"):
113
+ del pipe.vae
114
+
115
+ # Remove from components dictionary to ensure garbage collection
116
+ if "transformer" in pipe.components:
117
+ del pipe.components["transformer"]
118
+ if "vae" in pipe.components:
119
+ del pipe.components["vae"]
120
+
121
+ import gc
122
+ gc.collect()
123
+ torch.cuda.empty_cache()
124
+
125
+ # pipe._execution_device = device # This causes AttributeError, removing it.
126
+
127
+ json_data = []
128
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
129
+ with torch.inference_mode():
130
+ with torch.autocast("cuda"):
131
+ prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
132
+ prompt=data["caption"],
133
+ device=device # Explicitly pass device
134
+ )
135
+
136
+ # ==================== 代码修改开始 ====================
137
+
138
+ # 1. 记录原始的序列长度 (第二个维度的大小)
139
+ original_length = prompt_embeds.shape[1]
140
+ target_length = 1024
141
+
142
+ # 2. 计算需要填充的长度
143
+ # 假设 original_length 不会超过 target_length
144
+ pad_len = target_length - original_length
145
+
146
+ # 3. 填充 prompt_embeds
147
+ # prompt_embeds 是一个3D张量 (B, L, D),我们需要填充第二个维度 L
148
+ # F.pad 的填充参数顺序是从最后一个维度开始的 (pad_dim_D_left, pad_dim_D_right, pad_dim_L_left, pad_dim_L_right, ...)
149
+ # 我们在维度1(序列长度L)的右侧进行填充
150
+ prompt_embeds = F.pad(prompt_embeds, (0, 0, 0, pad_len), "constant", 0)
151
+
152
+ # 4. 填充 prompt_attention_mask
153
+ # prompt_attention_mask 是一个2D张量 (B, L),我们同样填充第二个维度 L
154
+ # 我们在维度1(序列长度L)的右侧进行填充
155
+ prompt_attention_mask = F.pad(prompt_attention_mask, (0, pad_len), "constant", 0)
156
+
157
+ # ==================== 代码修改结束 ====================
158
+
159
+ if args.vae_debug:
160
+ latents = data["latents"]
161
+ for idx, video_name in enumerate(data["filename"]):
162
+ prompt_embed_path = os.path.join(
163
+ args.output_dir, "prompt_embed", video_name + ".pt"
164
+ )
165
+ prompt_attention_mask_path = os.path.join(
166
+ args.output_dir, "prompt_attention_mask", video_name + ".pt"
167
+ )
168
+ # 保存 latent (注意这里保存的是填充后的张量)
169
+ torch.save(prompt_embeds[idx], prompt_embed_path)
170
+ torch.save(prompt_attention_mask[idx], prompt_attention_mask_path)
171
+ item = {}
172
+ item["prompt_embed_path"] = video_name + ".pt"
173
+ item["prompt_attention_mask"] = video_name + ".pt"
174
+ item["caption"] = data["caption"][idx]
175
+
176
+ # [新增] 将原始长度记录到 item 字典中
177
+ item["original_length"] = original_length
178
+
179
+ json_data.append(item)
180
+ dist.barrier()
181
+ local_data = json_data
182
+ gathered_data = [None] * world_size
183
+ dist.all_gather_object(gathered_data, local_data)
184
+ if local_rank == 0:
185
+ # os.remove(latents_json_path)
186
+ all_json_data = [item for sublist in gathered_data for item in sublist]
187
+ with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
188
+ json.dump(all_json_data, f, indent=4)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ parser = argparse.ArgumentParser()
193
+ # dataset & dataloader
194
+ parser.add_argument("--model_path", type=str, default="data/mochi")
195
+ parser.add_argument("--model_type", type=str, default="mochi")
196
+ # text encoder & vae & diffusion model
197
+ parser.add_argument(
198
+ "--dataloader_num_workers",
199
+ type=int,
200
+ default=1,
201
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
202
+ )
203
+ parser.add_argument(
204
+ "--train_batch_size",
205
+ type=int,
206
+ default=1,
207
+ help="Batch size (per device) for the training dataloader.",
208
+ )
209
+ parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
210
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
211
+ parser.add_argument(
212
+ "--output_dir",
213
+ type=str,
214
+ default=None,
215
+ help="The output directory where the model predictions and checkpoints will be written.",
216
+ )
217
+ parser.add_argument("--vae_debug", action="store_true")
218
+ parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
219
+ args = parser.parse_args()
220
+ main(args)
fastvideo/data_preprocess/preprocess_rl_embeddings.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import argparse
13
+ import torch
14
+ from accelerate.logging import get_logger
15
+ from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
16
+ from diffusers.utils import export_to_video
17
+ import json
18
+ import os
19
+ import torch.distributed as dist
20
+
21
+ logger = get_logger(__name__)
22
+ from torch.utils.data import Dataset
23
+ from torch.utils.data.distributed import DistributedSampler
24
+ from torch.utils.data import DataLoader
25
+ from fastvideo.utils.load import load_text_encoder, load_vae
26
+ from diffusers.video_processor import VideoProcessor
27
+ from tqdm import tqdm
28
+ import re
29
+
30
+ def contains_chinese(text):
31
+ """检查字符串是否包含中文字符"""
32
+ return bool(re.search(r'[\u4e00-\u9fff]', text))
33
+
34
+ class T5dataset(Dataset):
35
+ def __init__(
36
+ self, txt_path, vae_debug,
37
+ ):
38
+ self.txt_path = txt_path
39
+ self.vae_debug = vae_debug
40
+ with open(self.txt_path, "r", encoding="utf-8") as f:
41
+ self.train_dataset = [
42
+ line for line in f.read().splitlines() if not contains_chinese(line)
43
+ ]
44
+ #self.train_dataset = sorted(train_dataset)
45
+
46
+ def __getitem__(self, idx):
47
+ #import pdb;pdb.set_trace()
48
+ caption = self.train_dataset[idx]
49
+ filename = str(idx)
50
+ #length = self.train_dataset[idx]["length"]
51
+ if self.vae_debug:
52
+ latents = torch.load(
53
+ os.path.join(
54
+ args.output_dir, "latent", self.train_dataset[idx]["latent_path"]
55
+ ),
56
+ map_location="cpu",
57
+ )
58
+ else:
59
+ latents = []
60
+
61
+ return dict(caption=caption, latents=latents, filename=filename)
62
+
63
+ def __len__(self):
64
+ return len(self.train_dataset)
65
+
66
+
67
+ def main(args):
68
+ local_rank = int(os.getenv("RANK", 0))
69
+ world_size = int(os.getenv("WORLD_SIZE", 1))
70
+ print("world_size", world_size, "local rank", local_rank)
71
+
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ torch.cuda.set_device(local_rank)
74
+ if not dist.is_initialized():
75
+ dist.init_process_group(
76
+ backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
77
+ )
78
+
79
+ #videoprocessor = VideoProcessor(vae_scale_factor=8)
80
+ os.makedirs(args.output_dir, exist_ok=True)
81
+ os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True)
82
+ #os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
83
+ os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
84
+ os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True)
85
+
86
+ latents_txt_path = args.prompt_dir
87
+ train_dataset = T5dataset(latents_txt_path, args.vae_debug)
88
+ text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
89
+ #vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
90
+ #vae.enable_tiling()
91
+ sampler = DistributedSampler(
92
+ train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True
93
+ )
94
+ train_dataloader = DataLoader(
95
+ train_dataset,
96
+ sampler=sampler,
97
+ batch_size=args.train_batch_size,
98
+ num_workers=args.dataloader_num_workers,
99
+ )
100
+
101
+ json_data = []
102
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
103
+ with torch.inference_mode():
104
+ with torch.autocast("cuda"):
105
+ prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
106
+ prompt=data["caption"],
107
+ )
108
+ if args.vae_debug:
109
+ latents = data["latents"]
110
+ #video = vae.decode(latents.to(device), return_dict=False)[0]
111
+ #video = videoprocessor.postprocess_video(video)
112
+ for idx, video_name in enumerate(data["filename"]):
113
+ prompt_embed_path = os.path.join(
114
+ args.output_dir, "prompt_embed", video_name + ".pt"
115
+ )
116
+ #video_path = os.path.join(
117
+ # args.output_dir, "video", video_name + ".mp4"
118
+ #)
119
+ prompt_attention_mask_path = os.path.join(
120
+ args.output_dir, "prompt_attention_mask", video_name + ".pt"
121
+ )
122
+ # save latent
123
+ torch.save(prompt_embeds[idx], prompt_embed_path)
124
+ torch.save(prompt_attention_mask[idx], prompt_attention_mask_path)
125
+ #print(f"sample {video_name} saved")
126
+ #if args.vae_debug:
127
+ # export_to_video(video[idx], video_path, fps=fps)
128
+ item = {}
129
+ #item["length"] = int(data["length"][idx])
130
+ #item["latent_path"] = video_name + ".pt"
131
+ item["prompt_embed_path"] = video_name + ".pt"
132
+ item["prompt_attention_mask"] = video_name + ".pt"
133
+ item["caption"] = data["caption"][idx]
134
+ json_data.append(item)
135
+ dist.barrier()
136
+ local_data = json_data
137
+ gathered_data = [None] * world_size
138
+ dist.all_gather_object(gathered_data, local_data)
139
+ if local_rank == 0:
140
+ # os.remove(latents_json_path)
141
+ all_json_data = [item for sublist in gathered_data for item in sublist]
142
+ with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
143
+ json.dump(all_json_data, f, indent=4)
144
+
145
+
146
+ if __name__ == "__main__":
147
+ parser = argparse.ArgumentParser()
148
+ # dataset & dataloader
149
+ parser.add_argument("--model_path", type=str, default="data/mochi")
150
+ parser.add_argument("--model_type", type=str, default="mochi")
151
+ # text encoder & vae & diffusion model
152
+ parser.add_argument(
153
+ "--dataloader_num_workers",
154
+ type=int,
155
+ default=1,
156
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
157
+ )
158
+ parser.add_argument(
159
+ "--train_batch_size",
160
+ type=int,
161
+ default=1,
162
+ help="Batch size (per device) for the training dataloader.",
163
+ )
164
+ parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
165
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
166
+ parser.add_argument(
167
+ "--output_dir",
168
+ type=str,
169
+ default=None,
170
+ help="The output directory where the model predictions and checkpoints will be written.",
171
+ )
172
+ parser.add_argument("--vae_debug", action="store_true")
173
+ parser.add_argument("--prompt_dir", type=str, default="./empty.txt")
174
+ args = parser.parse_args()
175
+ main(args)
fastvideo/data_preprocess/preprocess_text_embeddings.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from accelerate.logging import get_logger
10
+ from diffusers.utils import export_to_video
11
+ from diffusers.video_processor import VideoProcessor
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from tqdm import tqdm
15
+
16
+ from fastvideo.utils.load import load_text_encoder, load_vae
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class T5dataset(Dataset):
22
+
23
+ def __init__(
24
+ self,
25
+ json_path,
26
+ vae_debug,
27
+ ):
28
+ self.json_path = json_path
29
+ self.vae_debug = vae_debug
30
+ with open(self.json_path, "r") as f:
31
+ train_dataset = json.load(f)
32
+ self.train_dataset = sorted(train_dataset,
33
+ key=lambda x: x["latent_path"])
34
+
35
+ def __getitem__(self, idx):
36
+ caption = self.train_dataset[idx]["caption"]
37
+ filename = self.train_dataset[idx]["latent_path"].split(".")[0]
38
+ length = self.train_dataset[idx]["length"]
39
+ if self.vae_debug:
40
+ latents = torch.load(
41
+ os.path.join(args.output_dir, "latent",
42
+ self.train_dataset[idx]["latent_path"]),
43
+ map_location="cpu",
44
+ )
45
+ else:
46
+ latents = []
47
+
48
+ return dict(caption=caption,
49
+ latents=latents,
50
+ filename=filename,
51
+ length=length)
52
+
53
+ def __len__(self):
54
+ return len(self.train_dataset)
55
+
56
+
57
+ def main(args):
58
+ local_rank = int(os.getenv("RANK", 0))
59
+ world_size = int(os.getenv("WORLD_SIZE", 1))
60
+ print("world_size", world_size, "local rank", local_rank)
61
+
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ torch.cuda.set_device(local_rank)
64
+ if not dist.is_initialized():
65
+ dist.init_process_group(backend="nccl",
66
+ init_method="env://",
67
+ world_size=world_size,
68
+ rank=local_rank)
69
+
70
+ videoprocessor = VideoProcessor(vae_scale_factor=8)
71
+ os.makedirs(args.output_dir, exist_ok=True)
72
+ os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True)
73
+ os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
74
+ os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
75
+ os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"),
76
+ exist_ok=True)
77
+
78
+ latents_json_path = os.path.join(args.output_dir,
79
+ "videos2caption_temp.json")
80
+ train_dataset = T5dataset(latents_json_path, args.vae_debug)
81
+ text_encoder = load_text_encoder(args.model_type,
82
+ args.model_path,
83
+ device=device)
84
+ vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
85
+ vae.enable_tiling()
86
+ sampler = DistributedSampler(train_dataset,
87
+ rank=local_rank,
88
+ num_replicas=world_size,
89
+ shuffle=True)
90
+ train_dataloader = DataLoader(
91
+ train_dataset,
92
+ sampler=sampler,
93
+ batch_size=args.train_batch_size,
94
+ num_workers=args.dataloader_num_workers,
95
+ )
96
+
97
+ json_data = []
98
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
99
+ with torch.inference_mode():
100
+ with torch.autocast("cuda", dtype=autocast_type):
101
+ prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
102
+ prompt=data["caption"], )
103
+ if args.vae_debug:
104
+ latents = data["latents"]
105
+ video = vae.decode(latents.to(device),
106
+ return_dict=False)[0]
107
+ video = videoprocessor.postprocess_video(video)
108
+ for idx, video_name in enumerate(data["filename"]):
109
+ prompt_embed_path = os.path.join(args.output_dir,
110
+ "prompt_embed",
111
+ video_name + ".pt")
112
+ video_path = os.path.join(args.output_dir, "video",
113
+ video_name + ".mp4")
114
+ prompt_attention_mask_path = os.path.join(
115
+ args.output_dir, "prompt_attention_mask",
116
+ video_name + ".pt")
117
+ # save latent
118
+ torch.save(prompt_embeds[idx], prompt_embed_path)
119
+ torch.save(prompt_attention_mask[idx],
120
+ prompt_attention_mask_path)
121
+ print(f"sample {video_name} saved")
122
+ if args.vae_debug:
123
+ export_to_video(video[idx], video_path, fps=fps)
124
+ item = {}
125
+ item["length"] = int(data["length"][idx])
126
+ item["latent_path"] = video_name + ".pt"
127
+ item["prompt_embed_path"] = video_name + ".pt"
128
+ item["prompt_attention_mask"] = video_name + ".pt"
129
+ item["caption"] = data["caption"][idx]
130
+ json_data.append(item)
131
+ dist.barrier()
132
+ local_data = json_data
133
+ gathered_data = [None] * world_size
134
+ dist.all_gather_object(gathered_data, local_data)
135
+ if local_rank == 0:
136
+ # os.remove(latents_json_path)
137
+ all_json_data = [item for sublist in gathered_data for item in sublist]
138
+ with open(os.path.join(args.output_dir, "videos2caption.json"),
139
+ "w") as f:
140
+ json.dump(all_json_data, f, indent=4)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # dataset & dataloader
146
+ parser.add_argument("--model_path", type=str, default="data/mochi")
147
+ parser.add_argument("--model_type", type=str, default="mochi")
148
+ # text encoder & vae & diffusion model
149
+ parser.add_argument(
150
+ "--dataloader_num_workers",
151
+ type=int,
152
+ default=1,
153
+ help=
154
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
155
+ )
156
+ parser.add_argument(
157
+ "--train_batch_size",
158
+ type=int,
159
+ default=1,
160
+ help="Batch size (per device) for the training dataloader.",
161
+ )
162
+ parser.add_argument("--text_encoder_name",
163
+ type=str,
164
+ default="google/t5-v1_1-xxl")
165
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
166
+ parser.add_argument(
167
+ "--output_dir",
168
+ type=str,
169
+ default=None,
170
+ help=
171
+ "The output directory where the model predictions and checkpoints will be written.",
172
+ )
173
+ parser.add_argument("--vae_debug", action="store_true")
174
+ args = parser.parse_args()
175
+ main(args)
fastvideo/data_preprocess/preprocess_vae_latents.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from accelerate.logging import get_logger
10
+ from torch.utils.data import DataLoader
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from tqdm import tqdm
13
+
14
+ from fastvideo.dataset import getdataset
15
+ from fastvideo.utils.load import load_vae
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def main(args):
21
+ local_rank = int(os.getenv("RANK", 0))
22
+ world_size = int(os.getenv("WORLD_SIZE", 1))
23
+ print("world_size", world_size, "local rank", local_rank)
24
+ train_dataset = getdataset(args)
25
+ sampler = DistributedSampler(train_dataset,
26
+ rank=local_rank,
27
+ num_replicas=world_size,
28
+ shuffle=True)
29
+ train_dataloader = DataLoader(
30
+ train_dataset,
31
+ sampler=sampler,
32
+ batch_size=args.train_batch_size,
33
+ num_workers=args.dataloader_num_workers,
34
+ )
35
+
36
+ encoder_device = torch.device(
37
+ "cuda" if torch.cuda.is_available() else "cpu")
38
+ torch.cuda.set_device(local_rank)
39
+ if not dist.is_initialized():
40
+ dist.init_process_group(backend="nccl",
41
+ init_method="env://",
42
+ world_size=world_size,
43
+ rank=local_rank)
44
+ vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
45
+ vae.enable_tiling()
46
+ os.makedirs(args.output_dir, exist_ok=True)
47
+ os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
48
+
49
+ json_data = []
50
+ for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
51
+ with torch.inference_mode():
52
+ with torch.autocast("cuda", dtype=autocast_type):
53
+ latents = vae.encode(data["pixel_values"].to(
54
+ encoder_device))["latent_dist"].sample()
55
+ for idx, video_path in enumerate(data["path"]):
56
+ video_name = os.path.basename(video_path).split(".")[0]
57
+ latent_path = os.path.join(args.output_dir, "latent",
58
+ video_name + ".pt")
59
+ torch.save(latents[idx].to(torch.bfloat16), latent_path)
60
+ item = {}
61
+ item["length"] = latents[idx].shape[1]
62
+ item["latent_path"] = video_name + ".pt"
63
+ item["caption"] = data["text"][idx]
64
+ json_data.append(item)
65
+ print(f"{video_name} processed")
66
+ dist.barrier()
67
+ local_data = json_data
68
+ gathered_data = [None] * world_size
69
+ dist.all_gather_object(gathered_data, local_data)
70
+ if local_rank == 0:
71
+ all_json_data = [item for sublist in gathered_data for item in sublist]
72
+ with open(os.path.join(args.output_dir, "videos2caption_temp.json"),
73
+ "w") as f:
74
+ json.dump(all_json_data, f, indent=4)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ parser = argparse.ArgumentParser()
79
+ # dataset & dataloader
80
+ parser.add_argument("--model_path", type=str, default="data/mochi")
81
+ parser.add_argument("--model_type", type=str, default="mochi")
82
+ parser.add_argument("--data_merge_path", type=str, required=True)
83
+ parser.add_argument("--num_frames", type=int, default=163)
84
+ parser.add_argument(
85
+ "--dataloader_num_workers",
86
+ type=int,
87
+ default=1,
88
+ help=
89
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
90
+ )
91
+ parser.add_argument(
92
+ "--train_batch_size",
93
+ type=int,
94
+ default=16,
95
+ help="Batch size (per device) for the training dataloader.",
96
+ )
97
+ parser.add_argument("--num_latent_t",
98
+ type=int,
99
+ default=28,
100
+ help="Number of latent timesteps.")
101
+ parser.add_argument("--max_height", type=int, default=480)
102
+ parser.add_argument("--max_width", type=int, default=848)
103
+ parser.add_argument("--video_length_tolerance_range",
104
+ type=int,
105
+ default=2.0)
106
+ parser.add_argument("--group_frame", action="store_true") # TODO
107
+ parser.add_argument("--group_resolution", action="store_true") # TODO
108
+ parser.add_argument("--dataset", default="t2v")
109
+ parser.add_argument("--train_fps", type=int, default=30)
110
+ parser.add_argument("--use_image_num", type=int, default=0)
111
+ parser.add_argument("--text_max_length", type=int, default=256)
112
+ parser.add_argument("--speed_factor", type=float, default=1.0)
113
+ parser.add_argument("--drop_short_ratio", type=float, default=1.0)
114
+ # text encoder & vae & diffusion model
115
+ parser.add_argument("--text_encoder_name",
116
+ type=str,
117
+ default="google/t5-v1_1-xxl")
118
+ parser.add_argument("--cache_dir", type=str, default="./cache_dir")
119
+ parser.add_argument("--cfg", type=float, default=0.0)
120
+ parser.add_argument(
121
+ "--output_dir",
122
+ type=str,
123
+ default=None,
124
+ help=
125
+ "The output directory where the model predictions and checkpoints will be written.",
126
+ )
127
+ parser.add_argument(
128
+ "--logging_dir",
129
+ type=str,
130
+ default="logs",
131
+ help=
132
+ ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
133
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
134
+ )
135
+
136
+ args = parser.parse_args()
137
+ main(args)
fastvideo/data_preprocess/preprocess_validation_text_embeddings.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import argparse
4
+ import os
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from accelerate.logging import get_logger
9
+
10
+ from fastvideo.utils.load import load_text_encoder
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def main(args):
16
+ local_rank = int(os.getenv("RANK", 0))
17
+ world_size = int(os.getenv("WORLD_SIZE", 1))
18
+ print("world_size", world_size, "local rank", local_rank)
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ torch.cuda.set_device(local_rank)
22
+ if not dist.is_initialized():
23
+ dist.init_process_group(backend="nccl",
24
+ init_method="env://",
25
+ world_size=world_size,
26
+ rank=local_rank)
27
+
28
+ text_encoder = load_text_encoder(args.model_type,
29
+ args.model_path,
30
+ device=device)
31
+ autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16
32
+ # output_dir/validation/prompt_attention_mask
33
+ # output_dir/validation/prompt_embed
34
+ os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
35
+ os.makedirs(
36
+ os.path.join(args.output_dir, "validation", "prompt_attention_mask"),
37
+ exist_ok=True,
38
+ )
39
+ os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"),
40
+ exist_ok=True)
41
+
42
+ with open(args.validation_prompt_txt, "r", encoding="utf-8") as file:
43
+ lines = file.readlines()
44
+ prompts = [line.strip() for line in lines]
45
+ for prompt in prompts:
46
+ with torch.inference_mode():
47
+ with torch.autocast("cuda", dtype=autocast_type):
48
+ prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
49
+ prompt)
50
+ file_name = prompt.split(".")[0]
51
+ prompt_embed_path = os.path.join(args.output_dir, "validation",
52
+ "prompt_embed",
53
+ f"{file_name}.pt")
54
+ prompt_attention_mask_path = os.path.join(
55
+ args.output_dir,
56
+ "validation",
57
+ "prompt_attention_mask",
58
+ f"{file_name}.pt",
59
+ )
60
+ torch.save(prompt_embeds[0], prompt_embed_path)
61
+ torch.save(prompt_attention_mask[0],
62
+ prompt_attention_mask_path)
63
+ print(f"sample {file_name} saved")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser()
68
+ # dataset & dataloader
69
+ parser.add_argument("--model_path", type=str, default="data/mochi")
70
+ parser.add_argument("--model_type", type=str, default="mochi")
71
+ parser.add_argument("--validation_prompt_txt", type=str)
72
+ parser.add_argument(
73
+ "--output_dir",
74
+ type=str,
75
+ default=None,
76
+ help=
77
+ "The output directory where the model predictions and checkpoints will be written.",
78
+ )
79
+ args = parser.parse_args()
80
+ main(args)
fastvideo/dataset/.DS_Store ADDED
Binary file (6.15 kB). View file
 
fastvideo/dataset/__init__.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from torchvision.transforms import Lambda
3
+ from transformers import AutoTokenizer
4
+
5
+ from fastvideo.dataset.t2v_datasets import T2V_dataset
6
+ from fastvideo.dataset.transform import (CenterCropResizeVideo, Normalize255,
7
+ TemporalRandomCrop)
8
+
9
+
10
+ def getdataset(args):
11
+ temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
12
+ norm_fun = Lambda(lambda x: 2.0 * x - 1.0)
13
+ resize_topcrop = [
14
+ CenterCropResizeVideo((args.max_height, args.max_width),
15
+ top_crop=True),
16
+ ]
17
+ resize = [
18
+ CenterCropResizeVideo((args.max_height, args.max_width)),
19
+ ]
20
+ transform = transforms.Compose([
21
+ # Normalize255(),
22
+ *resize,
23
+ ])
24
+ transform_topcrop = transforms.Compose([
25
+ Normalize255(),
26
+ *resize_topcrop,
27
+ norm_fun,
28
+ ])
29
+ # tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir)
30
+ tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name,
31
+ cache_dir=args.cache_dir)
32
+ if args.dataset == "t2v":
33
+ return T2V_dataset(
34
+ args,
35
+ transform=transform,
36
+ temporal_sample=temporal_sample,
37
+ tokenizer=tokenizer,
38
+ transform_topcrop=transform_topcrop,
39
+ )
40
+
41
+ raise NotImplementedError(args.dataset)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ import random
46
+
47
+ from accelerate import Accelerator
48
+ from tqdm import tqdm
49
+
50
+ from fastvideo.dataset.t2v_datasets import dataset_prog
51
+
52
+ args = type(
53
+ "args",
54
+ (),
55
+ {
56
+ "ae": "CausalVAEModel_4x8x8",
57
+ "dataset": "t2v",
58
+ "attention_mode": "xformers",
59
+ "use_rope": True,
60
+ "text_max_length": 300,
61
+ "max_height": 320,
62
+ "max_width": 240,
63
+ "num_frames": 1,
64
+ "use_image_num": 0,
65
+ "interpolation_scale_t": 1,
66
+ "interpolation_scale_h": 1,
67
+ "interpolation_scale_w": 1,
68
+ "cache_dir": "../cache_dir",
69
+ "image_data":
70
+ "/storage/ongoing/new/Open-Sora-Plan-bak/7.14bak/scripts/train_data/image_data.txt",
71
+ "video_data": "1",
72
+ "train_fps": 24,
73
+ "drop_short_ratio": 1.0,
74
+ "use_img_from_vid": False,
75
+ "speed_factor": 1.0,
76
+ "cfg": 0.1,
77
+ "text_encoder_name": "google/mt5-xxl",
78
+ "dataloader_num_workers": 10,
79
+ },
80
+ )
81
+ accelerator = Accelerator()
82
+ dataset = getdataset(args)
83
+ num = len(dataset_prog.img_cap_list)
84
+ zero = 0
85
+ for idx in tqdm(range(num)):
86
+ image_data = dataset_prog.img_cap_list[idx]
87
+ caps = [
88
+ i["cap"] if isinstance(i["cap"], list) else [i["cap"]]
89
+ for i in image_data
90
+ ]
91
+ try:
92
+ caps = [[random.choice(i)] for i in caps]
93
+ except Exception as e:
94
+ print(e)
95
+ # import ipdb;ipdb.set_trace()
96
+ print(image_data)
97
+ zero += 1
98
+ continue
99
+ assert caps[0] is not None and len(caps[0]) > 0
100
+ print(num, zero)
101
+ import ipdb
102
+
103
+ ipdb.set_trace()
104
+ print("end")
fastvideo/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
fastvideo/dataset/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (3.94 kB). View file
 
fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc ADDED
Binary file (5.12 kB). View file
 
fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc ADDED
Binary file (5.56 kB). View file
 
fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc ADDED
Binary file (4.67 kB). View file
 
fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc ADDED
Binary file (2.24 kB). View file
 
fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc ADDED
Binary file (9.14 kB). View file
 
fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc ADDED
Binary file (16 kB). View file
 
fastvideo/dataset/__pycache__/transform.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
fastvideo/dataset/__pycache__/transform.cpython-312.pyc ADDED
Binary file (27.3 kB). View file
 
fastvideo/dataset/latent_datasets.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import json
4
+ import os
5
+ import random
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class LatentDataset(Dataset):
12
+
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ num_latent_t,
17
+ cfg_rate,
18
+ ):
19
+ # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
20
+ self.json_path = json_path
21
+ self.cfg_rate = cfg_rate
22
+ self.datase_dir_path = os.path.dirname(json_path)
23
+ self.video_dir = os.path.join(self.datase_dir_path, "video")
24
+ self.latent_dir = os.path.join(self.datase_dir_path, "latent")
25
+ self.prompt_embed_dir = os.path.join(self.datase_dir_path,
26
+ "prompt_embed")
27
+ self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path,
28
+ "prompt_attention_mask")
29
+ with open(self.json_path, "r") as f:
30
+ self.data_anno = json.load(f)
31
+ # json.load(f) already keeps the order
32
+ # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
33
+ self.num_latent_t = num_latent_t
34
+ # just zero embeddings [256, 4096]
35
+ self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
36
+ # 256 zeros
37
+ self.uncond_prompt_mask = torch.zeros(256).bool()
38
+ self.lengths = [
39
+ data_item["length"] if "length" in data_item else 1
40
+ for data_item in self.data_anno
41
+ ]
42
+
43
+ def __getitem__(self, idx):
44
+ latent_file = self.data_anno[idx]["latent_path"]
45
+ prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
46
+ prompt_attention_mask_file = self.data_anno[idx][
47
+ "prompt_attention_mask"]
48
+ # load
49
+ latent = torch.load(
50
+ os.path.join(self.latent_dir, latent_file),
51
+ map_location="cpu",
52
+ weights_only=True,
53
+ )
54
+ latent = latent.squeeze(0)[:, -self.num_latent_t:]
55
+ if random.random() < self.cfg_rate:
56
+ prompt_embed = self.uncond_prompt_embed
57
+ prompt_attention_mask = self.uncond_prompt_mask
58
+ else:
59
+ prompt_embed = torch.load(
60
+ os.path.join(self.prompt_embed_dir, prompt_embed_file),
61
+ map_location="cpu",
62
+ weights_only=True,
63
+ )
64
+ prompt_attention_mask = torch.load(
65
+ os.path.join(self.prompt_attention_mask_dir,
66
+ prompt_attention_mask_file),
67
+ map_location="cpu",
68
+ weights_only=True,
69
+ )
70
+ return latent, prompt_embed, prompt_attention_mask
71
+
72
+ def __len__(self):
73
+ return len(self.data_anno)
74
+
75
+
76
+ def latent_collate_function(batch):
77
+ # return latent, prompt, latent_attn_mask, text_attn_mask
78
+ # latent_attn_mask: # b t h w
79
+ # text_attn_mask: b 1 l
80
+ # needs to check if the latent/prompt' size and apply padding & attn mask
81
+ latents, prompt_embeds, prompt_attention_masks = zip(*batch)
82
+ # calculate max shape
83
+ max_t = max([latent.shape[1] for latent in latents])
84
+ max_h = max([latent.shape[2] for latent in latents])
85
+ max_w = max([latent.shape[3] for latent in latents])
86
+
87
+ # padding
88
+ latents = [
89
+ torch.nn.functional.pad(
90
+ latent,
91
+ (
92
+ 0,
93
+ max_t - latent.shape[1],
94
+ 0,
95
+ max_h - latent.shape[2],
96
+ 0,
97
+ max_w - latent.shape[3],
98
+ ),
99
+ ) for latent in latents
100
+ ]
101
+ # attn mask
102
+ latent_attn_mask = torch.ones(len(latents), max_t, max_h, max_w)
103
+ # set to 0 if padding
104
+ for i, latent in enumerate(latents):
105
+ latent_attn_mask[i, latent.shape[1]:, :, :] = 0
106
+ latent_attn_mask[i, :, latent.shape[2]:, :] = 0
107
+ latent_attn_mask[i, :, :, latent.shape[3]:] = 0
108
+
109
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
110
+ prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
111
+ latents = torch.stack(latents, dim=0)
112
+ return latents, prompt_embeds, latent_attn_mask, prompt_attention_masks
113
+
114
+
115
+ if __name__ == "__main__":
116
+ dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt",
117
+ num_latent_t=28)
118
+ dataloader = torch.utils.data.DataLoader(
119
+ dataset,
120
+ batch_size=2,
121
+ shuffle=False,
122
+ collate_fn=latent_collate_function)
123
+ for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader:
124
+ print(
125
+ latent.shape,
126
+ prompt_embed.shape,
127
+ latent_attn_mask.shape,
128
+ prompt_attention_mask.shape,
129
+ )
130
+ import pdb
131
+
132
+ pdb.set_trace()
fastvideo/dataset/latent_flux_rfpt_datasets.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ import json
15
+ import os
16
+ import random
17
+
18
+
19
+ class LatentDataset(Dataset):
20
+ def __init__(
21
+ self, json_path, num_latent_t, cfg_rate,
22
+ ):
23
+ # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
24
+ self.json_path = json_path
25
+ self.cfg_rate = cfg_rate
26
+ self.datase_dir_path = os.path.dirname(json_path)
27
+ #self.video_dir = os.path.join(self.datase_dir_path, "video")
28
+ #self.latent_dir = os.path.join(self.datase_dir_path, "latent")
29
+ self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
30
+ self.pooled_prompt_embeds_dir = os.path.join(
31
+ self.datase_dir_path, "pooled_prompt_embeds"
32
+ )
33
+ self.text_ids_dir = os.path.join(
34
+ self.datase_dir_path, "text_ids"
35
+ )
36
+ self.latents_dir = os.path.join(
37
+ self.datase_dir_path, "images"
38
+ )
39
+ with open(self.json_path, "r") as f:
40
+ self.data_anno = json.load(f)
41
+ # json.load(f) already keeps the order
42
+ # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
43
+ self.num_latent_t = num_latent_t
44
+ # just zero embeddings [256, 4096]
45
+ self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
46
+ # 256 zeros
47
+ self.uncond_prompt_mask = torch.zeros(256).bool()
48
+ self.lengths = [
49
+ data_item["length"] if "length" in data_item else 1
50
+ for data_item in self.data_anno
51
+ ]
52
+
53
+ def __getitem__(self, idx):
54
+ #latent_file = self.data_anno[idx]["latent_path"]
55
+ prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
56
+ pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"]
57
+ text_ids_file = self.data_anno[idx]["text_ids"]
58
+ latent_file = text_ids_file
59
+ if random.random() < self.cfg_rate:
60
+ prompt_embed = self.uncond_prompt_embed
61
+ else:
62
+ prompt_embed = torch.load(
63
+ os.path.join(self.prompt_embed_dir, prompt_embed_file),
64
+ map_location="cpu",
65
+ weights_only=True,
66
+ )
67
+ pooled_prompt_embeds = torch.load(
68
+ os.path.join(
69
+ self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file
70
+ ),
71
+ map_location="cpu",
72
+ weights_only=True,
73
+ )
74
+ text_ids = torch.load(
75
+ os.path.join(
76
+ self.text_ids_dir, text_ids_file
77
+ ),
78
+ map_location="cpu",
79
+ weights_only=True,
80
+ )
81
+ latents = torch.load(
82
+ os.path.join(
83
+ self.latents_dir, latent_file
84
+ ),
85
+ map_location="cpu",
86
+ weights_only=True,
87
+ )
88
+ return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'], latents
89
+
90
+ def __len__(self):
91
+ return len(self.data_anno)
92
+
93
+
94
+ def latent_collate_function(batch):
95
+ # return latent, prompt, latent_attn_mask, text_attn_mask
96
+ # latent_attn_mask: # b t h w
97
+ # text_attn_mask: b 1 l
98
+ # needs to check if the latent/prompt' size and apply padding & attn mask
99
+ prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents = zip(*batch)
100
+ # attn mask
101
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
102
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0)
103
+ text_ids = torch.stack(text_ids, dim=0)
104
+ latents= torch.stack(latents, dim=0)
105
+ #latents = torch.stack(latents, dim=0)
106
+ return prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents
107
+
108
+
109
+ if __name__ == "__main__":
110
+ dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
111
+ dataloader = torch.utils.data.DataLoader(
112
+ dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
113
+ )
114
+ for prompt_embed, prompt_attention_mask, caption in dataloader:
115
+ print(
116
+ prompt_embed.shape,
117
+ prompt_attention_mask.shape,
118
+ caption
119
+ )
120
+ import pdb
121
+
122
+ pdb.set_trace()
fastvideo/dataset/latent_flux_rfpt_datasets_all.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ import json
15
+ import os
16
+ import random
17
+
18
+
19
+ class LatentDataset(Dataset):
20
+ def __init__(
21
+ self, json_path, num_latent_t, cfg_rate,
22
+ ):
23
+ # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
24
+ self.json_path = json_path
25
+ self.cfg_rate = cfg_rate
26
+ self.datase_dir_path = os.path.dirname(json_path)
27
+ #self.video_dir = os.path.join(self.datase_dir_path, "video")
28
+ #self.latent_dir = os.path.join(self.datase_dir_path, "latent")
29
+ self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
30
+ self.pooled_prompt_embeds_dir = os.path.join(
31
+ self.datase_dir_path, "pooled_prompt_embeds"
32
+ )
33
+ self.text_ids_dir = os.path.join(
34
+ self.datase_dir_path, "text_ids"
35
+ )
36
+ self.images_dir = os.path.join(
37
+ self.datase_dir_path, "images"
38
+ )
39
+ self.latents_dir = os.path.join(
40
+ self.datase_dir_path, "latents"
41
+ )
42
+ with open(self.json_path, "r") as f:
43
+ self.data_anno = json.load(f)
44
+ # json.load(f) already keeps the order
45
+ # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
46
+ self.num_latent_t = num_latent_t
47
+ # just zero embeddings [256, 4096]
48
+ self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
49
+ # 256 zeros
50
+ self.uncond_prompt_mask = torch.zeros(256).bool()
51
+ self.lengths = [
52
+ data_item["length"] if "length" in data_item else 1
53
+ for data_item in self.data_anno
54
+ ]
55
+
56
+ def __getitem__(self, idx):
57
+ #latent_file = self.data_anno[idx]["latent_path"]
58
+ prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
59
+ pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"]
60
+ text_ids_file = self.data_anno[idx]["text_ids"]
61
+ latent_file = text_ids_file
62
+ image_file = text_ids_file
63
+ if random.random() < self.cfg_rate:
64
+ prompt_embed = self.uncond_prompt_embed
65
+ else:
66
+ prompt_embed = torch.load(
67
+ os.path.join(self.prompt_embed_dir, prompt_embed_file),
68
+ map_location="cpu",
69
+ weights_only=True,
70
+ )
71
+ pooled_prompt_embeds = torch.load(
72
+ os.path.join(
73
+ self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file
74
+ ),
75
+ map_location="cpu",
76
+ weights_only=True,
77
+ )
78
+ text_ids = torch.load(
79
+ os.path.join(
80
+ self.text_ids_dir, text_ids_file
81
+ ),
82
+ map_location="cpu",
83
+ weights_only=True,
84
+ )
85
+ latents = torch.load(
86
+ os.path.join(
87
+ self.latents_dir, latent_file
88
+ ),
89
+ map_location="cpu",
90
+ weights_only=True,
91
+ )
92
+ images = torch.load(
93
+ os.path.join(
94
+ self.images_dir, image_file
95
+ ),
96
+ map_location="cpu",
97
+ weights_only=True,
98
+ )
99
+ return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'], latents, images
100
+
101
+ def __len__(self):
102
+ return len(self.data_anno)
103
+
104
+
105
+ def latent_collate_function(batch):
106
+ # return latent, prompt, latent_attn_mask, text_attn_mask
107
+ # latent_attn_mask: # b t h w
108
+ # text_attn_mask: b 1 l
109
+ # needs to check if the latent/prompt' size and apply padding & attn mask
110
+ prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents, images = zip(*batch)
111
+ # attn mask
112
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
113
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0)
114
+ text_ids = torch.stack(text_ids, dim=0)
115
+ latents= torch.stack(latents, dim=0)
116
+ images= torch.stack(images, dim=0)
117
+ #latents = torch.stack(latents, dim=0)
118
+ return prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents, images
119
+
120
+
121
+ if __name__ == "__main__":
122
+ dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
123
+ dataloader = torch.utils.data.DataLoader(
124
+ dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
125
+ )
126
+ for prompt_embed, prompt_attention_mask, caption in dataloader:
127
+ print(
128
+ prompt_embed.shape,
129
+ prompt_attention_mask.shape,
130
+ caption
131
+ )
132
+ import pdb
133
+
134
+ pdb.set_trace()
fastvideo/dataset/latent_flux_rl_datasets.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ import json
15
+ import os
16
+ import random
17
+
18
+
19
+ class LatentDataset(Dataset):
20
+ def __init__(
21
+ self, json_path, num_latent_t, cfg_rate,
22
+ ):
23
+ # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
24
+ self.json_path = json_path
25
+ self.cfg_rate = cfg_rate
26
+ self.datase_dir_path = os.path.dirname(json_path)
27
+ #self.video_dir = os.path.join(self.datase_dir_path, "video")
28
+ #self.latent_dir = os.path.join(self.datase_dir_path, "latent")
29
+ self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
30
+ self.pooled_prompt_embeds_dir = os.path.join(
31
+ self.datase_dir_path, "pooled_prompt_embeds"
32
+ )
33
+ self.text_ids_dir = os.path.join(
34
+ self.datase_dir_path, "text_ids"
35
+ )
36
+ with open(self.json_path, "r") as f:
37
+ self.data_anno = json.load(f)
38
+ # json.load(f) already keeps the order
39
+ # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
40
+ self.num_latent_t = num_latent_t
41
+ # just zero embeddings [256, 4096]
42
+ self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
43
+ # 256 zeros
44
+ self.uncond_prompt_mask = torch.zeros(256).bool()
45
+ self.lengths = [
46
+ data_item["length"] if "length" in data_item else 1
47
+ for data_item in self.data_anno
48
+ ]
49
+
50
+ def __getitem__(self, idx):
51
+ #latent_file = self.data_anno[idx]["latent_path"]
52
+ prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
53
+ pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"]
54
+ text_ids_file = self.data_anno[idx]["text_ids"]
55
+ if random.random() < self.cfg_rate:
56
+ prompt_embed = self.uncond_prompt_embed
57
+ else:
58
+ prompt_embed = torch.load(
59
+ os.path.join(self.prompt_embed_dir, prompt_embed_file),
60
+ map_location="cpu",
61
+ weights_only=True,
62
+ )
63
+ pooled_prompt_embeds = torch.load(
64
+ os.path.join(
65
+ self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file
66
+ ),
67
+ map_location="cpu",
68
+ weights_only=True,
69
+ )
70
+ text_ids = torch.load(
71
+ os.path.join(
72
+ self.text_ids_dir, text_ids_file
73
+ ),
74
+ map_location="cpu",
75
+ weights_only=True,
76
+ )
77
+ return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption']
78
+
79
+ def __len__(self):
80
+ return len(self.data_anno)
81
+
82
+
83
+ def latent_collate_function(batch):
84
+ # return latent, prompt, latent_attn_mask, text_attn_mask
85
+ # latent_attn_mask: # b t h w
86
+ # text_attn_mask: b 1 l
87
+ # needs to check if the latent/prompt' size and apply padding & attn mask
88
+ prompt_embeds, pooled_prompt_embeds, text_ids, caption = zip(*batch)
89
+ # attn mask
90
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
91
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0)
92
+ text_ids = torch.stack(text_ids, dim=0)
93
+ #latents = torch.stack(latents, dim=0)
94
+ return prompt_embeds, pooled_prompt_embeds, text_ids, caption
95
+
96
+
97
+ if __name__ == "__main__":
98
+ dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
99
+ dataloader = torch.utils.data.DataLoader(
100
+ dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
101
+ )
102
+ for prompt_embed, prompt_attention_mask, caption in dataloader:
103
+ print(
104
+ prompt_embed.shape,
105
+ prompt_attention_mask.shape,
106
+ caption
107
+ )
108
+ import pdb
109
+
110
+ pdb.set_trace()
fastvideo/dataset/latent_qwenimage_rl_datasets.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ import json
15
+ import os
16
+ import random
17
+
18
+
19
+ class LatentDataset(Dataset):
20
+ def __init__(
21
+ self, json_path, num_latent_t, cfg_rate,
22
+ ):
23
+ # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
24
+ self.json_path = json_path
25
+ self.cfg_rate = cfg_rate
26
+ self.datase_dir_path = os.path.dirname(json_path)
27
+ #self.video_dir = os.path.join(self.datase_dir_path, "video")
28
+ #self.latent_dir = os.path.join(self.datase_dir_path, "latent")
29
+ self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
30
+ self.prompt_attention_mask_dir = os.path.join(
31
+ self.datase_dir_path, "prompt_attention_mask"
32
+ )
33
+ with open(self.json_path, "r") as f:
34
+ self.data_anno = json.load(f)
35
+ # json.load(f) already keeps the order
36
+ # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
37
+ self.num_latent_t = num_latent_t
38
+ # just zero embeddings [256, 4096]
39
+ self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
40
+ # 256 zeros
41
+ self.uncond_prompt_mask = torch.zeros(256).bool()
42
+ self.lengths = [
43
+ data_item["length"] if "length" in data_item else 1
44
+ for data_item in self.data_anno
45
+ ]
46
+
47
+ def __getitem__(self, idx):
48
+ #latent_file = self.data_anno[idx]["latent_path"]
49
+ prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
50
+ prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"]
51
+ if random.random() < self.cfg_rate:
52
+ prompt_embed = self.uncond_prompt_embed
53
+ prompt_attention_mask = self.uncond_prompt_mask
54
+ else:
55
+ prompt_embed = torch.load(
56
+ os.path.join(self.prompt_embed_dir, prompt_embed_file),
57
+ map_location="cpu",
58
+ weights_only=True,
59
+ )
60
+ prompt_attention_mask = torch.load(
61
+ os.path.join(
62
+ self.prompt_attention_mask_dir, prompt_attention_mask_file
63
+ ),
64
+ map_location="cpu",
65
+ weights_only=True,
66
+ )
67
+ return prompt_embed, prompt_attention_mask, self.data_anno[idx]['caption'], self.data_anno[idx]['original_length']
68
+
69
+ def __len__(self):
70
+ return len(self.data_anno)
71
+
72
+
73
+ def latent_collate_function(batch):
74
+ # return latent, prompt, latent_attn_mask, text_attn_mask
75
+ # latent_attn_mask: # b t h w
76
+ # text_attn_mask: b 1 l
77
+ # needs to check if the latent/prompt' size and apply padding & attn mask
78
+ prompt_embeds, prompt_attention_masks, caption, original_length = zip(*batch)
79
+ # attn mask
80
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
81
+ prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
82
+
83
+ # Convert original_length to tensor
84
+ original_length = torch.tensor(original_length, dtype=torch.long)
85
+
86
+ # Convert caption to list
87
+ caption = list(caption)
88
+
89
+ #latents = torch.stack(latents, dim=0)
90
+ return prompt_embeds, prompt_attention_masks, caption, original_length
fastvideo/dataset/latent_rl_datasets.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) [2025] [FastVideo Team]
2
+ # Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
3
+ # SPDX-License-Identifier: [Apache License 2.0]
4
+ #
5
+ # This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025.
6
+ #
7
+ # Original file was released under [Apache License 2.0], with the full license text
8
+ # available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE].
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ import json
15
+ import os
16
+ import random
17
+
18
+
19
+ class LatentDataset(Dataset):
20
+ def __init__(
21
+ self, json_path, num_latent_t, cfg_rate,
22
+ ):
23
+ # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
24
+ self.json_path = json_path
25
+ self.cfg_rate = cfg_rate
26
+ self.datase_dir_path = os.path.dirname(json_path)
27
+ #self.video_dir = os.path.join(self.datase_dir_path, "video")
28
+ #self.latent_dir = os.path.join(self.datase_dir_path, "latent")
29
+ self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
30
+ self.prompt_attention_mask_dir = os.path.join(
31
+ self.datase_dir_path, "prompt_attention_mask"
32
+ )
33
+ with open(self.json_path, "r") as f:
34
+ self.data_anno = json.load(f)
35
+ # json.load(f) already keeps the order
36
+ # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
37
+ self.num_latent_t = num_latent_t
38
+ # just zero embeddings [256, 4096]
39
+ self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
40
+ # 256 zeros
41
+ self.uncond_prompt_mask = torch.zeros(256).bool()
42
+ self.lengths = [
43
+ data_item["length"] if "length" in data_item else 1
44
+ for data_item in self.data_anno
45
+ ]
46
+
47
+ def __getitem__(self, idx):
48
+ #latent_file = self.data_anno[idx]["latent_path"]
49
+ prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
50
+ prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"]
51
+ if random.random() < self.cfg_rate:
52
+ prompt_embed = self.uncond_prompt_embed
53
+ prompt_attention_mask = self.uncond_prompt_mask
54
+ else:
55
+ prompt_embed = torch.load(
56
+ os.path.join(self.prompt_embed_dir, prompt_embed_file),
57
+ map_location="cpu",
58
+ weights_only=True,
59
+ )
60
+ prompt_attention_mask = torch.load(
61
+ os.path.join(
62
+ self.prompt_attention_mask_dir, prompt_attention_mask_file
63
+ ),
64
+ map_location="cpu",
65
+ weights_only=True,
66
+ )
67
+ return prompt_embed, prompt_attention_mask, self.data_anno[idx]['caption']
68
+
69
+ def __len__(self):
70
+ return len(self.data_anno)
71
+
72
+
73
+ def latent_collate_function(batch):
74
+ # return latent, prompt, latent_attn_mask, text_attn_mask
75
+ # latent_attn_mask: # b t h w
76
+ # text_attn_mask: b 1 l
77
+ # needs to check if the latent/prompt' size and apply padding & attn mask
78
+ prompt_embeds, prompt_attention_masks, caption = zip(*batch)
79
+ # attn mask
80
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
81
+ prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
82
+ #latents = torch.stack(latents, dim=0)
83
+ return prompt_embeds, prompt_attention_masks, caption
84
+
85
+
86
+ if __name__ == "__main__":
87
+ dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0)
88
+ dataloader = torch.utils.data.DataLoader(
89
+ dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function
90
+ )
91
+ for prompt_embed, prompt_attention_mask, caption in dataloader:
92
+ print(
93
+ prompt_embed.shape,
94
+ prompt_attention_mask.shape,
95
+ caption
96
+ )
97
+ import pdb
98
+
99
+ pdb.set_trace()
fastvideo/dataset/t2v_datasets.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ from collections import Counter
8
+ from os.path import join as opj
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ from einops import rearrange
14
+ from PIL import Image
15
+ from torch.utils.data import Dataset
16
+
17
+ from fastvideo.utils.dataset_utils import DecordInit
18
+ from fastvideo.utils.logging_ import main_print
19
+
20
+
21
+ class SingletonMeta(type):
22
+ _instances = {}
23
+
24
+ def __call__(cls, *args, **kwargs):
25
+ if cls not in cls._instances:
26
+ instance = super().__call__(*args, **kwargs)
27
+ cls._instances[cls] = instance
28
+ return cls._instances[cls]
29
+
30
+
31
+ class DataSetProg(metaclass=SingletonMeta):
32
+
33
+ def __init__(self):
34
+ self.cap_list = []
35
+ self.elements = []
36
+ self.num_workers = 1
37
+ self.n_elements = 0
38
+ self.worker_elements = dict()
39
+ self.n_used_elements = dict()
40
+
41
+ def set_cap_list(self, num_workers, cap_list, n_elements):
42
+ self.num_workers = num_workers
43
+ self.cap_list = cap_list
44
+ self.n_elements = n_elements
45
+ self.elements = list(range(n_elements))
46
+ random.shuffle(self.elements)
47
+ print(f"n_elements: {len(self.elements)}", flush=True)
48
+
49
+ for i in range(self.num_workers):
50
+ self.n_used_elements[i] = 0
51
+ per_worker = int(
52
+ math.ceil(len(self.elements) / float(self.num_workers)))
53
+ start = i * per_worker
54
+ end = min(start + per_worker, len(self.elements))
55
+ self.worker_elements[i] = self.elements[start:end]
56
+
57
+ def get_item(self, work_info):
58
+ if work_info is None:
59
+ worker_id = 0
60
+ else:
61
+ worker_id = work_info.id
62
+
63
+ idx = self.worker_elements[worker_id][
64
+ self.n_used_elements[worker_id] %
65
+ len(self.worker_elements[worker_id])]
66
+ self.n_used_elements[worker_id] += 1
67
+ return idx
68
+
69
+
70
+ dataset_prog = DataSetProg()
71
+
72
+
73
+ def filter_resolution(h,
74
+ w,
75
+ max_h_div_w_ratio=17 / 16,
76
+ min_h_div_w_ratio=8 / 16):
77
+ if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio:
78
+ return True
79
+ return False
80
+
81
+
82
+ class T2V_dataset(Dataset):
83
+
84
+ def __init__(self, args, transform, temporal_sample, tokenizer,
85
+ transform_topcrop):
86
+ self.data = args.data_merge_path
87
+ self.num_frames = args.num_frames
88
+ self.train_fps = args.train_fps
89
+ self.use_image_num = args.use_image_num
90
+ self.transform = transform
91
+ self.transform_topcrop = transform_topcrop
92
+ self.temporal_sample = temporal_sample
93
+ self.tokenizer = tokenizer
94
+ self.text_max_length = args.text_max_length
95
+ self.cfg = args.cfg
96
+ self.speed_factor = args.speed_factor
97
+ self.max_height = args.max_height
98
+ self.max_width = args.max_width
99
+ self.drop_short_ratio = args.drop_short_ratio
100
+ assert self.speed_factor >= 1
101
+ self.v_decoder = DecordInit()
102
+ self.video_length_tolerance_range = args.video_length_tolerance_range
103
+ self.support_Chinese = True
104
+ if "mt5" not in args.text_encoder_name:
105
+ self.support_Chinese = False
106
+
107
+ cap_list = self.get_cap_list()
108
+
109
+ assert len(cap_list) > 0
110
+ cap_list, self.sample_num_frames = self.define_frame_index(cap_list)
111
+ self.lengths = self.sample_num_frames
112
+
113
+ n_elements = len(cap_list)
114
+ dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list,
115
+ n_elements)
116
+
117
+ print(f"video length: {len(dataset_prog.cap_list)}", flush=True)
118
+
119
+ def set_checkpoint(self, n_used_elements):
120
+ for i in range(len(dataset_prog.n_used_elements)):
121
+ dataset_prog.n_used_elements[i] = n_used_elements
122
+
123
+ def __len__(self):
124
+ return dataset_prog.n_elements
125
+
126
+ def __getitem__(self, idx):
127
+
128
+ data = self.get_data(idx)
129
+ return data
130
+
131
+ def get_data(self, idx):
132
+ path = dataset_prog.cap_list[idx]["path"]
133
+ if path.endswith(".mp4"):
134
+ return self.get_video(idx)
135
+ else:
136
+ return self.get_image(idx)
137
+
138
+ def get_video(self, idx):
139
+ video_path = dataset_prog.cap_list[idx]["path"]
140
+ assert os.path.exists(video_path), f"file {video_path} do not exist!"
141
+ frame_indices = dataset_prog.cap_list[idx]["sample_frame_index"]
142
+ torchvision_video, _, metadata = torchvision.io.read_video(
143
+ video_path, output_format="TCHW")
144
+ video = torchvision_video[frame_indices]
145
+ video = self.transform(video)
146
+ video = rearrange(video, "t c h w -> c t h w")
147
+ video = video.to(torch.uint8)
148
+ assert video.dtype == torch.uint8
149
+
150
+ h, w = video.shape[-2:]
151
+ assert (
152
+ h / w <= 17 / 16 and h / w >= 8 / 16
153
+ ), f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}"
154
+
155
+ video = video.float() / 127.5 - 1.0
156
+
157
+ text = dataset_prog.cap_list[idx]["cap"]
158
+ if not isinstance(text, list):
159
+ text = [text]
160
+ text = [random.choice(text)]
161
+
162
+ text = text[0] if random.random() > self.cfg else ""
163
+ text_tokens_and_mask = self.tokenizer(
164
+ text,
165
+ max_length=self.text_max_length,
166
+ padding="max_length",
167
+ truncation=True,
168
+ return_attention_mask=True,
169
+ add_special_tokens=True,
170
+ return_tensors="pt",
171
+ )
172
+ input_ids = text_tokens_and_mask["input_ids"]
173
+ cond_mask = text_tokens_and_mask["attention_mask"]
174
+ return dict(
175
+ pixel_values=video,
176
+ text=text,
177
+ input_ids=input_ids,
178
+ cond_mask=cond_mask,
179
+ path=video_path,
180
+ )
181
+
182
+ def get_image(self, idx):
183
+ image_data = dataset_prog.cap_list[
184
+ idx] # [{'path': path, 'cap': cap}, ...]
185
+
186
+ image = Image.open(image_data["path"]).convert("RGB") # [h, w, c]
187
+ image = torch.from_numpy(np.array(image)) # [h, w, c]
188
+ image = rearrange(image, "h w c -> c h w").unsqueeze(0) # [1 c h w]
189
+ # for i in image:
190
+ # h, w = i.shape[-2:]
191
+ # assert h / w <= 17 / 16 and h / w >= 8 / 16, f'Only image with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But found ratio is {round(h / w, 2)} with the shape of {i.shape}'
192
+
193
+ image = (self.transform_topcrop(image) if "human_images"
194
+ in image_data["path"] else self.transform(image)
195
+ ) # [1 C H W] -> num_img [1 C H W]
196
+ image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W]
197
+
198
+ image = image.float() / 127.5 - 1.0
199
+
200
+ caps = (image_data["cap"] if isinstance(image_data["cap"], list) else
201
+ [image_data["cap"]])
202
+ caps = [random.choice(caps)]
203
+ text = caps
204
+ input_ids, cond_mask = [], []
205
+ text = text[0] if random.random() > self.cfg else ""
206
+ text_tokens_and_mask = self.tokenizer(
207
+ text,
208
+ max_length=self.text_max_length,
209
+ padding="max_length",
210
+ truncation=True,
211
+ return_attention_mask=True,
212
+ add_special_tokens=True,
213
+ return_tensors="pt",
214
+ )
215
+ input_ids = text_tokens_and_mask["input_ids"] # 1, l
216
+ cond_mask = text_tokens_and_mask["attention_mask"] # 1, l
217
+ return dict(
218
+ pixel_values=image,
219
+ text=text,
220
+ input_ids=input_ids,
221
+ cond_mask=cond_mask,
222
+ path=image_data["path"],
223
+ )
224
+
225
+ def define_frame_index(self, cap_list):
226
+ new_cap_list = []
227
+ sample_num_frames = []
228
+ cnt_too_long = 0
229
+ cnt_too_short = 0
230
+ cnt_no_cap = 0
231
+ cnt_no_resolution = 0
232
+ cnt_resolution_mismatch = 0
233
+ cnt_movie = 0
234
+ cnt_img = 0
235
+ for i in cap_list:
236
+ path = i["path"]
237
+ cap = i.get("cap", None)
238
+ # ======no caption=====
239
+ if cap is None:
240
+ cnt_no_cap += 1
241
+ continue
242
+ if path.endswith(".mp4"):
243
+ # ======no fps and duration=====
244
+ duration = i.get("duration", None)
245
+ fps = i.get("fps", None)
246
+ if fps is None or duration is None:
247
+ continue
248
+
249
+ # ======resolution mismatch=====
250
+ resolution = i.get("resolution", None)
251
+ if resolution is None:
252
+ cnt_no_resolution += 1
253
+ continue
254
+ else:
255
+ if (resolution.get("height", None) is None
256
+ or resolution.get("width", None) is None):
257
+ cnt_no_resolution += 1
258
+ continue
259
+ height, width = i["resolution"]["height"], i["resolution"][
260
+ "width"]
261
+ aspect = self.max_height / self.max_width
262
+ hw_aspect_thr = 1.5
263
+ is_pick = filter_resolution(
264
+ height,
265
+ width,
266
+ max_h_div_w_ratio=hw_aspect_thr * aspect,
267
+ min_h_div_w_ratio=1 / hw_aspect_thr * aspect,
268
+ )
269
+ if not is_pick:
270
+ print("resolution mismatch")
271
+ cnt_resolution_mismatch += 1
272
+ continue
273
+
274
+ # import ipdb;ipdb.set_trace()
275
+ i["num_frames"] = math.ceil(fps * duration)
276
+ # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration.
277
+ if i["num_frames"] / fps > self.video_length_tolerance_range * (
278
+ self.num_frames / self.train_fps * self.speed_factor
279
+ ): # too long video is not suitable for this training stage (self.num_frames)
280
+ cnt_too_long += 1
281
+ continue
282
+
283
+ # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)
284
+ frame_interval = fps / self.train_fps
285
+ start_frame_idx = 0
286
+ frame_indices = np.arange(start_frame_idx, i["num_frames"],
287
+ frame_interval).astype(int)
288
+
289
+ # comment out it to enable dynamic frames training
290
+ if (len(frame_indices) < self.num_frames
291
+ and random.random() < self.drop_short_ratio):
292
+ cnt_too_short += 1
293
+ continue
294
+
295
+ # too long video will be temporal-crop randomly
296
+ if len(frame_indices) > self.num_frames:
297
+ begin_index, end_index = self.temporal_sample(
298
+ len(frame_indices))
299
+ frame_indices = frame_indices[begin_index:end_index]
300
+ # frame_indices = frame_indices[:self.num_frames] # head crop
301
+ i["sample_frame_index"] = frame_indices.tolist()
302
+ new_cap_list.append(i)
303
+ i["sample_num_frames"] = len(
304
+ i["sample_frame_index"]
305
+ ) # will use in dataloader(group sampler)
306
+ sample_num_frames.append(i["sample_num_frames"])
307
+ elif path.endswith(".jpg"): # image
308
+ cnt_img += 1
309
+ new_cap_list.append(i)
310
+ i["sample_num_frames"] = 1
311
+ sample_num_frames.append(i["sample_num_frames"])
312
+ else:
313
+ raise NameError(
314
+ f"Unknown file extension {path.split('.')[-1]}, only support .mp4 for video and .jpg for image"
315
+ )
316
+ # import ipdb;ipdb.set_trace()
317
+ main_print(
318
+ f"no_cap: {cnt_no_cap}, too_long: {cnt_too_long}, too_short: {cnt_too_short}, "
319
+ f"no_resolution: {cnt_no_resolution}, resolution_mismatch: {cnt_resolution_mismatch}, "
320
+ f"Counter(sample_num_frames): {Counter(sample_num_frames)}, cnt_movie: {cnt_movie}, cnt_img: {cnt_img}, "
321
+ f"before filter: {len(cap_list)}, after filter: {len(new_cap_list)}"
322
+ )
323
+ return new_cap_list, sample_num_frames
324
+
325
+ def decord_read(self, path, frame_indices):
326
+ decord_vr = self.v_decoder(path)
327
+ video_data = decord_vr.get_batch(frame_indices).asnumpy()
328
+ video_data = torch.from_numpy(video_data)
329
+ video_data = video_data.permute(0, 3, 1,
330
+ 2) # (T, H, W, C) -> (T C H W)
331
+ return video_data
332
+
333
+ def read_jsons(self, data):
334
+ cap_lists = []
335
+ with open(data, "r") as f:
336
+ folder_anno = [
337
+ i.strip().split(",") for i in f.readlines()
338
+ if len(i.strip()) > 0
339
+ ]
340
+ print(folder_anno)
341
+ for folder, anno in folder_anno:
342
+ with open(anno, "r") as f:
343
+ sub_list = json.load(f)
344
+ for i in range(len(sub_list)):
345
+ sub_list[i]["path"] = opj(folder, sub_list[i]["path"])
346
+ cap_lists += sub_list
347
+ return cap_lists
348
+
349
+ def get_cap_list(self):
350
+ cap_lists = self.read_jsons(self.data)
351
+ return cap_lists
fastvideo/dataset/transform.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import numbers
4
+ import random
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+
10
+ def _is_tensor_video_clip(clip):
11
+ if not torch.is_tensor(clip):
12
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
13
+
14
+ if not clip.ndimension() == 4:
15
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
16
+
17
+ return True
18
+
19
+
20
+ def center_crop_arr(pil_image, image_size):
21
+ """
22
+ Center cropping implementation from ADM.
23
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
24
+ """
25
+ while min(*pil_image.size) >= 2 * image_size:
26
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size),
27
+ resample=Image.BOX)
28
+
29
+ scale = image_size / min(*pil_image.size)
30
+ pil_image = pil_image.resize(tuple(
31
+ round(x * scale) for x in pil_image.size),
32
+ resample=Image.BICUBIC)
33
+
34
+ arr = np.array(pil_image)
35
+ crop_y = (arr.shape[0] - image_size) // 2
36
+ crop_x = (arr.shape[1] - image_size) // 2
37
+ return Image.fromarray(arr[crop_y:crop_y + image_size,
38
+ crop_x:crop_x + image_size])
39
+
40
+
41
+ def crop(clip, i, j, h, w):
42
+ """
43
+ Args:
44
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
45
+ """
46
+ if len(clip.size()) != 4:
47
+ raise ValueError("clip should be a 4D tensor")
48
+ return clip[..., i:i + h, j:j + w]
49
+
50
+
51
+ def resize(clip, target_size, interpolation_mode):
52
+ if len(target_size) != 2:
53
+ raise ValueError(
54
+ f"target size should be tuple (height, width), instead got {target_size}"
55
+ )
56
+ return torch.nn.functional.interpolate(
57
+ clip,
58
+ size=target_size,
59
+ mode=interpolation_mode,
60
+ align_corners=True,
61
+ antialias=True,
62
+ )
63
+
64
+
65
+ def resize_scale(clip, target_size, interpolation_mode):
66
+ if len(target_size) != 2:
67
+ raise ValueError(
68
+ f"target size should be tuple (height, width), instead got {target_size}"
69
+ )
70
+ H, W = clip.size(-2), clip.size(-1)
71
+ scale_ = target_size[0] / min(H, W)
72
+ return torch.nn.functional.interpolate(
73
+ clip,
74
+ scale_factor=scale_,
75
+ mode=interpolation_mode,
76
+ align_corners=True,
77
+ antialias=True,
78
+ )
79
+
80
+
81
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
82
+ """
83
+ Do spatial cropping and resizing to the video clip
84
+ Args:
85
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
86
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
87
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
88
+ h (int): Height of the cropped region.
89
+ w (int): Width of the cropped region.
90
+ size (tuple(int, int)): height and width of resized clip
91
+ Returns:
92
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
93
+ """
94
+ if not _is_tensor_video_clip(clip):
95
+ raise ValueError("clip should be a 4D torch.tensor")
96
+ clip = crop(clip, i, j, h, w)
97
+ clip = resize(clip, size, interpolation_mode)
98
+ return clip
99
+
100
+
101
+ def center_crop(clip, crop_size):
102
+ if not _is_tensor_video_clip(clip):
103
+ raise ValueError("clip should be a 4D torch.tensor")
104
+ h, w = clip.size(-2), clip.size(-1)
105
+ th, tw = crop_size
106
+ if h < th or w < tw:
107
+ raise ValueError("height and width must be no smaller than crop_size")
108
+
109
+ i = int(round((h - th) / 2.0))
110
+ j = int(round((w - tw) / 2.0))
111
+ return crop(clip, i, j, th, tw)
112
+
113
+
114
+ def center_crop_using_short_edge(clip):
115
+ if not _is_tensor_video_clip(clip):
116
+ raise ValueError("clip should be a 4D torch.tensor")
117
+ h, w = clip.size(-2), clip.size(-1)
118
+ if h < w:
119
+ th, tw = h, h
120
+ i = 0
121
+ j = int(round((w - tw) / 2.0))
122
+ else:
123
+ th, tw = w, w
124
+ i = int(round((h - th) / 2.0))
125
+ j = 0
126
+ return crop(clip, i, j, th, tw)
127
+
128
+
129
+ def center_crop_th_tw(clip, th, tw, top_crop):
130
+ if not _is_tensor_video_clip(clip):
131
+ raise ValueError("clip should be a 4D torch.tensor")
132
+
133
+ # import ipdb;ipdb.set_trace()
134
+ h, w = clip.size(-2), clip.size(-1)
135
+ tr = th / tw
136
+ if h / w > tr:
137
+ new_h = int(w * tr)
138
+ new_w = w
139
+ else:
140
+ new_h = h
141
+ new_w = int(h / tr)
142
+
143
+ i = 0 if top_crop else int(round((h - new_h) / 2.0))
144
+ j = int(round((w - new_w) / 2.0))
145
+ return crop(clip, i, j, new_h, new_w)
146
+
147
+
148
+ def random_shift_crop(clip):
149
+ """
150
+ Slide along the long edge, with the short edge as crop size
151
+ """
152
+ if not _is_tensor_video_clip(clip):
153
+ raise ValueError("clip should be a 4D torch.tensor")
154
+ h, w = clip.size(-2), clip.size(-1)
155
+
156
+ if h <= w:
157
+ short_edge = h
158
+ else:
159
+ short_edge = w
160
+
161
+ th, tw = short_edge, short_edge
162
+
163
+ i = torch.randint(0, h - th + 1, size=(1, )).item()
164
+ j = torch.randint(0, w - tw + 1, size=(1, )).item()
165
+ return crop(clip, i, j, th, tw)
166
+
167
+
168
+ def normalize_video(clip):
169
+ """
170
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
171
+ permute the dimensions of clip tensor
172
+ Args:
173
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
174
+ Return:
175
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
176
+ """
177
+ _is_tensor_video_clip(clip)
178
+ if not clip.dtype == torch.uint8:
179
+ raise TypeError("clip tensor should have data type uint8. Got %s" %
180
+ str(clip.dtype))
181
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
182
+ return clip.float() / 255.0
183
+
184
+
185
+ def normalize(clip, mean, std, inplace=False):
186
+ """
187
+ Args:
188
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
189
+ mean (tuple): pixel RGB mean. Size is (3)
190
+ std (tuple): pixel standard deviation. Size is (3)
191
+ Returns:
192
+ normalized clip (torch.tensor): Size is (T, C, H, W)
193
+ """
194
+ if not _is_tensor_video_clip(clip):
195
+ raise ValueError("clip should be a 4D torch.tensor")
196
+ if not inplace:
197
+ clip = clip.clone()
198
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
199
+ # print(mean)
200
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
201
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
202
+ return clip
203
+
204
+
205
+ def hflip(clip):
206
+ """
207
+ Args:
208
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
209
+ Returns:
210
+ flipped clip (torch.tensor): Size is (T, C, H, W)
211
+ """
212
+ if not _is_tensor_video_clip(clip):
213
+ raise ValueError("clip should be a 4D torch.tensor")
214
+ return clip.flip(-1)
215
+
216
+
217
+ class RandomCropVideo:
218
+
219
+ def __init__(self, size):
220
+ if isinstance(size, numbers.Number):
221
+ self.size = (int(size), int(size))
222
+ else:
223
+ self.size = size
224
+
225
+ def __call__(self, clip):
226
+ """
227
+ Args:
228
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
229
+ Returns:
230
+ torch.tensor: randomly cropped video clip.
231
+ size is (T, C, OH, OW)
232
+ """
233
+ i, j, h, w = self.get_params(clip)
234
+ return crop(clip, i, j, h, w)
235
+
236
+ def get_params(self, clip):
237
+ h, w = clip.shape[-2:]
238
+ th, tw = self.size
239
+
240
+ if h < th or w < tw:
241
+ raise ValueError(
242
+ f"Required crop size {(th, tw)} is larger than input image size {(h, w)}"
243
+ )
244
+
245
+ if w == tw and h == th:
246
+ return 0, 0, h, w
247
+
248
+ i = torch.randint(0, h - th + 1, size=(1, )).item()
249
+ j = torch.randint(0, w - tw + 1, size=(1, )).item()
250
+
251
+ return i, j, th, tw
252
+
253
+ def __repr__(self) -> str:
254
+ return f"{self.__class__.__name__}(size={self.size})"
255
+
256
+
257
+ class SpatialStrideCropVideo:
258
+
259
+ def __init__(self, stride):
260
+ self.stride = stride
261
+
262
+ def __call__(self, clip):
263
+ """
264
+ Args:
265
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
266
+ Returns:
267
+ torch.tensor: cropped video clip by stride.
268
+ size is (T, C, OH, OW)
269
+ """
270
+ i, j, h, w = self.get_params(clip)
271
+ return crop(clip, i, j, h, w)
272
+
273
+ def get_params(self, clip):
274
+ h, w = clip.shape[-2:]
275
+
276
+ th, tw = h // self.stride * self.stride, w // self.stride * self.stride
277
+
278
+ return 0, 0, th, tw # from top-left
279
+
280
+ def __repr__(self) -> str:
281
+ return f"{self.__class__.__name__}(size={self.size})"
282
+
283
+
284
+ class LongSideResizeVideo:
285
+ """
286
+ First use the long side,
287
+ then resize to the specified size
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ size,
293
+ skip_low_resolution=False,
294
+ interpolation_mode="bilinear",
295
+ ):
296
+ self.size = size
297
+ self.skip_low_resolution = skip_low_resolution
298
+ self.interpolation_mode = interpolation_mode
299
+
300
+ def __call__(self, clip):
301
+ """
302
+ Args:
303
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
304
+ Returns:
305
+ torch.tensor: scale resized video clip.
306
+ size is (T, C, 512, *) or (T, C, *, 512)
307
+ """
308
+ _, _, h, w = clip.shape
309
+ if self.skip_low_resolution and max(h, w) <= self.size:
310
+ return clip
311
+ if h > w:
312
+ w = int(w * self.size / h)
313
+ h = self.size
314
+ else:
315
+ h = int(h * self.size / w)
316
+ w = self.size
317
+ resize_clip = resize(clip,
318
+ target_size=(h, w),
319
+ interpolation_mode=self.interpolation_mode)
320
+ return resize_clip
321
+
322
+ def __repr__(self) -> str:
323
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
324
+
325
+
326
+ class CenterCropResizeVideo:
327
+ """
328
+ First use the short side for cropping length,
329
+ center crop video, then resize to the specified size
330
+ """
331
+
332
+ def __init__(
333
+ self,
334
+ size,
335
+ top_crop=False,
336
+ interpolation_mode="bilinear",
337
+ ):
338
+ if len(size) != 2:
339
+ raise ValueError(
340
+ f"size should be tuple (height, width), instead got {size}")
341
+ self.size = size
342
+ self.top_crop = top_crop
343
+ self.interpolation_mode = interpolation_mode
344
+
345
+ def __call__(self, clip):
346
+ """
347
+ Args:
348
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
349
+ Returns:
350
+ torch.tensor: scale resized / center cropped video clip.
351
+ size is (T, C, crop_size, crop_size)
352
+ """
353
+ # clip_center_crop = center_crop_using_short_edge(clip)
354
+ clip_center_crop = center_crop_th_tw(clip,
355
+ self.size[0],
356
+ self.size[1],
357
+ top_crop=self.top_crop)
358
+ # import ipdb;ipdb.set_trace()
359
+ clip_center_crop_resize = resize(
360
+ clip_center_crop,
361
+ target_size=self.size,
362
+ interpolation_mode=self.interpolation_mode,
363
+ )
364
+ return clip_center_crop_resize
365
+
366
+ def __repr__(self) -> str:
367
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
368
+
369
+
370
+ class UCFCenterCropVideo:
371
+ """
372
+ First scale to the specified size in equal proportion to the short edge,
373
+ then center cropping
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ size,
379
+ interpolation_mode="bilinear",
380
+ ):
381
+ if isinstance(size, tuple):
382
+ if len(size) != 2:
383
+ raise ValueError(
384
+ f"size should be tuple (height, width), instead got {size}"
385
+ )
386
+ self.size = size
387
+ else:
388
+ self.size = (size, size)
389
+
390
+ self.interpolation_mode = interpolation_mode
391
+
392
+ def __call__(self, clip):
393
+ """
394
+ Args:
395
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
396
+ Returns:
397
+ torch.tensor: scale resized / center cropped video clip.
398
+ size is (T, C, crop_size, crop_size)
399
+ """
400
+ clip_resize = resize_scale(clip=clip,
401
+ target_size=self.size,
402
+ interpolation_mode=self.interpolation_mode)
403
+ clip_center_crop = center_crop(clip_resize, self.size)
404
+ return clip_center_crop
405
+
406
+ def __repr__(self) -> str:
407
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
408
+
409
+
410
+ class KineticsRandomCropResizeVideo:
411
+ """
412
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
413
+ """
414
+
415
+ def __init__(
416
+ self,
417
+ size,
418
+ interpolation_mode="bilinear",
419
+ ):
420
+ if isinstance(size, tuple):
421
+ if len(size) != 2:
422
+ raise ValueError(
423
+ f"size should be tuple (height, width), instead got {size}"
424
+ )
425
+ self.size = size
426
+ else:
427
+ self.size = (size, size)
428
+
429
+ self.interpolation_mode = interpolation_mode
430
+
431
+ def __call__(self, clip):
432
+ clip_random_crop = random_shift_crop(clip)
433
+ clip_resize = resize(clip_random_crop, self.size,
434
+ self.interpolation_mode)
435
+ return clip_resize
436
+
437
+
438
+ class CenterCropVideo:
439
+
440
+ def __init__(
441
+ self,
442
+ size,
443
+ interpolation_mode="bilinear",
444
+ ):
445
+ if isinstance(size, tuple):
446
+ if len(size) != 2:
447
+ raise ValueError(
448
+ f"size should be tuple (height, width), instead got {size}"
449
+ )
450
+ self.size = size
451
+ else:
452
+ self.size = (size, size)
453
+
454
+ self.interpolation_mode = interpolation_mode
455
+
456
+ def __call__(self, clip):
457
+ """
458
+ Args:
459
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
460
+ Returns:
461
+ torch.tensor: center cropped video clip.
462
+ size is (T, C, crop_size, crop_size)
463
+ """
464
+ clip_center_crop = center_crop(clip, self.size)
465
+ return clip_center_crop
466
+
467
+ def __repr__(self) -> str:
468
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
469
+
470
+
471
+ class Normalize:
472
+ """
473
+ Normalize the video clip by mean subtraction and division by standard deviation
474
+ Args:
475
+ mean (3-tuple): pixel RGB mean
476
+ std (3-tuple): pixel RGB standard deviation
477
+ inplace (boolean): whether do in-place normalization
478
+ """
479
+
480
+ def __init__(self, mean, std, inplace=False):
481
+ self.mean = mean
482
+ self.std = std
483
+ self.inplace = inplace
484
+
485
+ def __call__(self, clip):
486
+ """
487
+ Args:
488
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
489
+ """
490
+ return normalize(clip, self.mean, self.std, self.inplace)
491
+
492
+ def __repr__(self) -> str:
493
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
494
+
495
+
496
+ class Normalize255:
497
+ """
498
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
499
+ """
500
+
501
+ def __init__(self):
502
+ pass
503
+
504
+ def __call__(self, clip):
505
+ """
506
+ Args:
507
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
508
+ Return:
509
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
510
+ """
511
+ return normalize_video(clip)
512
+
513
+ def __repr__(self) -> str:
514
+ return self.__class__.__name__
515
+
516
+
517
+ class RandomHorizontalFlipVideo:
518
+ """
519
+ Flip the video clip along the horizontal direction with a given probability
520
+ Args:
521
+ p (float): probability of the clip being flipped. Default value is 0.5
522
+ """
523
+
524
+ def __init__(self, p=0.5):
525
+ self.p = p
526
+
527
+ def __call__(self, clip):
528
+ """
529
+ Args:
530
+ clip (torch.tensor): Size is (T, C, H, W)
531
+ Return:
532
+ clip (torch.tensor): Size is (T, C, H, W)
533
+ """
534
+ if random.random() < self.p:
535
+ clip = hflip(clip)
536
+ return clip
537
+
538
+ def __repr__(self) -> str:
539
+ return f"{self.__class__.__name__}(p={self.p})"
540
+
541
+
542
+ # ------------------------------------------------------------
543
+ # --------------------- Sampling ---------------------------
544
+ # ------------------------------------------------------------
545
+ class TemporalRandomCrop(object):
546
+ """Temporally crop the given frame indices at a random location.
547
+
548
+ Args:
549
+ size (int): Desired length of frames will be seen in the model.
550
+ """
551
+
552
+ def __init__(self, size):
553
+ self.size = size
554
+
555
+ def __call__(self, total_frames):
556
+ rand_end = max(0, total_frames - self.size - 1)
557
+ begin_index = random.randint(0, rand_end)
558
+ end_index = min(begin_index + self.size, total_frames)
559
+ return begin_index, end_index
560
+
561
+
562
+ class DynamicSampleDuration(object):
563
+ """Temporally crop the given frame indices at a random location.
564
+
565
+ Args:
566
+ size (int): Desired length of frames will be seen in the model.
567
+ """
568
+
569
+ def __init__(self, t_stride, extra_1):
570
+ self.t_stride = t_stride
571
+ self.extra_1 = extra_1
572
+
573
+ def __call__(self, t, h, w):
574
+ if self.extra_1:
575
+ t = t - 1
576
+ truncate_t_list = list(
577
+ range(t + 1))[t // 2:][::self.t_stride] # need half at least
578
+ truncate_t = random.choice(truncate_t_list)
579
+ if self.extra_1:
580
+ truncate_t = truncate_t + 1
581
+ return 0, truncate_t
582
+
583
+
584
+ if __name__ == "__main__":
585
+ import os
586
+
587
+ import numpy as np
588
+ import torchvision.io as io
589
+ from torchvision import transforms
590
+ from torchvision.utils import save_image
591
+
592
+ vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi",
593
+ pts_unit="sec",
594
+ output_format="TCHW")
595
+
596
+ trans = transforms.Compose([
597
+ Normalize255(),
598
+ RandomHorizontalFlipVideo(),
599
+ UCFCenterCropVideo(512),
600
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
601
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],
602
+ std=[0.5, 0.5, 0.5],
603
+ inplace=True),
604
+ ])
605
+
606
+ target_video_len = 32
607
+ frame_interval = 1
608
+ total_frames = len(vframes)
609
+ print(total_frames)
610
+
611
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
612
+
613
+ # Sampling video frames
614
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
615
+ # print(start_frame_ind)
616
+ # print(end_frame_ind)
617
+ assert end_frame_ind - start_frame_ind >= target_video_len
618
+ frame_indice = np.linspace(start_frame_ind,
619
+ end_frame_ind - 1,
620
+ target_video_len,
621
+ dtype=int)
622
+ print(frame_indice)
623
+
624
+ select_vframes = vframes[frame_indice]
625
+ print(select_vframes.shape)
626
+ print(select_vframes.dtype)
627
+
628
+ select_vframes_trans = trans(select_vframes)
629
+ print(select_vframes_trans.shape)
630
+ print(select_vframes_trans.dtype)
631
+
632
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) *
633
+ 255).to(dtype=torch.uint8)
634
+ print(select_vframes_trans_int.dtype)
635
+ print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
636
+
637
+ io.write_video("./test.avi",
638
+ select_vframes_trans_int.permute(0, 2, 3, 1),
639
+ fps=8)
640
+
641
+ for i in range(target_video_len):
642
+ save_image(
643
+ select_vframes_trans[i],
644
+ os.path.join("./test000", "%04d.png" % i),
645
+ normalize=True,
646
+ value_range=(-1, 1),
647
+ )
fastvideo/distill/__init__.py ADDED
File without changes
fastvideo/distill/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (225 Bytes). View file
 
fastvideo/distill/__pycache__/solver.cpython-312.pyc ADDED
Binary file (16.1 kB). View file
 
fastvideo/distill/discriminator.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import torch.nn as nn
4
+ from diffusers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
7
+
8
+
9
+ class DiscriminatorHead(nn.Module):
10
+
11
+ def __init__(self, input_channel, output_channel=1):
12
+ super().__init__()
13
+ inner_channel = 1024
14
+ self.conv1 = nn.Sequential(
15
+ nn.Conv2d(input_channel, inner_channel, 1, 1, 0),
16
+ nn.GroupNorm(32, inner_channel),
17
+ nn.LeakyReLU(
18
+ inplace=True
19
+ ), # use LeakyReLu instead of GELU shown in the paper to save memory
20
+ )
21
+ self.conv2 = nn.Sequential(
22
+ nn.Conv2d(inner_channel, inner_channel, 1, 1, 0),
23
+ nn.GroupNorm(32, inner_channel),
24
+ nn.LeakyReLU(
25
+ inplace=True
26
+ ), # use LeakyReLu instead of GELU shown in the paper to save memory
27
+ )
28
+
29
+ self.conv_out = nn.Conv2d(inner_channel, output_channel, 1, 1, 0)
30
+
31
+ def forward(self, x):
32
+ b, twh, c = x.shape
33
+ t = twh // (30 * 53)
34
+ x = x.view(-1, 30 * 53, c)
35
+ x = x.permute(0, 2, 1)
36
+ x = x.view(b * t, c, 30, 53)
37
+ x = self.conv1(x)
38
+ x = self.conv2(x) + x
39
+ x = self.conv_out(x)
40
+ return x
41
+
42
+
43
+ class Discriminator(nn.Module):
44
+
45
+ def __init__(
46
+ self,
47
+ stride=8,
48
+ num_h_per_head=1,
49
+ adapter_channel_dims=[3072],
50
+ total_layers=48,
51
+ ):
52
+ super().__init__()
53
+ adapter_channel_dims = adapter_channel_dims * (total_layers // stride)
54
+ self.stride = stride
55
+ self.num_h_per_head = num_h_per_head
56
+ self.head_num = len(adapter_channel_dims)
57
+ self.heads = nn.ModuleList([
58
+ nn.ModuleList([
59
+ DiscriminatorHead(adapter_channel)
60
+ for _ in range(self.num_h_per_head)
61
+ ]) for adapter_channel in adapter_channel_dims
62
+ ])
63
+
64
+ def forward(self, features):
65
+ outputs = []
66
+
67
+ def create_custom_forward(module):
68
+
69
+ def custom_forward(*inputs):
70
+ return module(*inputs)
71
+
72
+ return custom_forward
73
+
74
+ assert len(features) == len(self.heads)
75
+ for i in range(0, len(features)):
76
+ for h in self.heads[i]:
77
+ # out = torch.utils.checkpoint.checkpoint(
78
+ # create_custom_forward(h),
79
+ # features[i],
80
+ # use_reentrant=False
81
+ # )
82
+ out = h(features[i])
83
+ outputs.append(out)
84
+ return outputs
fastvideo/distill/solver.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
10
+ from diffusers.utils import BaseOutput, logging
11
+
12
+ from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
13
+
14
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
15
+
16
+
17
+ @dataclass
18
+ class PCMFMSchedulerOutput(BaseOutput):
19
+ prev_sample: torch.FloatTensor
20
+
21
+
22
+ def extract_into_tensor(a, t, x_shape):
23
+ b, *_ = t.shape
24
+ out = a.gather(-1, t)
25
+ return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
26
+
27
+
28
+ class PCMFMScheduler(SchedulerMixin, ConfigMixin):
29
+ _compatibles = []
30
+ order = 1
31
+
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_train_timesteps: int = 1000,
36
+ shift: float = 1.0,
37
+ pcm_timesteps: int = 50,
38
+ linear_quadratic=False,
39
+ linear_quadratic_threshold=0.025,
40
+ linear_range=0.5,
41
+ ):
42
+ if linear_quadratic:
43
+ linear_steps = int(num_train_timesteps * linear_range)
44
+ sigmas = linear_quadratic_schedule(num_train_timesteps,
45
+ linear_quadratic_threshold,
46
+ linear_steps)
47
+ sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
48
+ else:
49
+ timesteps = np.linspace(1,
50
+ num_train_timesteps,
51
+ num_train_timesteps,
52
+ dtype=np.float32)[::-1].copy()
53
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
54
+ sigmas = timesteps / num_train_timesteps
55
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
56
+ self.euler_timesteps = (np.arange(1, pcm_timesteps + 1) *
57
+ (num_train_timesteps //
58
+ pcm_timesteps)).round().astype(np.int64) - 1
59
+ self.sigmas = sigmas.numpy()[::-1][self.euler_timesteps]
60
+ self.sigmas = torch.from_numpy((self.sigmas[::-1].copy()))
61
+ self.timesteps = self.sigmas * num_train_timesteps
62
+ self._step_index = None
63
+ self._begin_index = None
64
+ self.sigmas = self.sigmas.to(
65
+ "cpu") # to avoid too much CPU/GPU communication
66
+ self.sigma_min = self.sigmas[-1].item()
67
+ self.sigma_max = self.sigmas[0].item()
68
+
69
+ @property
70
+ def step_index(self):
71
+ """
72
+ The index counter for current timestep. It will increase 1 after each scheduler step.
73
+ """
74
+ return self._step_index
75
+
76
+ @property
77
+ def begin_index(self):
78
+ """
79
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
80
+ """
81
+ return self._begin_index
82
+
83
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
84
+ def set_begin_index(self, begin_index: int = 0):
85
+ """
86
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
87
+
88
+ Args:
89
+ begin_index (`int`):
90
+ The begin index for the scheduler.
91
+ """
92
+ self._begin_index = begin_index
93
+
94
+ def scale_noise(
95
+ self,
96
+ sample: torch.FloatTensor,
97
+ timestep: Union[float, torch.FloatTensor],
98
+ noise: Optional[torch.FloatTensor] = None,
99
+ ) -> torch.FloatTensor:
100
+ """
101
+ Forward process in flow-matching
102
+
103
+ Args:
104
+ sample (`torch.FloatTensor`):
105
+ The input sample.
106
+ timestep (`int`, *optional*):
107
+ The current timestep in the diffusion chain.
108
+
109
+ Returns:
110
+ `torch.FloatTensor`:
111
+ A scaled input sample.
112
+ """
113
+ if self.step_index is None:
114
+ self._init_step_index(timestep)
115
+
116
+ sigma = self.sigmas[self.step_index]
117
+ sample = sigma * noise + (1.0 - sigma) * sample
118
+
119
+ return sample
120
+
121
+ def _sigma_to_t(self, sigma):
122
+ return sigma * self.config.num_train_timesteps
123
+
124
+ def set_timesteps(self,
125
+ num_inference_steps: int,
126
+ device: Union[str, torch.device] = None):
127
+ """
128
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
129
+
130
+ Args:
131
+ num_inference_steps (`int`):
132
+ The number of diffusion steps used when generating samples with a pre-trained model.
133
+ device (`str` or `torch.device`, *optional*):
134
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
135
+ """
136
+ self.num_inference_steps = num_inference_steps
137
+ inference_indices = np.linspace(0,
138
+ self.config.pcm_timesteps,
139
+ num=num_inference_steps,
140
+ endpoint=False)
141
+ inference_indices = np.floor(inference_indices).astype(np.int64)
142
+ inference_indices = torch.from_numpy(inference_indices).long()
143
+
144
+ self.sigmas_ = self.sigmas[inference_indices]
145
+ timesteps = self.sigmas_ * self.config.num_train_timesteps
146
+ self.timesteps = timesteps.to(device=device)
147
+ self.sigmas_ = torch.cat(
148
+ [self.sigmas_,
149
+ torch.zeros(1, device=self.sigmas_.device)])
150
+ self._step_index = None
151
+ self._begin_index = None
152
+
153
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
154
+ if schedule_timesteps is None:
155
+ schedule_timesteps = self.timesteps
156
+
157
+ indices = (schedule_timesteps == timestep).nonzero()
158
+
159
+ # The sigma index that is taken for the **very** first `step`
160
+ # is always the second index (or the last index if there is only 1)
161
+ # This way we can ensure we don't accidentally skip a sigma in
162
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
163
+ pos = 1 if len(indices) > 1 else 0
164
+
165
+ return indices[pos].item()
166
+
167
+ def _init_step_index(self, timestep):
168
+ if self.begin_index is None:
169
+ if isinstance(timestep, torch.Tensor):
170
+ timestep = timestep.to(self.timesteps.device)
171
+ self._step_index = self.index_for_timestep(timestep)
172
+ else:
173
+ self._step_index = self._begin_index
174
+
175
+ def step(
176
+ self,
177
+ model_output: torch.FloatTensor,
178
+ timestep: Union[float, torch.FloatTensor],
179
+ sample: torch.FloatTensor,
180
+ generator: Optional[torch.Generator] = None,
181
+ return_dict: bool = True,
182
+ ) -> Union[PCMFMSchedulerOutput, Tuple]:
183
+ """
184
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
185
+ process from the learned model outputs (most often the predicted noise).
186
+
187
+ Args:
188
+ model_output (`torch.FloatTensor`):
189
+ The direct output from learned diffusion model.
190
+ timestep (`float`):
191
+ The current discrete timestep in the diffusion chain.
192
+ sample (`torch.FloatTensor`):
193
+ A current instance of a sample created by the diffusion process.
194
+ s_churn (`float`):
195
+ s_tmin (`float`):
196
+ s_tmax (`float`):
197
+ s_noise (`float`, defaults to 1.0):
198
+ Scaling factor for noise added to the sample.
199
+ generator (`torch.Generator`, *optional*):
200
+ A random number generator.
201
+ return_dict (`bool`):
202
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
203
+ tuple.
204
+
205
+ Returns:
206
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
207
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
208
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
209
+ """
210
+
211
+ if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor)
212
+ or isinstance(timestep, torch.LongTensor)):
213
+ raise ValueError((
214
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
215
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
216
+ " one of the `scheduler.timesteps` as a timestep."), )
217
+
218
+ if self.step_index is None:
219
+ self._init_step_index(timestep)
220
+
221
+ sample = sample.to(torch.float32)
222
+
223
+ sigma = self.sigmas_[self.step_index]
224
+
225
+ denoised = sample - model_output * sigma
226
+ derivative = (sample - denoised) / sigma
227
+
228
+ dt = self.sigmas_[self.step_index + 1] - sigma
229
+ prev_sample = sample + derivative * dt
230
+ prev_sample = prev_sample.to(model_output.dtype)
231
+ self._step_index += 1
232
+
233
+ if not return_dict:
234
+ return (prev_sample, )
235
+
236
+ return PCMFMSchedulerOutput(prev_sample=prev_sample)
237
+
238
+ def __len__(self):
239
+ return self.config.num_train_timesteps
240
+
241
+
242
+ class EulerSolver:
243
+
244
+ def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
245
+ self.step_ratio = timesteps // euler_timesteps
246
+ self.euler_timesteps = (np.arange(1, euler_timesteps + 1) *
247
+ self.step_ratio).round().astype(np.int64) - 1
248
+ self.euler_timesteps_prev = np.asarray(
249
+ [0] + self.euler_timesteps[:-1].tolist())
250
+ self.sigmas = sigmas[self.euler_timesteps]
251
+ self.sigmas_prev = np.asarray(
252
+ [sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()
253
+ ) # either use sigma0 or 0
254
+
255
+ self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
256
+ self.euler_timesteps_prev = torch.from_numpy(
257
+ self.euler_timesteps_prev).long()
258
+ self.sigmas = torch.from_numpy(self.sigmas)
259
+ self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
260
+
261
+ def to(self, device):
262
+ self.euler_timesteps = self.euler_timesteps.to(device)
263
+ self.euler_timesteps_prev = self.euler_timesteps_prev.to(device)
264
+
265
+ self.sigmas = self.sigmas.to(device)
266
+ self.sigmas_prev = self.sigmas_prev.to(device)
267
+ return self
268
+
269
+ def euler_step(self, sample, model_pred, timestep_index):
270
+ sigma = extract_into_tensor(self.sigmas, timestep_index,
271
+ model_pred.shape)
272
+ sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index,
273
+ model_pred.shape)
274
+ x_prev = sample + (sigma_prev - sigma) * model_pred
275
+ return x_prev
276
+
277
+ def euler_style_multiphase_pred(
278
+ self,
279
+ sample,
280
+ model_pred,
281
+ timestep_index,
282
+ multiphase,
283
+ is_target=False,
284
+ ):
285
+ inference_indices = np.linspace(0,
286
+ len(self.euler_timesteps),
287
+ num=multiphase,
288
+ endpoint=False)
289
+ inference_indices = np.floor(inference_indices).astype(np.int64)
290
+ inference_indices = (torch.from_numpy(inference_indices).long().to(
291
+ self.euler_timesteps.device))
292
+ expanded_timestep_index = timestep_index.unsqueeze(1).expand(
293
+ -1, inference_indices.size(0))
294
+ valid_indices_mask = expanded_timestep_index >= inference_indices
295
+ last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(
296
+ dim=1)
297
+ last_valid_index = inference_indices.size(0) - 1 - last_valid_index
298
+ timestep_index_end = inference_indices[last_valid_index]
299
+
300
+ if is_target:
301
+ sigma = extract_into_tensor(self.sigmas_prev, timestep_index,
302
+ sample.shape)
303
+ else:
304
+ sigma = extract_into_tensor(self.sigmas, timestep_index,
305
+ sample.shape)
306
+ sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end,
307
+ sample.shape)
308
+ x_prev = sample + (sigma_prev - sigma) * model_pred
309
+
310
+ return x_prev, timestep_index_end
fastvideo/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc ADDED
Binary file (1.41 kB). View file
 
fastvideo/models/flash_attn_no_pad.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from flash_attn import flash_attn_varlen_qkvpacked_func
3
+ from flash_attn.bert_padding import pad_input, unpad_input
4
+
5
+
6
+ def flash_attn_no_pad(qkv,
7
+ key_padding_mask,
8
+ causal=False,
9
+ dropout_p=0.0,
10
+ softmax_scale=None):
11
+ # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
12
+ batch_size = qkv.shape[0]
13
+ seqlen = qkv.shape[1]
14
+ nheads = qkv.shape[-2]
15
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
16
+ x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(
17
+ x, key_padding_mask)
18
+
19
+ x_unpad = rearrange(x_unpad,
20
+ "nnz (three h d) -> nnz three h d",
21
+ three=3,
22
+ h=nheads)
23
+ output_unpad = flash_attn_varlen_qkvpacked_func(
24
+ x_unpad,
25
+ cu_seqlens,
26
+ max_s,
27
+ dropout_p,
28
+ softmax_scale=softmax_scale,
29
+ causal=causal,
30
+ )
31
+ output = rearrange(
32
+ pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices,
33
+ batch_size, seqlen),
34
+ "b s (h d) -> b s h d",
35
+ h=nheads,
36
+ )
37
+ return output
fastvideo/reward_model/clip_score.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torchvision import transforms
4
+ import torch.nn.functional as F
5
+ import clip
6
+ from PIL import Image
7
+ from typing import List, Tuple, Union
8
+ from PIL import Image
9
+ import os
10
+ from open_clip import create_model_from_pretrained, get_tokenizer
11
+ import argparse
12
+
13
+
14
+
15
+ @torch.no_grad()
16
+ def calculate_clip_score(prompts, images, clip_model, device):
17
+ texts = clip.tokenize(prompts, truncate=True).to(device=device)
18
+
19
+ image_features = clip_model.encode_image(images)
20
+ text_features = clip_model.encode_text(texts)
21
+
22
+ scores = F.cosine_similarity(image_features, text_features)
23
+ return scores
24
+
25
+
26
+ class CLIPScoreRewardModel():
27
+ def __init__(self, clip_model_path, device, http_proxy=None, https_proxy=None, clip_model_type='ViT-H-14'):
28
+ super().__init__()
29
+ if http_proxy:
30
+ os.environ["http_proxy"] = http_proxy
31
+ if https_proxy:
32
+ os.environ["https_proxy"] = https_proxy
33
+ self.clip_model_path = clip_model_path
34
+ self.clip_model_type = clip_model_type
35
+ self.device = device
36
+ self.load_model()
37
+
38
+ def load_model(self, logger=None):
39
+ self.model, self.preprocess = create_model_from_pretrained(self.clip_model_path)
40
+ self.tokenizer = get_tokenizer(self.clip_model_type)
41
+ self.model.to(self.device)
42
+
43
+ # calculate clip score directly, such as for rerank
44
+ @torch.no_grad()
45
+ def __call__(
46
+ self,
47
+ prompts: Union[str, List[str]],
48
+ images: List[Image.Image]
49
+ ) -> List[float]:
50
+ if isinstance(prompts, str):
51
+ prompts = [prompts] * len(images)
52
+ if len(prompts) != len(images):
53
+ raise ValueError("prompts must have the same length as images")
54
+
55
+ scores = []
56
+ for prompt, image in zip(prompts, images):
57
+ image_proc = self.preprocess(image).unsqueeze(0).to(self.device)
58
+ text = self.tokenizer(
59
+ [prompt],
60
+ context_length=self.model.context_length
61
+ ).to(self.device)
62
+
63
+ image_features = self.model.encode_image(image_proc)
64
+ text_features = self.model.encode_text(text)
65
+ image_features = F.normalize(image_features, dim=-1)
66
+ text_features = F.normalize(text_features, dim=-1)
67
+
68
+ clip_score = image_features @ text_features.T
69
+
70
+ scores.append(clip_score.item())
71
+
72
+ return scores
73
+
74
+
75
+
76
+ if __name__ == "__main__":
77
+ parser = argparse.ArgumentParser(description="PickScore Reward Model")
78
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (e.g., 'cuda', 'cpu')")
79
+ parser.add_argument("--http_proxy", type=str, default=None, help="HTTP proxy URL")
80
+ parser.add_argument("--https_proxy", type=str, default=None, help="HTTPS proxy URL")
81
+ args = parser.parse_args()
82
+
83
+ # Example usage
84
+ clip_model_path = 'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384'
85
+ reward_model = CLIPScoreRewardModel(
86
+ clip_model_path,
87
+ device=args.device,
88
+ http_proxy=args.http_proxy,
89
+ https_proxy=args.https_proxy
90
+ )
91
+
92
+ image_path = "assets/reward_demo.jpg"
93
+ prompt = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting."
94
+
95
+ image = Image.open(image_path).convert("RGB")
96
+ clip_score = reward_model(prompt, [image])
97
+
98
+ print(f"CLIP Score: {clip_score}")
fastvideo/reward_model/hps_score.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ import argparse
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from HPSv2.hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
+
8
+
9
+ class HPSClipRewardModel(object):
10
+ def __init__(self, device, clip_ckpt_path, hps_ckpt_path, model_name='ViT-H-14'):
11
+ self.device = device
12
+ self.clip_ckpt_path = clip_ckpt_path
13
+ self.hps_ckpt_path = hps_ckpt_path
14
+ self.model_name = model_name
15
+ self.reward_model, self.text_processor, self.img_processor = self.build_reward_model()
16
+
17
+ def build_reward_model(self):
18
+ model, preprocess_train, img_preprocess_val = create_model_and_transforms(
19
+ self.model_name,
20
+ self.clip_ckpt_path,
21
+ precision='amp',
22
+ device=self.device,
23
+ jit=False,
24
+ force_quick_gelu=False,
25
+ force_custom_text=False,
26
+ force_patch_dropout=False,
27
+ force_image_size=None,
28
+ pretrained_image=False,
29
+ image_mean=None,
30
+ image_std=None,
31
+ light_augmentation=True,
32
+ aug_cfg={},
33
+ output_dict=True,
34
+ with_score_predictor=False,
35
+ with_region_predictor=False
36
+ )
37
+
38
+ # Convert device name to proper format
39
+ if isinstance(self.device, int):
40
+ ml_device = str(self.device)
41
+ else:
42
+ ml_device = self.device
43
+
44
+ if not ml_device.startswith('cuda'):
45
+ ml_device = f'cuda:{ml_device}' if ml_device.isdigit() else ml_device
46
+
47
+ checkpoint = torch.load(self.hps_ckpt_path, map_location=ml_device)
48
+ model.load_state_dict(checkpoint['state_dict'])
49
+ text_processor = get_tokenizer(self.model_name)
50
+ reward_model = model.to(self.device)
51
+ reward_model.eval()
52
+
53
+ return reward_model, text_processor, img_preprocess_val
54
+
55
+ @torch.no_grad()
56
+ def __call__(
57
+ self,
58
+ images: Union[Image.Image, List[Image.Image]],
59
+ texts: Union[str, List[str]],
60
+ ):
61
+ if isinstance(images, Image.Image):
62
+ images = [images]
63
+ if isinstance(texts, str):
64
+ texts = [texts]
65
+
66
+ rewards = []
67
+ for image, text in zip(images, texts):
68
+ image = self.img_processor(image).unsqueeze(0).to(self.device, non_blocking=True)
69
+ text = self.text_processor([text]).to(device=self.device, non_blocking=True)
70
+ with torch.amp.autocast('cuda'):
71
+ outputs = self.reward_model(image, text)
72
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
73
+ logits_per_image = image_features @ text_features.T
74
+ hps_score = torch.diagonal(logits_per_image)
75
+
76
+ # reward is a tensor of shape (1,) --> list
77
+ rewards.append(hps_score.float().cpu().item())
78
+
79
+ return rewards
fastvideo/reward_model/image_reward.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image-Reward: Copyied from https://github.com/THUDM/ImageReward
2
+ import os
3
+ from typing import Union, List
4
+ from PIL import Image
5
+
6
+ import torch
7
+ try:
8
+ import ImageReward as RM
9
+ except:
10
+ raise Warning("ImageReward is required to be installed (`pip install image-reward`) when using ImageReward for post-training.")
11
+
12
+
13
+ class ImageRewardModel(object):
14
+ def __init__(self, model_name, device, http_proxy=None, https_proxy=None, med_config=None):
15
+ if http_proxy:
16
+ os.environ["http_proxy"] = http_proxy
17
+ if https_proxy:
18
+ os.environ["https_proxy"] = https_proxy
19
+ self.model_name = model_name if model_name else "ImageReward-v1.0"
20
+ self.device = device
21
+ self.med_config = med_config
22
+ self.build_reward_model()
23
+
24
+ def build_reward_model(self):
25
+ self.model = RM.load(self.model_name, device=self.device, med_config=self.med_config)
26
+
27
+ @torch.no_grad()
28
+ def __call__(
29
+ self,
30
+ images,
31
+ texts,
32
+ ):
33
+ if isinstance(texts, str):
34
+ texts = [texts] * len(images)
35
+
36
+ rewards = []
37
+ for image, text in zip(images, texts):
38
+ ranking, reward = self.model.inference_rank(text, [image])
39
+ rewards.append(reward)
40
+ return rewards
fastvideo/reward_model/pick_score.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from typing import List, Tuple, Union
5
+ from transformers import AutoProcessor, AutoModel
6
+ from PIL import Image
7
+
8
+
9
+ class PickScoreRewardModel(object):
10
+ def __init__(self, device: str = "cuda", http_proxy=None, https_proxy=None, mean=18.0, std=8.0):
11
+ """
12
+ Initialize PickScore reward model.
13
+
14
+ Args:
15
+ device: Device to run the model on ('cuda' or 'cpu')
16
+ """
17
+ if http_proxy:
18
+ os.environ["http_proxy"] = http_proxy
19
+ if https_proxy:
20
+ os.environ["https_proxy"] = https_proxy
21
+ self.device = device
22
+ self.processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
23
+ self.model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
24
+ self.mean = mean
25
+ self.std = std
26
+
27
+ # Initialize model and processor
28
+ self.processor = AutoProcessor.from_pretrained(self.processor_name_or_path)
29
+ self.model = AutoModel.from_pretrained(self.model_pretrained_name_or_path).eval().to(device)
30
+
31
+ @torch.no_grad()
32
+ def __call__(
33
+ self,
34
+ images: List[Image.Image],
35
+ prompts: Union[str, List[str]],
36
+ ) -> Tuple[List[float], List[float]]:
37
+ """
38
+ Calculate probabilities and scores for images given a prompt.
39
+
40
+ Args:
41
+ prompts: Text prompt to evaluate images against
42
+ images: List of PIL Images to evaluate
43
+
44
+ Returns:
45
+ Tuple of (probabilities, scores) for each image
46
+ """
47
+ if isinstance(prompts, str):
48
+ prompts = [prompts] * len(images)
49
+ if len(prompts) != len(images):
50
+ raise ValueError("prompts must have the same length as images")
51
+
52
+ scores = []
53
+ for prompt, image in zip(prompts, images):
54
+ # Preprocess images
55
+ image_inputs = self.processor(
56
+ images=[image],
57
+ padding=True,
58
+ truncation=True,
59
+ max_length=77,
60
+ return_tensors="pt",
61
+ ).to(self.device)
62
+
63
+ # Preprocess text
64
+ text_inputs = self.processor(
65
+ text=prompt,
66
+ padding=True,
67
+ truncation=True,
68
+ max_length=77,
69
+ return_tensors="pt",
70
+ ).to(self.device)
71
+
72
+ # Get embeddings
73
+ image_embs = self.model.get_image_features(**image_inputs)
74
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
75
+
76
+ text_embs = self.model.get_text_features(**text_inputs)
77
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
78
+
79
+ # Calculate scores
80
+ score = self.model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
81
+ score = (score - self.mean) / self.std
82
+ scores.extend(score.cpu().tolist())
83
+
84
+ return scores
85
+
86
+
87
+ if __name__ == "__main__":
88
+ parser = argparse.ArgumentParser(description="PickScore Reward Model")
89
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (e.g., 'cuda', 'cpu')")
90
+ parser.add_argument("--http_proxy", type=str, default=None, help="HTTP proxy URL")
91
+ parser.add_argument("--https_proxy", type=str, default=None, help="HTTPS proxy URL")
92
+ args = parser.parse_args()
93
+
94
+ # Example usage
95
+ reward_model = PickScoreRewardModel(
96
+ device=args.device,
97
+ http_proxy=args.http_proxy,
98
+ https_proxy=args.https_proxy,
99
+ )
100
+ pil_images = [Image.open("assets/reward_demo.jpg")]
101
+
102
+ prompt = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting."
103
+
104
+ scores = reward_model(pil_images, [prompt] * len(pil_images))
105
+ scores = [(s * reward_model.std + reward_model.mean) / 100.0 for s in scores]
106
+ print(f"scores: {scores}")
107
+
fastvideo/reward_model/unified_reward.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import os
4
+ import re
5
+ import requests
6
+ import time
7
+ import concurrent.futures
8
+ from io import BytesIO
9
+ from typing import List, Optional, Union
10
+
11
+ from PIL import Image
12
+
13
+
14
+ QUESTION_TEMPLATE_SEMANTIC = (
15
+ "You are presented with a generated image and its associated text caption. Your task is to analyze the image across multiple dimensions in relation to the caption. Specifically:\n\n"
16
+ "1. Evaluate each word in the caption based on how well it is visually represented in the image. Assign a numerical score to each word using the format:\n"
17
+ " Word-wise Scores: [[\"word1\", score1], [\"word2\", score2], ..., [\"wordN\", scoreN], [\"[No_mistakes]\", scoreM]]\n"
18
+ " - A higher score indicates that the word is less well represented in the image.\n"
19
+ " - The special token [No_mistakes] represents whether all elements in the caption were correctly depicted. A high score suggests no mistakes; a low score suggests missing or incorrect elements.\n\n"
20
+ "2. Provide overall assessments for the image along the following axes (each rated from 1 to 5):\n"
21
+ "- Alignment Score: How well the image matches the caption in terms of content.\n"
22
+ "- Coherence Score: How logically consistent the image is (absence of visual glitches, object distortions, etc.).\n"
23
+ "- Style Score: How aesthetically appealing the image looks, regardless of caption accuracy.\n\n"
24
+ "Output your evaluation using the format below:\n\n"
25
+ "---\n\n"
26
+ "Word-wise Scores: [[\"word1\", score1], ..., [\"[No_mistakes]\", scoreM]]\n\n"
27
+ "Alignment Score (1-5): X\n"
28
+ "Coherence Score (1-5): Y\n"
29
+ "Style Score (1-5): Z\n\n"
30
+ "Your task is provided as follows:\nText Caption: [{}]"
31
+ )
32
+
33
+ QUESTION_TEMPLATE_SCORE = (
34
+ "You are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n"
35
+ "1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n"
36
+ "2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\n"
37
+ "Extract key elements from the provided text caption, evaluate their presence in the generated image using the format: \'element (type): value\' (where value=0 means not generated, and value=1 means generated), and assign a score from 1 to 5 after \'Final Score:\'.\n"
38
+ "Your task is provided as follows:\nText Caption: [{}]"
39
+ )
40
+
41
+
42
+ class VLMessageClient:
43
+ def __init__(self, api_url):
44
+ self.api_url = api_url
45
+ self._session = None
46
+
47
+ @property
48
+ def session(self):
49
+ if self._session is None:
50
+ self._session = requests.Session()
51
+ return self._session
52
+
53
+ def close(self):
54
+ """Close the session if it exists."""
55
+ if self._session is not None:
56
+ self._session.close()
57
+ self._session = None
58
+
59
+ def __enter__(self):
60
+ return self
61
+
62
+ def __exit__(self, exc_type, exc_val, exc_tb):
63
+ self.close()
64
+
65
+ def _encode_image_base64(self, image):
66
+ if isinstance(image, str):
67
+ with Image.open(image) as img:
68
+ img = img.convert("RGB")
69
+ buffered = BytesIO()
70
+ img.save(buffered, format="JPEG", quality=95)
71
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
72
+ elif isinstance(image, Image.Image):
73
+ buffered = BytesIO()
74
+ image.save(buffered, format="JPEG", quality=95)
75
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
76
+ else:
77
+ raise ValueError(f"Unsupported image type: {type(image)}")
78
+
79
+ def build_messages(self, item, image_root=""):
80
+ if isinstance(item['image'], str):
81
+ image_path = os.path.join(image_root, item['image'])
82
+ return [
83
+ {
84
+ "role": "user",
85
+ "content": [
86
+ {"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
87
+ {
88
+ "type": "text",
89
+ "text": f"{item['question']}"
90
+ }
91
+ ]
92
+ }
93
+ ]
94
+ assert isinstance(item['image'], Image.Image), f"image must be a PIL.Image.Image, but got {type(item['image'])}"
95
+ return [
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {"type": "pil_image", "pil_image": item['image']},
100
+ {
101
+ "type": "text",
102
+ "text": f"{item['question']}"
103
+ }
104
+ ]
105
+ }
106
+ ]
107
+
108
+ def format_messages(self, messages):
109
+ formatted = []
110
+ for msg in messages:
111
+ new_msg = {"role": msg["role"], "content": []}
112
+
113
+ if msg["role"] == "system":
114
+ new_msg["content"] = msg["content"][0]["text"]
115
+ else:
116
+ for part in msg["content"]:
117
+ if part["type"] == "image_url":
118
+ img_path = part["image_url"]["url"].replace("file://", "")
119
+ base64_image = self._encode_image_base64(img_path)
120
+ new_part = {
121
+ "type": "image_url",
122
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
123
+ }
124
+ new_msg["content"].append(new_part)
125
+ elif part["type"] == "pil_image":
126
+ base64_image = self._encode_image_base64(part["pil_image"])
127
+ new_part = {
128
+ "type": "image_url",
129
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
130
+ }
131
+ new_msg["content"].append(new_part)
132
+ else:
133
+ new_msg["content"].append(part)
134
+ formatted.append(new_msg)
135
+ return formatted
136
+
137
+ def process_item(self, item, image_root=""):
138
+ max_retries = 3
139
+ attempt = 0
140
+ result = None
141
+
142
+ while attempt < max_retries:
143
+ try:
144
+ attempt += 1
145
+ raw_messages = self.build_messages(item, image_root)
146
+ formatted_messages = self.format_messages(raw_messages)
147
+
148
+ payload = {
149
+ "model": "UnifiedReward",
150
+ "messages": formatted_messages,
151
+ "temperature": 0,
152
+ "max_tokens": 4096,
153
+ }
154
+
155
+ response = self.session.post(
156
+ f"{self.api_url}/v1/chat/completions",
157
+ json=payload,
158
+ timeout=30 + attempt*5
159
+ )
160
+ response.raise_for_status()
161
+
162
+ output = response.json()["choices"][0]["message"]["content"]
163
+
164
+ result = {
165
+ "question": item["question"],
166
+ "image_path": item["image"] if isinstance(item["image"], str) else "PIL_Image",
167
+ "model_output": output,
168
+ "attempt": attempt,
169
+ "success": True
170
+ }
171
+ break
172
+
173
+ except Exception as e:
174
+ if attempt == max_retries:
175
+ result = {
176
+ "question": item["question"],
177
+ "image_path": item["image"] if isinstance(item["image"], str) else "PIL_Image",
178
+ "error": str(e),
179
+ "attempt": attempt,
180
+ "success": False
181
+ }
182
+ raise(e)
183
+ else:
184
+ sleep_time = min(2 ** attempt, 10)
185
+ time.sleep(sleep_time)
186
+
187
+ return result, result.get("success", False)
188
+
189
+
190
+ class UnifiedRewardModel(object):
191
+ def __init__(self, api_url, default_question_type="score", num_workers=8):
192
+ self.api_url = api_url
193
+ self.num_workers = num_workers
194
+ self.default_question_type = default_question_type
195
+ self.question_template_score = QUESTION_TEMPLATE_SCORE
196
+ self.question_template_semantic = QUESTION_TEMPLATE_SEMANTIC
197
+ # self.client = VLMessageClient(self.api_url)
198
+
199
+ def question_constructor(self, prompt, question_type=None):
200
+ if question_type is None:
201
+ question_type = self.default_question_type
202
+ if question_type == "score":
203
+ return self.question_template_score.format(prompt)
204
+ elif question_type == "semantic":
205
+ return self.question_template_semantic.format(prompt)
206
+ else:
207
+ raise ValueError(f"Invalid question type: {question_type}")
208
+
209
+ def _process_item_wrapper(self, client, image, question):
210
+ try:
211
+ item = {
212
+ "image": image,
213
+ "question": question,
214
+ }
215
+ result, _ = client.process_item(item)
216
+ return result
217
+ except Exception as e:
218
+ print(f"Encountered error in unified reward model processing: {str(e)}")
219
+ return None
220
+
221
+ def _reset_proxy(self):
222
+ os.environ.pop('http_proxy', None)
223
+ os.environ.pop('https_proxy', None)
224
+
225
+ def __call__(self,
226
+ images: Union[List[Image.Image], List[str]],
227
+ prompts: Union[str, List[str]],
228
+ question_type: Optional[str] = None,
229
+ ):
230
+ # Reset proxy, otherwise cannot access the server url
231
+ self._reset_proxy()
232
+ if isinstance(prompts, str):
233
+ prompts = [prompts] * len(images)
234
+ if len(prompts) != len(images):
235
+ raise ValueError("prompts must have the same length as images")
236
+
237
+ with VLMessageClient(self.api_url) as client:
238
+ questions = [self.question_constructor(prompt, question_type) for prompt in prompts]
239
+
240
+ # Initialize results and successes lists with None and False
241
+ results = [None] * len(images)
242
+ successes = [False] * len(images)
243
+
244
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_workers) as executor:
245
+ # Submit all tasks and keep track of their order
246
+ future_to_idx = {
247
+ executor.submit(self._process_item_wrapper, client, image, question): idx
248
+ for idx, (image, question) in enumerate(zip(images, questions))
249
+ }
250
+
251
+ # Get results in completion order but store them in the correct position
252
+ for future in concurrent.futures.as_completed(future_to_idx):
253
+ idx = future_to_idx[future]
254
+ result = future.result()
255
+ if result is not None and result.get("success", False):
256
+ output = result.get("model_output", "")
257
+ score = self.score_parser(output, question_type)
258
+ results[idx] = score
259
+ successes[idx] = True
260
+ else:
261
+ results[idx] = None
262
+ successes[idx] = False
263
+
264
+ return results, successes
265
+
266
+ def score_parser(self, text, question_type=None):
267
+ if question_type is None:
268
+ question_type = self.default_question_type
269
+ if question_type == "score":
270
+ return self.extract_final_score(text)
271
+ elif question_type == "semantic":
272
+ return self.extract_alignment_score(text)
273
+ else:
274
+ raise ValueError(f"Invalid question type: {question_type}")
275
+
276
+ @staticmethod
277
+ def extract_alignment_score(text):
278
+ """
279
+ Extract Alignment Score (1-5) from the evaluation text.
280
+ Returns a float score if found, None otherwise.
281
+ """
282
+ match = re.search(r'Alignment Score \(1-5\):\s*([0-5](?:\.\d+)?)', text)
283
+ if match:
284
+ return float(match.group(1))
285
+ else:
286
+ return None
287
+
288
+ @staticmethod
289
+ def extract_final_score(text):
290
+ """
291
+ Extract Final Score from the evaluation text.
292
+ Returns a float score if found, None otherwise.
293
+ Example input:
294
+ 'ocean (location): 0
295
+ clouds (object): 1
296
+ birds (animal): 0
297
+ day time (attribute): 1
298
+ low depth field effect (attribute): 1
299
+ painting (attribute): 1
300
+ Final Score: 2.33'
301
+ """
302
+ match = re.search(r'Final Score:\s*([0-5](?:\.\d+)?)', text)
303
+ if match:
304
+ return float(match.group(1))
305
+ else:
306
+ return None
307
+
308
+
309
+ if __name__ == "__main__":
310
+ parser = argparse.ArgumentParser()
311
+ parser.add_argument("--api_url", type=str)
312
+ parser.add_argument("--max_workers", type=int)
313
+ args = parser.parse_args()
314
+
315
+ unified_reward_model = UnifiedRewardModel(args.api_url, num_workers=args.max_workers)
316
+ img_path = "assets/reward_demo.jpg"
317
+ images = [
318
+ Image.open(img_path).convert("RGB")
319
+ for i in range(1, 5)
320
+ ] * 4
321
+ prompts = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting."
322
+ results, successes = unified_reward_model(images, prompts, question_type="semantic")
323
+ print(results)
324
+ print(successes)
325
+
326
+ # # 并发测试
327
+ # proc_num = 32
328
+
329
+ # for i in range(5):
330
+ # with concurrent.futures.ThreadPoolExecutor(max_workers=proc_num) as executor:
331
+ # futures = [executor.submit(unified_reward_model, images, prompts, question_type="semantic") for _ in range(proc_num)]
332
+ # results = [future.result() for future in concurrent.futures.as_completed(futures)]
333
+ # print(results)
fastvideo/reward_model/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import random
3
+
4
+ def _compute_single_reward(reward_model, images, input_prompts):
5
+ """Compute reward for a single reward model."""
6
+ reward_model_name = type(reward_model).__name__
7
+ try:
8
+ if reward_model_name == 'HPSClipRewardModel':
9
+ rewards = reward_model(images, input_prompts)
10
+ successes = [1] * len(rewards)
11
+
12
+ elif reward_model_name == 'CLIPScoreRewardModel':
13
+ rewards = reward_model(input_prompts, images)
14
+ successes = [1] * len(rewards)
15
+
16
+ elif reward_model_name == 'ImageRewardModel':
17
+ rewards = reward_model(images, input_prompts)
18
+ successes = [1] * len(rewards)
19
+
20
+ elif reward_model_name == 'UnifiedRewardModel':
21
+ rewards, successes_bool = reward_model(images, input_prompts)
22
+ rewards = [float(reward) if success else 0.0 for reward, success in zip(rewards, successes_bool)]
23
+ successes = [1 if success else 0 for success in successes_bool]
24
+
25
+ elif reward_model_name == 'PickScoreRewardModel':
26
+ rewards = reward_model(images, input_prompts)
27
+ successes = [1] * len(rewards)
28
+
29
+ else:
30
+ raise ValueError(f"Unknown reward model: {reward_model_name}")
31
+
32
+ # Verify the length of results matches input
33
+ assert len(rewards) == len(input_prompts), \
34
+ f"Length mismatch in {reward_model_name}: rewards ({len(rewards)}) != input_prompts ({len(input_prompts)})"
35
+ assert len(successes) == len(input_prompts), \
36
+ f"Length mismatch in {reward_model_name}: successes ({len(successes)}) != input_prompts ({len(input_prompts)})"
37
+
38
+ return rewards, successes
39
+
40
+ except Exception as e:
41
+ raise ValueError(f"Error in _compute_single_reward with {reward_model_name}: {e}") from e
42
+
43
+ def compute_reward(images, input_prompts, reward_models, reward_weights):
44
+ assert (
45
+ len(images) == len(input_prompts)
46
+ ), f"length of `images` ({len(images)}) must be equal to length of `input_prompts` ({len(input_prompts)})"
47
+
48
+ # Initialize results
49
+ rewards_dict = {}
50
+ successes_dict = {}
51
+
52
+ # Create a thread pool for parallel reward computation
53
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(reward_models)) as executor:
54
+ # Submit all reward computation tasks
55
+ future_to_model = {
56
+ executor.submit(_compute_single_reward, reward_model, images, input_prompts): reward_model
57
+ for reward_model in reward_models
58
+ }
59
+
60
+ # Process results as they complete
61
+ for future in concurrent.futures.as_completed(future_to_model):
62
+ reward_model = future_to_model[future]
63
+ model_name = type(reward_model).__name__
64
+ try:
65
+ model_rewards, model_successes = future.result()
66
+ rewards_dict[model_name] = model_rewards
67
+ successes_dict[model_name] = model_successes
68
+ except Exception as e:
69
+ print(f"Error computing reward with {model_name}: {e}")
70
+ rewards_dict[model_name] = [0.0] * len(input_prompts)
71
+ successes_dict[model_name] = [0] * len(input_prompts)
72
+ continue
73
+
74
+ # Merge rewards based on weights
75
+ merged_rewards = [0.0] * len(input_prompts)
76
+ merged_successes = [0] * len(input_prompts)
77
+
78
+ # First check if all models are successful for each sample
79
+ for i in range(len(merged_rewards)):
80
+ all_success = True
81
+ for model_name in reward_weights.keys():
82
+ if model_name in successes_dict and successes_dict[model_name][i] != 1:
83
+ all_success = False
84
+ break
85
+
86
+ if all_success:
87
+ # Only compute weighted sum if all models are successful
88
+ for model_name, weight in reward_weights.items():
89
+ if model_name in rewards_dict:
90
+ merged_rewards[i] += rewards_dict[model_name][i] * weight
91
+ merged_successes[i] = 1
92
+
93
+ return merged_rewards, merged_successes, rewards_dict, successes_dict
94
+
95
+ def balance_pos_neg(samples, use_random=False):
96
+ """Balance positive and negative samples distribution in the samples list."""
97
+ if use_random:
98
+ return random.sample(samples, len(samples))
99
+ else:
100
+ positive_samples = [sample for sample in samples if sample['advantages'].item() > 0]
101
+ negative_samples = [sample for sample in samples if sample['advantages'].item() < 0]
102
+
103
+ positive_samples = random.sample(positive_samples, len(positive_samples))
104
+ negative_samples = random.sample(negative_samples, len(negative_samples))
105
+
106
+ num_positive = len(positive_samples)
107
+ num_negative = len(negative_samples)
108
+
109
+ balanced_samples = []
110
+
111
+ if num_positive < num_negative:
112
+ smaller_group = positive_samples
113
+ larger_group = negative_samples
114
+ else:
115
+ smaller_group = negative_samples
116
+ larger_group = positive_samples
117
+
118
+ for i in range(len(smaller_group)):
119
+ balanced_samples.append(smaller_group[i])
120
+ balanced_samples.append(larger_group[i])
121
+
122
+ # If there are remaining samples in the larger group, add them
123
+ remaining_samples = larger_group[len(smaller_group):]
124
+ balanced_samples.extend(remaining_samples)
125
+ return balanced_samples
126
+
fastvideo/utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
fastvideo/utils/checkpoint.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import json
4
+ import os
5
+
6
+ import torch
7
+ import torch.distributed.checkpoint as dist_cp
8
+ from peft import get_peft_model_state_dict
9
+ from safetensors.torch import load_file, save_file
10
+ from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner,
11
+ DefaultSavePlanner)
12
+ from torch.distributed.checkpoint.optimizer import \
13
+ load_sharded_optimizer_state_dict
14
+ from torch.distributed.fsdp import (FullOptimStateDictConfig,
15
+ FullStateDictConfig)
16
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
17
+ from torch.distributed.fsdp import StateDictType
18
+
19
+ from fastvideo.utils.logging_ import main_print
20
+
21
+
22
+ def save_checkpoint_optimizer(model,
23
+ optimizer,
24
+ rank,
25
+ output_dir,
26
+ step,
27
+ discriminator=False):
28
+ with FSDP.state_dict_type(
29
+ model,
30
+ StateDictType.FULL_STATE_DICT,
31
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
32
+ FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
33
+ ):
34
+ cpu_state = model.state_dict()
35
+ optim_state = FSDP.optim_state_dict(
36
+ model,
37
+ optimizer,
38
+ )
39
+
40
+ # todo move to get_state_dict
41
+ save_dir = os.path.join(output_dir, f"checkpoint-{step}")
42
+ os.makedirs(save_dir, exist_ok=True)
43
+ # save using safetensors
44
+ if rank <= 0 and not discriminator:
45
+ weight_path = os.path.join(save_dir,
46
+ "diffusion_pytorch_model.safetensors")
47
+ save_file(cpu_state, weight_path)
48
+ config_dict = dict(model.config)
49
+ config_dict.pop('dtype')
50
+ config_path = os.path.join(save_dir, "config.json")
51
+ # save dict as json
52
+ with open(config_path, "w") as f:
53
+ json.dump(config_dict, f, indent=4)
54
+ optimizer_path = os.path.join(save_dir, "optimizer.pt")
55
+ torch.save(optim_state, optimizer_path)
56
+ else:
57
+ weight_path = os.path.join(save_dir,
58
+ "discriminator_pytorch_model.safetensors")
59
+ save_file(cpu_state, weight_path)
60
+ optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt")
61
+ torch.save(optim_state, optimizer_path)
62
+ main_print(f"--> checkpoint saved at step {step}")
63
+
64
+
65
+ def save_checkpoint(transformer, rank, output_dir, step, epoch):
66
+ main_print(f"--> saving checkpoint at step {step}")
67
+ with FSDP.state_dict_type(
68
+ transformer,
69
+ StateDictType.FULL_STATE_DICT,
70
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
71
+ ):
72
+ cpu_state = transformer.state_dict()
73
+ # todo move to get_state_dict
74
+ if rank <= 0:
75
+ save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}")
76
+ os.makedirs(save_dir, exist_ok=True)
77
+ # save using safetensors
78
+ weight_path = os.path.join(save_dir,
79
+ "diffusion_pytorch_model.safetensors")
80
+ save_file(cpu_state, weight_path)
81
+ config_dict = dict(transformer.config)
82
+ if "dtype" in config_dict:
83
+ del config_dict["dtype"] # TODO
84
+ config_path = os.path.join(save_dir, "config.json")
85
+ # save dict as json
86
+ with open(config_path, "w") as f:
87
+ json.dump(config_dict, f, indent=4)
88
+ main_print(f"--> checkpoint saved at step {step}")
89
+
90
+
91
+ def save_checkpoint_generator_discriminator(
92
+ model,
93
+ optimizer,
94
+ discriminator,
95
+ discriminator_optimizer,
96
+ rank,
97
+ output_dir,
98
+ step,
99
+ ):
100
+ with FSDP.state_dict_type(
101
+ model,
102
+ StateDictType.FULL_STATE_DICT,
103
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
104
+ ):
105
+ cpu_state = model.state_dict()
106
+
107
+ # todo move to get_state_dict
108
+ save_dir = os.path.join(output_dir, f"checkpoint-{step}")
109
+ os.makedirs(save_dir, exist_ok=True)
110
+ hf_weight_dir = os.path.join(save_dir, "hf_weights")
111
+ os.makedirs(hf_weight_dir, exist_ok=True)
112
+ # save using safetensors
113
+ if rank <= 0:
114
+ config_dict = dict(model.config)
115
+ config_path = os.path.join(hf_weight_dir, "config.json")
116
+ # save dict as json
117
+ with open(config_path, "w") as f:
118
+ json.dump(config_dict, f, indent=4)
119
+ weight_path = os.path.join(hf_weight_dir,
120
+ "diffusion_pytorch_model.safetensors")
121
+ save_file(cpu_state, weight_path)
122
+
123
+ main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}")
124
+ model_weight_dir = os.path.join(save_dir, "model_weights_state")
125
+ os.makedirs(model_weight_dir, exist_ok=True)
126
+ model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state")
127
+ os.makedirs(model_optimizer_dir, exist_ok=True)
128
+ with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
129
+ optim_state = FSDP.optim_state_dict(model, optimizer)
130
+ model_state = model.state_dict()
131
+ weight_state_dict = {"model": model_state}
132
+ dist_cp.save_state_dict(
133
+ state_dict=weight_state_dict,
134
+ storage_writer=dist_cp.FileSystemWriter(model_weight_dir),
135
+ planner=DefaultSavePlanner(),
136
+ )
137
+ optimizer_state_dict = {"optimizer": optim_state}
138
+ dist_cp.save_state_dict(
139
+ state_dict=optimizer_state_dict,
140
+ storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir),
141
+ planner=DefaultSavePlanner(),
142
+ )
143
+
144
+ discriminator_fsdp_state_dir = os.path.join(save_dir,
145
+ "discriminator_fsdp_state")
146
+ os.makedirs(discriminator_fsdp_state_dir, exist_ok=True)
147
+ with FSDP.state_dict_type(
148
+ discriminator,
149
+ StateDictType.FULL_STATE_DICT,
150
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
151
+ FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
152
+ ):
153
+ optim_state = FSDP.optim_state_dict(discriminator,
154
+ discriminator_optimizer)
155
+ model_state = discriminator.state_dict()
156
+ state_dict = {"optimizer": optim_state, "model": model_state}
157
+ if rank <= 0:
158
+ discriminator_fsdp_state_fil = os.path.join(
159
+ discriminator_fsdp_state_dir, "discriminator_state.pt")
160
+ torch.save(state_dict, discriminator_fsdp_state_fil)
161
+
162
+ main_print("--> saved FSDP state checkpoint")
163
+
164
+
165
+ def load_sharded_model(model, optimizer, model_dir, optimizer_dir):
166
+ with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
167
+ weight_state_dict = {"model": model.state_dict()}
168
+
169
+ optim_state = load_sharded_optimizer_state_dict(
170
+ model_state_dict=weight_state_dict["model"],
171
+ optimizer_key="optimizer",
172
+ storage_reader=dist_cp.FileSystemReader(optimizer_dir),
173
+ )
174
+ optim_state = optim_state["optimizer"]
175
+ flattened_osd = FSDP.optim_state_dict_to_load(
176
+ model=model, optim=optimizer, optim_state_dict=optim_state)
177
+ optimizer.load_state_dict(flattened_osd)
178
+ dist_cp.load_state_dict(
179
+ state_dict=weight_state_dict,
180
+ storage_reader=dist_cp.FileSystemReader(model_dir),
181
+ planner=DefaultLoadPlanner(),
182
+ )
183
+ model_state = weight_state_dict["model"]
184
+ model.load_state_dict(model_state)
185
+ main_print(f"--> loaded model and optimizer from path {model_dir}")
186
+ return model, optimizer
187
+
188
+
189
+ def load_full_state_model(model, optimizer, checkpoint_file, rank):
190
+ with FSDP.state_dict_type(
191
+ model,
192
+ StateDictType.FULL_STATE_DICT,
193
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
194
+ FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
195
+ ):
196
+ discriminator_state = torch.load(checkpoint_file)
197
+ model_state = discriminator_state["model"]
198
+ if rank <= 0:
199
+ optim_state = discriminator_state["optimizer"]
200
+ else:
201
+ optim_state = None
202
+ model.load_state_dict(model_state)
203
+ discriminator_optim_state = FSDP.optim_state_dict_to_load(
204
+ model=model, optim=optimizer, optim_state_dict=optim_state)
205
+ optimizer.load_state_dict(discriminator_optim_state)
206
+ main_print(
207
+ f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}"
208
+ )
209
+ return model, optimizer
210
+
211
+
212
+ def resume_training_generator_discriminator(model, optimizer, discriminator,
213
+ discriminator_optimizer,
214
+ checkpoint_dir, rank):
215
+ step = int(checkpoint_dir.split("-")[-1])
216
+ model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state")
217
+ model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state")
218
+ model, optimizer = load_sharded_model(model, optimizer, model_weight_dir,
219
+ model_optimizer_dir)
220
+ discriminator_ckpt_file = os.path.join(checkpoint_dir,
221
+ "discriminator_fsdp_state",
222
+ "discriminator_state.pt")
223
+ discriminator, discriminator_optimizer = load_full_state_model(
224
+ discriminator, discriminator_optimizer, discriminator_ckpt_file, rank)
225
+ return model, optimizer, discriminator, discriminator_optimizer, step
226
+
227
+
228
+ def resume_training(model, optimizer, checkpoint_dir, discriminator=False):
229
+ weight_path = os.path.join(checkpoint_dir,
230
+ "diffusion_pytorch_model.safetensors")
231
+ if discriminator:
232
+ weight_path = os.path.join(checkpoint_dir,
233
+ "discriminator_pytorch_model.safetensors")
234
+ model_weights = load_file(weight_path)
235
+
236
+ with FSDP.state_dict_type(
237
+ model,
238
+ StateDictType.FULL_STATE_DICT,
239
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
240
+ FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
241
+ ):
242
+ current_state = model.state_dict()
243
+ current_state.update(model_weights)
244
+ model.load_state_dict(current_state, strict=False)
245
+ if discriminator:
246
+ optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt")
247
+ else:
248
+ optim_path = os.path.join(checkpoint_dir, "optimizer.pt")
249
+ optimizer_state_dict = torch.load(optim_path, weights_only=False)
250
+ optim_state = FSDP.optim_state_dict_to_load(
251
+ model=model, optim=optimizer, optim_state_dict=optimizer_state_dict)
252
+ optimizer.load_state_dict(optim_state)
253
+ step = int(checkpoint_dir.split("-")[-1])
254
+ return model, optimizer, step
255
+
256
+
257
+ def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step,
258
+ pipeline, epoch):
259
+ with FSDP.state_dict_type(
260
+ transformer,
261
+ StateDictType.FULL_STATE_DICT,
262
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
263
+ ):
264
+ full_state_dict = transformer.state_dict()
265
+ lora_optim_state = FSDP.optim_state_dict(
266
+ transformer,
267
+ optimizer,
268
+ )
269
+
270
+ if rank <= 0:
271
+ save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}-{epoch}")
272
+ os.makedirs(save_dir, exist_ok=True)
273
+
274
+ # save optimizer
275
+ optim_path = os.path.join(save_dir, "lora_optimizer.pt")
276
+ torch.save(lora_optim_state, optim_path)
277
+ # save lora weight
278
+ main_print(f"--> saving LoRA checkpoint at step {step}")
279
+ transformer_lora_layers = get_peft_model_state_dict(
280
+ model=transformer, state_dict=full_state_dict)
281
+ pipeline.save_lora_weights(
282
+ save_directory=save_dir,
283
+ transformer_lora_layers=transformer_lora_layers,
284
+ is_main_process=True,
285
+ )
286
+ # save config
287
+ lora_config = {
288
+ "step": step,
289
+ "lora_params": {
290
+ "lora_rank": transformer.config.lora_rank,
291
+ "lora_alpha": transformer.config.lora_alpha,
292
+ "target_modules": transformer.config.lora_target_modules,
293
+ },
294
+ }
295
+ config_path = os.path.join(save_dir, "lora_config.json")
296
+ with open(config_path, "w") as f:
297
+ json.dump(lora_config, f, indent=4)
298
+ main_print(f"--> LoRA checkpoint saved at step {step}")
299
+
300
+
301
+ def resume_lora_optimizer(transformer, checkpoint_dir, optimizer):
302
+ config_path = os.path.join(checkpoint_dir, "lora_config.json")
303
+ with open(config_path, "r") as f:
304
+ config_dict = json.load(f)
305
+ optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt")
306
+ optimizer_state_dict = torch.load(optim_path, weights_only=False)
307
+ optim_state = FSDP.optim_state_dict_to_load(
308
+ model=transformer,
309
+ optim=optimizer,
310
+ optim_state_dict=optimizer_state_dict)
311
+ optimizer.load_state_dict(optim_state)
312
+ step = config_dict["step"]
313
+ main_print(f"--> Successfully resuming LoRA optimizer from step {step}")
314
+ return transformer, optimizer, step
fastvideo/utils/communications.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ from typing import Any, Tuple
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import Tensor
8
+
9
+ from fastvideo.utils.parallel_states import nccl_info
10
+
11
+
12
+ def broadcast(input_: torch.Tensor):
13
+ src = nccl_info.group_id * nccl_info.sp_size
14
+ dist.broadcast(input_, src=src, group=nccl_info.group)
15
+
16
+
17
+ def _all_to_all_4D(input: torch.tensor,
18
+ scatter_idx: int = 2,
19
+ gather_idx: int = 1,
20
+ group=None) -> torch.tensor:
21
+ """
22
+ all-to-all for QKV
23
+
24
+ Args:
25
+ input (torch.tensor): a tensor sharded along dim scatter dim
26
+ scatter_idx (int): default 1
27
+ gather_idx (int): default 2
28
+ group : torch process group
29
+
30
+ Returns:
31
+ torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
32
+ """
33
+ assert (
34
+ input.dim() == 4
35
+ ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
36
+
37
+ seq_world_size = dist.get_world_size(group)
38
+
39
+ if scatter_idx == 2 and gather_idx == 1:
40
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
41
+ bs, shard_seqlen, hc, hs = input.shape
42
+ seqlen = shard_seqlen * seq_world_size
43
+ shard_hc = hc // seq_world_size
44
+
45
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
46
+ # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
47
+ input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc,
48
+ hs).transpose(0, 2).contiguous())
49
+
50
+ output = torch.empty_like(input_t)
51
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
52
+ # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
53
+ if seq_world_size > 1:
54
+ dist.all_to_all_single(output, input_t, group=group)
55
+ torch.cuda.synchronize()
56
+ else:
57
+ output = input_t
58
+ # if scattering the seq-dim, transpose the heads back to the original dimension
59
+ output = output.reshape(seqlen, bs, shard_hc, hs)
60
+
61
+ # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
62
+ output = output.transpose(0, 1).contiguous().reshape(
63
+ bs, seqlen, shard_hc, hs)
64
+
65
+ return output
66
+
67
+ elif scatter_idx == 1 and gather_idx == 2:
68
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
69
+ bs, seqlen, shard_hc, hs = input.shape
70
+ hc = shard_hc * seq_world_size
71
+ shard_seqlen = seqlen // seq_world_size
72
+ seq_world_size = dist.get_world_size(group)
73
+
74
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
75
+ # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
76
+ input_t = (input.reshape(
77
+ bs, seq_world_size, shard_seqlen, shard_hc,
78
+ hs).transpose(0, 3).transpose(0, 1).contiguous().reshape(
79
+ seq_world_size, shard_hc, shard_seqlen, bs, hs))
80
+
81
+ output = torch.empty_like(input_t)
82
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
83
+ # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
84
+ if seq_world_size > 1:
85
+ dist.all_to_all_single(output, input_t, group=group)
86
+ torch.cuda.synchronize()
87
+ else:
88
+ output = input_t
89
+
90
+ # if scattering the seq-dim, transpose the heads back to the original dimension
91
+ output = output.reshape(hc, shard_seqlen, bs, hs)
92
+
93
+ # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
94
+ output = output.transpose(0, 2).contiguous().reshape(
95
+ bs, shard_seqlen, hc, hs)
96
+
97
+ return output
98
+ else:
99
+ raise RuntimeError(
100
+ "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
101
+
102
+
103
+ class SeqAllToAll4D(torch.autograd.Function):
104
+
105
+ @staticmethod
106
+ def forward(
107
+ ctx: Any,
108
+ group: dist.ProcessGroup,
109
+ input: Tensor,
110
+ scatter_idx: int,
111
+ gather_idx: int,
112
+ ) -> Tensor:
113
+ ctx.group = group
114
+ ctx.scatter_idx = scatter_idx
115
+ ctx.gather_idx = gather_idx
116
+
117
+ return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
118
+
119
+ @staticmethod
120
+ def backward(ctx: Any,
121
+ *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
122
+ return (
123
+ None,
124
+ SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx,
125
+ ctx.scatter_idx),
126
+ None,
127
+ None,
128
+ )
129
+
130
+
131
+ def all_to_all_4D(
132
+ input_: torch.Tensor,
133
+ scatter_dim: int = 2,
134
+ gather_dim: int = 1,
135
+ ):
136
+ return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim,
137
+ gather_dim)
138
+
139
+
140
+ def _all_to_all(
141
+ input_: torch.Tensor,
142
+ world_size: int,
143
+ group: dist.ProcessGroup,
144
+ scatter_dim: int,
145
+ gather_dim: int,
146
+ ):
147
+ input_list = [
148
+ t.contiguous()
149
+ for t in torch.tensor_split(input_, world_size, scatter_dim)
150
+ ]
151
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
152
+ dist.all_to_all(output_list, input_list, group=group)
153
+ return torch.cat(output_list, dim=gather_dim).contiguous()
154
+
155
+
156
+ class _AllToAll(torch.autograd.Function):
157
+ """All-to-all communication.
158
+
159
+ Args:
160
+ input_: input matrix
161
+ process_group: communication group
162
+ scatter_dim: scatter dimension
163
+ gather_dim: gather dimension
164
+ """
165
+
166
+ @staticmethod
167
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
168
+ ctx.process_group = process_group
169
+ ctx.scatter_dim = scatter_dim
170
+ ctx.gather_dim = gather_dim
171
+ ctx.world_size = dist.get_world_size(process_group)
172
+ output = _all_to_all(input_, ctx.world_size, process_group,
173
+ scatter_dim, gather_dim)
174
+ return output
175
+
176
+ @staticmethod
177
+ def backward(ctx, grad_output):
178
+ grad_output = _all_to_all(
179
+ grad_output,
180
+ ctx.world_size,
181
+ ctx.process_group,
182
+ ctx.gather_dim,
183
+ ctx.scatter_dim,
184
+ )
185
+ return (
186
+ grad_output,
187
+ None,
188
+ None,
189
+ None,
190
+ )
191
+
192
+
193
+ def all_to_all(
194
+ input_: torch.Tensor,
195
+ scatter_dim: int = 2,
196
+ gather_dim: int = 1,
197
+ ):
198
+ return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
199
+
200
+
201
+ class _AllGather(torch.autograd.Function):
202
+ """All-gather communication with autograd support.
203
+
204
+ Args:
205
+ input_: input tensor
206
+ dim: dimension along which to concatenate
207
+ """
208
+
209
+ @staticmethod
210
+ def forward(ctx, input_, dim):
211
+ ctx.dim = dim
212
+ world_size = nccl_info.sp_size
213
+ group = nccl_info.group
214
+ input_size = list(input_.size())
215
+
216
+ ctx.input_size = input_size[dim]
217
+
218
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
219
+ input_ = input_.contiguous()
220
+ dist.all_gather(tensor_list, input_, group=group)
221
+
222
+ output = torch.cat(tensor_list, dim=dim)
223
+ return output
224
+
225
+ @staticmethod
226
+ def backward(ctx, grad_output):
227
+ world_size = nccl_info.sp_size
228
+ rank = nccl_info.rank_within_group
229
+ dim = ctx.dim
230
+ input_size = ctx.input_size
231
+
232
+ sizes = [input_size] * world_size
233
+
234
+ grad_input_list = torch.split(grad_output, sizes, dim=dim)
235
+ grad_input = grad_input_list[rank]
236
+
237
+ return grad_input, None
238
+
239
+
240
+ def all_gather(input_: torch.Tensor, dim: int = 1):
241
+ """Performs an all-gather operation on the input tensor along the specified dimension.
242
+
243
+ Args:
244
+ input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
245
+ dim (int, optional): Dimension along which to concatenate. Defaults to 1.
246
+
247
+ Returns:
248
+ torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
249
+ """
250
+ return _AllGather.apply(input_, dim)
251
+
252
+
253
+ def prepare_sequence_parallel_data(
254
+ encoder_hidden_states, encoder_attention_mask, caption
255
+ ):
256
+ if nccl_info.sp_size == 1:
257
+ return (
258
+ encoder_hidden_states,
259
+ encoder_attention_mask,
260
+ caption,
261
+ )
262
+
263
+ def prepare(
264
+ encoder_hidden_states, encoder_attention_mask, caption
265
+ ):
266
+ #hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
267
+ encoder_hidden_states = all_to_all(
268
+ encoder_hidden_states, scatter_dim=1, gather_dim=0
269
+ )
270
+ #attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0)
271
+ encoder_attention_mask = all_to_all(
272
+ encoder_attention_mask, scatter_dim=1, gather_dim=0
273
+ )
274
+ return (
275
+ encoder_hidden_states,
276
+ encoder_attention_mask,
277
+ caption
278
+ )
279
+
280
+ sp_size = nccl_info.sp_size
281
+ #frame = hidden_states.shape[2]
282
+ #assert frame % sp_size == 0, "frame should be a multiple of sp_size"
283
+
284
+ (
285
+ #hidden_states,
286
+ encoder_hidden_states,
287
+ #attention_mask,
288
+ encoder_attention_mask,
289
+ caption,
290
+ ) = prepare(
291
+ #hidden_states,
292
+ encoder_hidden_states.repeat(1, sp_size, 1),
293
+ #attention_mask.repeat(1, sp_size, 1, 1),
294
+ encoder_attention_mask.repeat(1, sp_size),
295
+ caption,
296
+ )
297
+
298
+ return encoder_hidden_states, encoder_attention_mask, caption
299
+
300
+
301
+ def sp_parallel_dataloader_wrapper(
302
+ dataloader, device, train_batch_size, sp_size, train_sp_batch_size
303
+ ):
304
+ while True:
305
+ for data_item in dataloader:
306
+ cond, cond_mask, caption = data_item
307
+ #latents = latents.to(device)
308
+ cond = cond.to(device)
309
+ #attn_mask = attn_mask.to(device)
310
+ cond_mask = cond_mask.to(device)
311
+ #frame = latents.shape[2]
312
+ frame = 19
313
+ if frame == 1:
314
+ yield cond, cond_mask, caption
315
+ else:
316
+ cond, cond_mask, caption = prepare_sequence_parallel_data(
317
+ cond, cond_mask, caption
318
+ )
319
+ assert (
320
+ train_batch_size * sp_size >= train_sp_batch_size
321
+ ), "train_batch_size * sp_size should be greater than train_sp_batch_size"
322
+ for iter in range(train_batch_size * sp_size // train_sp_batch_size):
323
+ st_idx = iter * train_sp_batch_size
324
+ ed_idx = (iter + 1) * train_sp_batch_size
325
+ encoder_hidden_states = cond[st_idx:ed_idx]
326
+ #attention_mask = attn_mask[st_idx:ed_idx]
327
+ encoder_attention_mask = cond_mask[st_idx:ed_idx]
328
+ yield (
329
+ #latents[st_idx:ed_idx],
330
+ encoder_hidden_states,
331
+ #attention_mask,
332
+ encoder_attention_mask,
333
+ caption
334
+ )
335
+