diff --git a/fastvideo/config_sd/__pycache__/base.cpython-310.pyc b/fastvideo/config_sd/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d885beb4eb3b859252d38595f02f9ee402cf001 Binary files /dev/null and b/fastvideo/config_sd/__pycache__/base.cpython-310.pyc differ diff --git a/fastvideo/config_sd/base.py b/fastvideo/config_sd/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4046de1901a9bd1190a465d2bb3b0975cc8cf82a --- /dev/null +++ b/fastvideo/config_sd/base.py @@ -0,0 +1,113 @@ +import ml_collections + + +def get_config(): + config = ml_collections.ConfigDict() + + ###### General ###### + # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime. + config.run_name = "" + # random seed for reproducibility. + config.seed = 42 + # top-level logging directory for checkpoint saving. + config.logdir = "logs" + # number of epochs to train for. each epoch is one round of sampling from the model followed by training on those + # samples. + config.num_epochs = 300 + # number of epochs between saving model checkpoints. + config.save_freq = 20 + # number of checkpoints to keep before overwriting old ones. + config.num_checkpoint_limit = 5 + # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly. + config.mixed_precision = "bf16" + # allow tf32 on Ampere GPUs, which can speed up training. + config.allow_tf32 = True + # resume training from a checkpoint. either an exact checkpoint directory (e.g. checkpoint_50), or a directory + # containing checkpoints, in which case the latest one will be used. `config.use_lora` must be set to the same value + # as the run that generated the saved checkpoint. + config.resume_from = "" + # whether or not to use LoRA. LoRA reduces memory usage significantly by injecting small weight matrices into the + # attention layers of the UNet. with LoRA, fp16, and a batch size of 1, finetuning Stable Diffusion should take + # about 10GB of GPU memory. beware that if LoRA is disabled, training will take a lot of memory and saved checkpoint + # files will also be large. + config.use_lora = False + + ###### Pretrained Model ###### + config.pretrained = pretrained = ml_collections.ConfigDict() + # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub. + pretrained.model = "./data/StableDiffusion" + # revision of the model to load. + pretrained.revision = "main" + + ###### Sampling ###### + config.sample = sample = ml_collections.ConfigDict() + # number of sampler inference steps. + sample.num_steps = 50 + # eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0 + # being fully deterministic and 1.0 being equivalent to the DDPM sampler. + sample.eta = 1.0 + # classifier-free guidance weight. 1.0 is no guidance. + sample.guidance_scale = 5.0 + # batch size (per GPU!) to use for sampling. + sample.batch_size = 1 + # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch * + # batch_size * num_gpus`. + sample.num_batches_per_epoch = 2 + + ###### Training ###### + config.train = train = ml_collections.ConfigDict() + # batch size (per GPU!) to use for training. + train.batch_size = 1 + # whether to use the 8bit Adam optimizer from bitsandbytes. + train.use_8bit_adam = False + # learning rate. + train.learning_rate = 1e-5 + # Adam beta1. + train.adam_beta1 = 0.9 + # Adam beta2. + train.adam_beta2 = 0.999 + # Adam weight decay. + train.adam_weight_decay = 1e-4 + # Adam epsilon. + train.adam_epsilon = 1e-8 + # number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus * + # gradient_accumulation_steps`. + train.gradient_accumulation_steps = 1 + # maximum gradient norm for gradient clipping. + train.max_grad_norm = 1.0 + # number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one + # outer epoch's round of sampling. + train.num_inner_epochs = 1 + # whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during + # sampling will be used during training. + train.cfg = True + # clip advantages to the range [-adv_clip_max, adv_clip_max]. + train.adv_clip_max = 5 + # the PPO clip range. + train.clip_range = 1e-4 + # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the + # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates. + train.timestep_fraction = 1.0 + + ###### Prompt Function ###### + # prompt function to use. see `prompts.py` for available prompt functions. + config.prompt_fn = "imagenet_animals" + # kwargs to pass to the prompt function. + config.prompt_fn_kwargs = {} + + ###### Reward Function ###### + # reward function to use. see `rewards.py` for available reward functions. + config.reward_fn = "hpsv2" + + ###### Per-Prompt Stat Tracking ###### + # when enabled, the model will track the mean and std of reward on a per-prompt basis and use that to compute + # advantages. set `config.per_prompt_stat_tracking` to None to disable per-prompt stat tracking, in which case + # advantages will be calculated using the mean and std of the entire batch. + #config.per_prompt_stat_tracking = ml_collections.ConfigDict() + # number of reward values to store in the buffer for each prompt. the buffer persists across epochs. + #config.per_prompt_stat_tracking.buffer_size = 16 + # the minimum number of reward values to store in the buffer before using the per-prompt mean and std. if the buffer + # contains fewer than `min_count` values, the mean and std of the entire batch will be used instead. + #config.per_prompt_stat_tracking.min_count = 16 + + return config diff --git a/fastvideo/config_sd/dgx.py b/fastvideo/config_sd/dgx.py new file mode 100644 index 0000000000000000000000000000000000000000..d01cb530b82f25c1179aa13a6428d9c34e90c945 --- /dev/null +++ b/fastvideo/config_sd/dgx.py @@ -0,0 +1,60 @@ +import ml_collections +import imp +import os + +base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) + + +def compressibility(): + config = base.get_config() + + config.pretrained.model = "CompVis/stable-diffusion-v1-4" + + config.num_epochs = 300 + config.save_freq = 50 + config.num_checkpoint_limit = 100000000 + + # the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch. + config.sample.batch_size = 8 + config.sample.num_batches_per_epoch = 4 + + # this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch. + config.train.batch_size = 1 + config.train.gradient_accumulation_steps = 4 + + # prompting + config.prompt_fn = "imagenet_animals" + config.prompt_fn_kwargs = {} + + # rewards + config.reward_fn = "jpeg_compressibility" + + config.per_prompt_stat_tracking = { + "buffer_size": 16, + "min_count": 16, + } + + return config + +def hps(): + config = compressibility() + config.num_epochs = 300 + config.reward_fn = "aesthetic_score" + + # this reward is a bit harder to optimize, so I used 2 gradient updates per epoch. + config.train.gradient_accumulation_steps = 8 + + # the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch. + config.sample.batch_size = 4 + + # this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch. + config.train.batch_size = 4 + + config.prompt_fn = "aes" + config.chosen_number = 16 + config.num_generations = 16 + return config + + +def get_config(name): + return globals()[name]() diff --git a/fastvideo/data_preprocess/.DS_Store b/fastvideo/data_preprocess/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/fastvideo/data_preprocess/.DS_Store differ diff --git a/fastvideo/data_preprocess/preprocess_flux_embedding.py b/fastvideo/data_preprocess/preprocess_flux_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d6137d329c268cb6d9272e4ec30ab767b17845 --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_flux_embedding.py @@ -0,0 +1,170 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + + +import argparse +import torch +from accelerate.logging import get_logger +import cv2 +import json +import os +import torch.distributed as dist +from pathlib import Path + +logger = get_logger(__name__) +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from tqdm import tqdm +import re +from diffusers import FluxPipeline + +def contains_chinese(text): + return bool(re.search(r'[\u4e00-\u9fff]', text)) + +class T5dataset(Dataset): + def __init__( + self, txt_path, vae_debug, + ): + self.txt_path = txt_path + self.vae_debug = vae_debug + with open(self.txt_path, "r", encoding="utf-8") as f: + self.train_dataset = [ + line for line in f.read().splitlines() if not contains_chinese(line) + ][:50000] + + def __getitem__(self, idx): + #import pdb;pdb.set_trace() + caption = self.train_dataset[idx] + filename = str(idx) + #length = self.train_dataset[idx]["length"] + if self.vae_debug: + latents = torch.load( + os.path.join( + args.output_dir, "latent", self.train_dataset[idx]["latent_path"] + ), + map_location="cpu", + ) + else: + latents = [] + + return dict(caption=caption, latents=latents, filename=filename) + + def __len__(self): + return len(self.train_dataset) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=local_rank + ) + + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True) + + latents_txt_path = args.prompt_dir + train_dataset = T5dataset(latents_txt_path, args.vae_debug) + sampler = DistributedSampler( + train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True + ) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + flux_path = args.model_path + pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device) + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + try: + with torch.inference_mode(): + if args.vae_debug: + latents = data["latents"] + for idx, video_name in enumerate(data["filename"]): + prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( + prompt=data["caption"], prompt_2=data["caption"] + ) + prompt_embed_path = os.path.join( + args.output_dir, "prompt_embed", video_name + ".pt" + ) + pooled_prompt_embeds_path = os.path.join( + args.output_dir, "pooled_prompt_embeds", video_name + ".pt" + ) + + text_ids_path = os.path.join( + args.output_dir, "text_ids", video_name + ".pt" + ) + # save latent + torch.save(prompt_embeds[idx], prompt_embed_path) + torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path) + torch.save(text_ids[idx], text_ids_path) + item = {} + item["prompt_embed_path"] = video_name + ".pt" + item["text_ids"] = video_name + ".pt" + item["pooled_prompt_embeds_path"] = video_name + ".pt" + item["caption"] = data["caption"][idx] + json_data.append(item) + except Exception as e: + print(f"Rank {local_rank} Error: {repr(e)}") + dist.barrier() + raise + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + # os.remove(latents_json_path) + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + # text encoder & vae & diffusion model + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--vae_debug", action="store_true") + parser.add_argument("--prompt_dir", type=str, default="./empty.txt") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py b/fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py new file mode 100644 index 0000000000000000000000000000000000000000..f749d8d68bdb6d076d500e7c3722e997e8569ec1 --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_flux_embedding_rlpt.py @@ -0,0 +1,172 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + + +import argparse +import torch +from accelerate.logging import get_logger +import cv2 +import json +import os +import torch.distributed as dist +from pathlib import Path + +logger = get_logger(__name__) +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from tqdm import tqdm +import re +from diffusers import FluxPipeline + +def contains_chinese(text): + return bool(re.search(r'[\u4e00-\u9fff]', text)) + +class T5dataset(Dataset): + def __init__( + self, txt_path, vae_debug, + ): + self.txt_path = txt_path + self.vae_debug = vae_debug + print(f"[DEBUG] Loading captions from: {self.txt_path}") + with open(self.txt_path, "r", encoding="utf-8") as f: + self.train_dataset = [ + line.strip() for line in f.read().splitlines() if line.strip() and not contains_chinese(line) + ][:50000] + print(f"[DEBUG] Loaded {len(self.train_dataset)} captions after filtering") + + def __getitem__(self, idx): + #import pdb;pdb.set_trace() + caption = self.train_dataset[idx] + filename = str(idx) + #length = self.train_dataset[idx]["length"] + if self.vae_debug: + latents = torch.load( + os.path.join( + args.output_dir, "latent", self.train_dataset[idx]["latent_path"] + ), + map_location="cpu", + ) + else: + latents = [] + + return dict(caption=caption, latents=latents, filename=filename) + + def __len__(self): + return len(self.train_dataset) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=local_rank + ) + + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True) + + latents_txt_path = args.prompt_dir + train_dataset = T5dataset(latents_txt_path, args.vae_debug) + sampler = DistributedSampler( + train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True + ) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + flux_path = args.model_path + pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device) + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + try: + with torch.inference_mode(): + if args.vae_debug: + latents = data["latents"] + for idx, video_name in enumerate(data["filename"]): + prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( + prompt=data["caption"], prompt_2=data["caption"] + ) + prompt_embed_path = os.path.join( + args.output_dir, "prompt_embed", video_name + ".pt" + ) + pooled_prompt_embeds_path = os.path.join( + args.output_dir, "pooled_prompt_embeds", video_name + ".pt" + ) + + text_ids_path = os.path.join( + args.output_dir, "text_ids", video_name + ".pt" + ) + # save latent + torch.save(prompt_embeds[idx], prompt_embed_path) + torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path) + torch.save(text_ids[idx], text_ids_path) + item = {} + item["prompt_embed_path"] = video_name + ".pt" + item["text_ids"] = video_name + ".pt" + item["pooled_prompt_embeds_path"] = video_name + ".pt" + item["caption"] = data["caption"][idx] + json_data.append(item) + except Exception as e: + print(f"Rank {local_rank} Error: {repr(e)}") + dist.barrier() + raise + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + # os.remove(latents_json_path) + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + # text encoder & vae & diffusion model + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--vae_debug", action="store_true") + parser.add_argument("--prompt_dir", type=str, default="./empty.txt") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py b/fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc8c655a0d094a0aa978dd8e2a71eceaff796cb --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_flux_rfpt_embedding.py @@ -0,0 +1,224 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + + +import argparse +import torch +from accelerate.logging import get_logger +import cv2 +import json +import os +import torch.distributed as dist +import pandas as pd +from torch.utils.data.dataset import ConcatDataset, Dataset +import io +import torchvision.transforms as transforms +logger = get_logger(__name__) +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from tqdm import tqdm +import re +from diffusers import FluxPipeline +from PIL import Image +from diffusers.image_processor import VaeImageProcessor + +def contains_chinese(text): + return bool(re.search(r'[\u4e00-\u9fff]', text)) + +class RFPTdataset(Dataset): + def __init__( + self, file_path, + ): + self.file_path = file_path + file_names = os.listdir(self.file_path) # each file contains 5,000 images + self.file_names = [os.path.join(self.file_path, file_name) for file_name in file_names] + self.train_dataset = self.read_data() + self.transform = transforms.ToTensor() + + def read_data(self): + df_list = [pd.read_parquet(file_name) for file_name in self.file_names] + combined_df = pd.concat(df_list, axis=0, ignore_index=True) + return combined_df + + def __len__(self): + return len(self.train_dataset) + + def __getitem__(self, index): + + image = self.train_dataset.iloc[index]['image']['bytes'] + image = self.transform(Image.open(io.BytesIO(image)).convert('RGB')) + # print(image.shape) + + caption = self.train_dataset.iloc[index]['caption_composition'] + # print(caption) + filename = str(index) + if caption == None or image == None: + return self.__getitem__(index+1) + return dict(caption=caption, image=image, filename=filename) + +class T5dataset(Dataset): + def __init__( + self, txt_path, vae_debug, + ): + self.txt_path = txt_path + self.vae_debug = vae_debug + with open(self.txt_path, "r", encoding="utf-8") as f: + self.train_dataset = [ + line for line in f.read().splitlines() if not contains_chinese(line) + ][:50000] + + def __getitem__(self, idx): + #import pdb;pdb.set_trace() + caption = self.train_dataset[idx] + filename = str(idx) + #length = self.train_dataset[idx]["length"] + if self.vae_debug: + latents = torch.load( + os.path.join( + args.output_dir, "latent", self.train_dataset[idx]["latent_path"] + ), + map_location="cpu", + ) + else: + latents = [] + + return dict(caption=caption, latents=latents, filename=filename) + + def __len__(self): + return len(self.train_dataset) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=local_rank + ) + + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True) + + # latents_txt_path = args.prompt_dir + # train_dataset = T5dataset(latents_txt_path, args.vae_debug) + + train_dataset = RFPTdataset(args.prompt_dir) + + + sampler = DistributedSampler( + train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True + ) + + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + flux_path = args.model_path + pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device) + image_processor = VaeImageProcessor(16) + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + try: + with torch.inference_mode(): + if args.vae_debug: + latents = data["latents"] + for idx, video_name in enumerate(data["filename"]): + # prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( + # prompt=data["caption"], prompt_2=data["caption"] + # ) + # image_latents = pipe.vae.encode(data["image"].to(torch.bfloat16).to(device)).latent_dist.sample() + # output_image = pipe.vae.decode(image_latents, return_dict=False)[0] + # output_image = image_processor.postprocess(output_image) + # output_image[0].save('output.png') + # print(image_latents.latent_dist.sample()) + # print(image_latents.latent_dist.sample().shape) + + prompt_embed_path = os.path.join( + args.output_dir, "prompt_embed", video_name + ".pt" + ) + pooled_prompt_embeds_path = os.path.join( + args.output_dir, "pooled_prompt_embeds", video_name + ".pt" + ) + + text_ids_path = os.path.join( + args.output_dir, "text_ids", video_name + ".pt" + ) + + image_latents_path = os.path.join( + args.output_dir, "images", video_name + ".pt" + ) + # save latent + # torch.save(prompt_embeds[idx], prompt_embed_path) + # torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path) + # torch.save(text_ids[idx], text_ids_path) + torch.save(data["image"].to(torch.bfloat16), image_latents_path) + item = {} + item["prompt_embed_path"] = video_name + ".pt" + item["text_ids"] = video_name + ".pt" + item["pooled_prompt_embeds_path"] = video_name + ".pt" + item["caption"] = data["caption"][idx] + json_data.append(item) + except Exception as e: + print(f"Rank {local_rank} Error: {repr(e)}") + dist.barrier() + raise + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + # os.remove(latents_json_path) + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + # text encoder & vae & diffusion model + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--vae_debug", action="store_true") + parser.add_argument("--prompt_dir", type=str, default="./empty.txt") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/fastvideo/data_preprocess/preprocess_qwenimage_embedding.py b/fastvideo/data_preprocess/preprocess_qwenimage_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6796d20e2759087f0568b717e66bf1ef7453b1 --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_qwenimage_embedding.py @@ -0,0 +1,220 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import argparse +import torch +from accelerate.logging import get_logger +# from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from diffusers.utils import export_to_video +from fastvideo.models.qwenimage.pipeline_qwenimage import QwenImagePipeline +import json +import os +import torch.distributed as dist + +logger = get_logger(__name__) +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from fastvideo.utils.load import load_text_encoder, load_vae +from diffusers.video_processor import VideoProcessor +from tqdm import tqdm +import re +from diffusers import DiffusionPipeline +import torch.nn.functional as F + +def contains_chinese(text): + """检查字符串是否包含中文字符""" + return bool(re.search(r'[\u4e00-\u9fff]', text)) + +class T5dataset(Dataset): + def __init__( + self, txt_path, vae_debug, + ): + self.txt_path = txt_path + self.vae_debug = vae_debug + with open(self.txt_path, "r", encoding="utf-8") as f: + self.train_dataset = [ + line for line in f.read().splitlines() if not contains_chinese(line) + ] + #self.train_dataset = sorted(train_dataset) + + def __getitem__(self, idx): + #import pdb;pdb.set_trace() + caption = self.train_dataset[idx] + filename = str(idx) + #length = self.train_dataset[idx]["length"] + if self.vae_debug: + latents = torch.load( + os.path.join( + args.output_dir, "latent", self.train_dataset[idx]["latent_path"] + ), + map_location="cpu", + ) + else: + latents = [] + + return dict(caption=caption, latents=latents, filename=filename) + + def __len__(self): + return len(self.train_dataset) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=local_rank + ) + + #videoprocessor = VideoProcessor(vae_scale_factor=8) + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True) + + latents_txt_path = args.prompt_dir + train_dataset = T5dataset(latents_txt_path, args.vae_debug) + #text_encoder = load_text_encoder(args.model_type, args.model_path, device=device) + #vae, autocast_type, fps = load_vae(args.model_type, args.model_path) + #vae.enable_tiling() + sampler = DistributedSampler( + train_dataset, rank=local_rank, num_replicas=world_size, shuffle=False + ) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + # Load pipeline but don't move everything to GPU yet + pipe = QwenImagePipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) + + # Only move text_encoder to GPU for embedding generation + pipe.text_encoder = pipe.text_encoder.to(device) + + # Delete unused components to free up RAM/VRAM + if not args.vae_debug: + # Remove from attributes + if hasattr(pipe, "transformer"): + del pipe.transformer + if hasattr(pipe, "vae"): + del pipe.vae + + # Remove from components dictionary to ensure garbage collection + if "transformer" in pipe.components: + del pipe.components["transformer"] + if "vae" in pipe.components: + del pipe.components["vae"] + + import gc + gc.collect() + torch.cuda.empty_cache() + + # pipe._execution_device = device # This causes AttributeError, removing it. + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + with torch.inference_mode(): + with torch.autocast("cuda"): + prompt_embeds, prompt_attention_mask = pipe.encode_prompt( + prompt=data["caption"], + device=device # Explicitly pass device + ) + + # ==================== 代码修改开始 ==================== + + # 1. 记录原始的序列长度 (第二个维度的大小) + original_length = prompt_embeds.shape[1] + target_length = 1024 + + # 2. 计算需要填充的长度 + # 假设 original_length 不会超过 target_length + pad_len = target_length - original_length + + # 3. 填充 prompt_embeds + # prompt_embeds 是一个3D张量 (B, L, D),我们需要填充第二个维度 L + # F.pad 的填充参数顺序是从最后一个维度开始的 (pad_dim_D_left, pad_dim_D_right, pad_dim_L_left, pad_dim_L_right, ...) + # 我们在维度1(序列长度L)的右侧进行填充 + prompt_embeds = F.pad(prompt_embeds, (0, 0, 0, pad_len), "constant", 0) + + # 4. 填充 prompt_attention_mask + # prompt_attention_mask 是一个2D张量 (B, L),我们同样填充第二个维度 L + # 我们在维度1(序列长度L)的右侧进行填充 + prompt_attention_mask = F.pad(prompt_attention_mask, (0, pad_len), "constant", 0) + + # ==================== 代码修改结束 ==================== + + if args.vae_debug: + latents = data["latents"] + for idx, video_name in enumerate(data["filename"]): + prompt_embed_path = os.path.join( + args.output_dir, "prompt_embed", video_name + ".pt" + ) + prompt_attention_mask_path = os.path.join( + args.output_dir, "prompt_attention_mask", video_name + ".pt" + ) + # 保存 latent (注意这里保存的是填充后的张量) + torch.save(prompt_embeds[idx], prompt_embed_path) + torch.save(prompt_attention_mask[idx], prompt_attention_mask_path) + item = {} + item["prompt_embed_path"] = video_name + ".pt" + item["prompt_attention_mask"] = video_name + ".pt" + item["caption"] = data["caption"][idx] + + # [新增] 将原始长度记录到 item 字典中 + item["original_length"] = original_length + + json_data.append(item) + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + # os.remove(latents_json_path) + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + # text encoder & vae & diffusion model + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--vae_debug", action="store_true") + parser.add_argument("--prompt_dir", type=str, default="./empty.txt") + args = parser.parse_args() + main(args) diff --git a/fastvideo/data_preprocess/preprocess_rl_embeddings.py b/fastvideo/data_preprocess/preprocess_rl_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..54dc7f5a8625cf69010c867a66fc994d132d7279 --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_rl_embeddings.py @@ -0,0 +1,175 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import argparse +import torch +from accelerate.logging import get_logger +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from diffusers.utils import export_to_video +import json +import os +import torch.distributed as dist + +logger = get_logger(__name__) +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from fastvideo.utils.load import load_text_encoder, load_vae +from diffusers.video_processor import VideoProcessor +from tqdm import tqdm +import re + +def contains_chinese(text): + """检查字符串是否包含中文字符""" + return bool(re.search(r'[\u4e00-\u9fff]', text)) + +class T5dataset(Dataset): + def __init__( + self, txt_path, vae_debug, + ): + self.txt_path = txt_path + self.vae_debug = vae_debug + with open(self.txt_path, "r", encoding="utf-8") as f: + self.train_dataset = [ + line for line in f.read().splitlines() if not contains_chinese(line) + ] + #self.train_dataset = sorted(train_dataset) + + def __getitem__(self, idx): + #import pdb;pdb.set_trace() + caption = self.train_dataset[idx] + filename = str(idx) + #length = self.train_dataset[idx]["length"] + if self.vae_debug: + latents = torch.load( + os.path.join( + args.output_dir, "latent", self.train_dataset[idx]["latent_path"] + ), + map_location="cpu", + ) + else: + latents = [] + + return dict(caption=caption, latents=latents, filename=filename) + + def __len__(self): + return len(self.train_dataset) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=local_rank + ) + + #videoprocessor = VideoProcessor(vae_scale_factor=8) + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True) + #os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True) + + latents_txt_path = args.prompt_dir + train_dataset = T5dataset(latents_txt_path, args.vae_debug) + text_encoder = load_text_encoder(args.model_type, args.model_path, device=device) + #vae, autocast_type, fps = load_vae(args.model_type, args.model_path) + #vae.enable_tiling() + sampler = DistributedSampler( + train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True + ) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + with torch.inference_mode(): + with torch.autocast("cuda"): + prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( + prompt=data["caption"], + ) + if args.vae_debug: + latents = data["latents"] + #video = vae.decode(latents.to(device), return_dict=False)[0] + #video = videoprocessor.postprocess_video(video) + for idx, video_name in enumerate(data["filename"]): + prompt_embed_path = os.path.join( + args.output_dir, "prompt_embed", video_name + ".pt" + ) + #video_path = os.path.join( + # args.output_dir, "video", video_name + ".mp4" + #) + prompt_attention_mask_path = os.path.join( + args.output_dir, "prompt_attention_mask", video_name + ".pt" + ) + # save latent + torch.save(prompt_embeds[idx], prompt_embed_path) + torch.save(prompt_attention_mask[idx], prompt_attention_mask_path) + #print(f"sample {video_name} saved") + #if args.vae_debug: + # export_to_video(video[idx], video_path, fps=fps) + item = {} + #item["length"] = int(data["length"][idx]) + #item["latent_path"] = video_name + ".pt" + item["prompt_embed_path"] = video_name + ".pt" + item["prompt_attention_mask"] = video_name + ".pt" + item["caption"] = data["caption"][idx] + json_data.append(item) + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + # os.remove(latents_json_path) + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + # text encoder & vae & diffusion model + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--vae_debug", action="store_true") + parser.add_argument("--prompt_dir", type=str, default="./empty.txt") + args = parser.parse_args() + main(args) diff --git a/fastvideo/data_preprocess/preprocess_text_embeddings.py b/fastvideo/data_preprocess/preprocess_text_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..f82b3d97f33417f19862f67e431a0aa10072545e --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_text_embeddings.py @@ -0,0 +1,175 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import argparse +import json +import os + +import torch +import torch.distributed as dist +from accelerate.logging import get_logger +from diffusers.utils import export_to_video +from diffusers.video_processor import VideoProcessor +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from fastvideo.utils.load import load_text_encoder, load_vae + +logger = get_logger(__name__) + + +class T5dataset(Dataset): + + def __init__( + self, + json_path, + vae_debug, + ): + self.json_path = json_path + self.vae_debug = vae_debug + with open(self.json_path, "r") as f: + train_dataset = json.load(f) + self.train_dataset = sorted(train_dataset, + key=lambda x: x["latent_path"]) + + def __getitem__(self, idx): + caption = self.train_dataset[idx]["caption"] + filename = self.train_dataset[idx]["latent_path"].split(".")[0] + length = self.train_dataset[idx]["length"] + if self.vae_debug: + latents = torch.load( + os.path.join(args.output_dir, "latent", + self.train_dataset[idx]["latent_path"]), + map_location="cpu", + ) + else: + latents = [] + + return dict(caption=caption, + latents=latents, + filename=filename, + length=length) + + def __len__(self): + return len(self.train_dataset) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) + + videoprocessor = VideoProcessor(vae_scale_factor=8) + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), + exist_ok=True) + + latents_json_path = os.path.join(args.output_dir, + "videos2caption_temp.json") + train_dataset = T5dataset(latents_json_path, args.vae_debug) + text_encoder = load_text_encoder(args.model_type, + args.model_path, + device=device) + vae, autocast_type, fps = load_vae(args.model_type, args.model_path) + vae.enable_tiling() + sampler = DistributedSampler(train_dataset, + rank=local_rank, + num_replicas=world_size, + shuffle=True) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + with torch.inference_mode(): + with torch.autocast("cuda", dtype=autocast_type): + prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( + prompt=data["caption"], ) + if args.vae_debug: + latents = data["latents"] + video = vae.decode(latents.to(device), + return_dict=False)[0] + video = videoprocessor.postprocess_video(video) + for idx, video_name in enumerate(data["filename"]): + prompt_embed_path = os.path.join(args.output_dir, + "prompt_embed", + video_name + ".pt") + video_path = os.path.join(args.output_dir, "video", + video_name + ".mp4") + prompt_attention_mask_path = os.path.join( + args.output_dir, "prompt_attention_mask", + video_name + ".pt") + # save latent + torch.save(prompt_embeds[idx], prompt_embed_path) + torch.save(prompt_attention_mask[idx], + prompt_attention_mask_path) + print(f"sample {video_name} saved") + if args.vae_debug: + export_to_video(video[idx], video_path, fps=fps) + item = {} + item["length"] = int(data["length"][idx]) + item["latent_path"] = video_name + ".pt" + item["prompt_embed_path"] = video_name + ".pt" + item["prompt_attention_mask"] = video_name + ".pt" + item["caption"] = data["caption"][idx] + json_data.append(item) + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + # os.remove(latents_json_path) + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption.json"), + "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + # text encoder & vae & diffusion model + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--text_encoder_name", + type=str, + default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help= + "The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--vae_debug", action="store_true") + args = parser.parse_args() + main(args) diff --git a/fastvideo/data_preprocess/preprocess_vae_latents.py b/fastvideo/data_preprocess/preprocess_vae_latents.py new file mode 100644 index 0000000000000000000000000000000000000000..961c51322627795506b2fa04ccb0ca802854b624 --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_vae_latents.py @@ -0,0 +1,137 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import argparse +import json +import os + +import torch +import torch.distributed as dist +from accelerate.logging import get_logger +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from fastvideo.dataset import getdataset +from fastvideo.utils.load import load_vae + +logger = get_logger(__name__) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + train_dataset = getdataset(args) + sampler = DistributedSampler(train_dataset, + rank=local_rank, + num_replicas=world_size, + shuffle=True) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + encoder_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) + vae, autocast_type, fps = load_vae(args.model_type, args.model_path) + vae.enable_tiling() + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True) + + json_data = [] + for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): + with torch.inference_mode(): + with torch.autocast("cuda", dtype=autocast_type): + latents = vae.encode(data["pixel_values"].to( + encoder_device))["latent_dist"].sample() + for idx, video_path in enumerate(data["path"]): + video_name = os.path.basename(video_path).split(".")[0] + latent_path = os.path.join(args.output_dir, "latent", + video_name + ".pt") + torch.save(latents[idx].to(torch.bfloat16), latent_path) + item = {} + item["length"] = latents[idx].shape[1] + item["latent_path"] = video_name + ".pt" + item["caption"] = data["text"][idx] + json_data.append(item) + print(f"{video_name} processed") + dist.barrier() + local_data = json_data + gathered_data = [None] * world_size + dist.all_gather_object(gathered_data, local_data) + if local_rank == 0: + all_json_data = [item for sublist in gathered_data for item in sublist] + with open(os.path.join(args.output_dir, "videos2caption_temp.json"), + "w") as f: + json.dump(all_json_data, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + parser.add_argument("--data_merge_path", type=str, required=True) + parser.add_argument("--num_frames", type=int, default=163) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_latent_t", + type=int, + default=28, + help="Number of latent timesteps.") + parser.add_argument("--max_height", type=int, default=480) + parser.add_argument("--max_width", type=int, default=848) + parser.add_argument("--video_length_tolerance_range", + type=int, + default=2.0) + parser.add_argument("--group_frame", action="store_true") # TODO + parser.add_argument("--group_resolution", action="store_true") # TODO + parser.add_argument("--dataset", default="t2v") + parser.add_argument("--train_fps", type=int, default=30) + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--text_max_length", type=int, default=256) + parser.add_argument("--speed_factor", type=float, default=1.0) + parser.add_argument("--drop_short_ratio", type=float, default=1.0) + # text encoder & vae & diffusion model + parser.add_argument("--text_encoder_name", + type=str, + default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument("--cfg", type=float, default=0.0) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help= + "The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help= + ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + + args = parser.parse_args() + main(args) diff --git a/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py b/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..d0bea086e900ea411b4bd2d0b88ace25782bc82a --- /dev/null +++ b/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py @@ -0,0 +1,80 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import argparse +import os + +import torch +import torch.distributed as dist +from accelerate.logging import get_logger + +from fastvideo.utils.load import load_text_encoder + +logger = get_logger(__name__) + + +def main(args): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size, "local rank", local_rank) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) + + text_encoder = load_text_encoder(args.model_type, + args.model_path, + device=device) + autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16 + # output_dir/validation/prompt_attention_mask + # output_dir/validation/prompt_embed + os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True) + os.makedirs( + os.path.join(args.output_dir, "validation", "prompt_attention_mask"), + exist_ok=True, + ) + os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"), + exist_ok=True) + + with open(args.validation_prompt_txt, "r", encoding="utf-8") as file: + lines = file.readlines() + prompts = [line.strip() for line in lines] + for prompt in prompts: + with torch.inference_mode(): + with torch.autocast("cuda", dtype=autocast_type): + prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( + prompt) + file_name = prompt.split(".")[0] + prompt_embed_path = os.path.join(args.output_dir, "validation", + "prompt_embed", + f"{file_name}.pt") + prompt_attention_mask_path = os.path.join( + args.output_dir, + "validation", + "prompt_attention_mask", + f"{file_name}.pt", + ) + torch.save(prompt_embeds[0], prompt_embed_path) + torch.save(prompt_attention_mask[0], + prompt_attention_mask_path) + print(f"sample {file_name} saved") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + parser.add_argument("--validation_prompt_txt", type=str) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help= + "The output directory where the model predictions and checkpoints will be written.", + ) + args = parser.parse_args() + main(args) diff --git a/fastvideo/dataset/.DS_Store b/fastvideo/dataset/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/fastvideo/dataset/.DS_Store differ diff --git a/fastvideo/dataset/__init__.py b/fastvideo/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..781f232ad8c7ece117bec989414c7959227d9a29 --- /dev/null +++ b/fastvideo/dataset/__init__.py @@ -0,0 +1,104 @@ +from torchvision import transforms +from torchvision.transforms import Lambda +from transformers import AutoTokenizer + +from fastvideo.dataset.t2v_datasets import T2V_dataset +from fastvideo.dataset.transform import (CenterCropResizeVideo, Normalize255, + TemporalRandomCrop) + + +def getdataset(args): + temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x + norm_fun = Lambda(lambda x: 2.0 * x - 1.0) + resize_topcrop = [ + CenterCropResizeVideo((args.max_height, args.max_width), + top_crop=True), + ] + resize = [ + CenterCropResizeVideo((args.max_height, args.max_width)), + ] + transform = transforms.Compose([ + # Normalize255(), + *resize, + ]) + transform_topcrop = transforms.Compose([ + Normalize255(), + *resize_topcrop, + norm_fun, + ]) + # tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir) + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, + cache_dir=args.cache_dir) + if args.dataset == "t2v": + return T2V_dataset( + args, + transform=transform, + temporal_sample=temporal_sample, + tokenizer=tokenizer, + transform_topcrop=transform_topcrop, + ) + + raise NotImplementedError(args.dataset) + + +if __name__ == "__main__": + import random + + from accelerate import Accelerator + from tqdm import tqdm + + from fastvideo.dataset.t2v_datasets import dataset_prog + + args = type( + "args", + (), + { + "ae": "CausalVAEModel_4x8x8", + "dataset": "t2v", + "attention_mode": "xformers", + "use_rope": True, + "text_max_length": 300, + "max_height": 320, + "max_width": 240, + "num_frames": 1, + "use_image_num": 0, + "interpolation_scale_t": 1, + "interpolation_scale_h": 1, + "interpolation_scale_w": 1, + "cache_dir": "../cache_dir", + "image_data": + "/storage/ongoing/new/Open-Sora-Plan-bak/7.14bak/scripts/train_data/image_data.txt", + "video_data": "1", + "train_fps": 24, + "drop_short_ratio": 1.0, + "use_img_from_vid": False, + "speed_factor": 1.0, + "cfg": 0.1, + "text_encoder_name": "google/mt5-xxl", + "dataloader_num_workers": 10, + }, + ) + accelerator = Accelerator() + dataset = getdataset(args) + num = len(dataset_prog.img_cap_list) + zero = 0 + for idx in tqdm(range(num)): + image_data = dataset_prog.img_cap_list[idx] + caps = [ + i["cap"] if isinstance(i["cap"], list) else [i["cap"]] + for i in image_data + ] + try: + caps = [[random.choice(i)] for i in caps] + except Exception as e: + print(e) + # import ipdb;ipdb.set_trace() + print(image_data) + zero += 1 + continue + assert caps[0] is not None and len(caps[0]) > 0 + print(num, zero) + import ipdb + + ipdb.set_trace() + print("end") diff --git a/fastvideo/dataset/__pycache__/__init__.cpython-310.pyc b/fastvideo/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb884f79e51c0529ec9f2a1150c6a71bdbfae3aa Binary files /dev/null and b/fastvideo/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/fastvideo/dataset/__pycache__/__init__.cpython-312.pyc b/fastvideo/dataset/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e925f1cf94d8788ecb2f85daa58cd235d906eb1 Binary files /dev/null and b/fastvideo/dataset/__pycache__/__init__.cpython-312.pyc differ diff --git a/fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc b/fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a924d31d3dbfec16c859f8a97559b2b69c4259b8 Binary files /dev/null and b/fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets.cpython-312.pyc differ diff --git a/fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc b/fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd3f44eee72a13562e0b312b6dfeb414fae139ee Binary files /dev/null and b/fastvideo/dataset/__pycache__/latent_flux_rfpt_datasets_all.cpython-312.pyc differ diff --git a/fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc b/fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf4d6776e16bd908fb1ab8076689c62a674f04de Binary files /dev/null and b/fastvideo/dataset/__pycache__/latent_flux_rl_datasets.cpython-312.pyc differ diff --git a/fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc b/fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c29903028f7ac8351a39a5116e059f3ef5198b40 Binary files /dev/null and b/fastvideo/dataset/__pycache__/latent_qwenimage_rl_datasets.cpython-310.pyc differ diff --git a/fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc b/fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc23a4a45360364ab16d99efeff89e5e41a8cbb9 Binary files /dev/null and b/fastvideo/dataset/__pycache__/t2v_datasets.cpython-310.pyc differ diff --git a/fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc b/fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be3bd390a821693f6b2ae69f2c83bfb8dafc6e19 Binary files /dev/null and b/fastvideo/dataset/__pycache__/t2v_datasets.cpython-312.pyc differ diff --git a/fastvideo/dataset/__pycache__/transform.cpython-310.pyc b/fastvideo/dataset/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bf7755a45fd9f20819b98616e5acfe899c38002 Binary files /dev/null and b/fastvideo/dataset/__pycache__/transform.cpython-310.pyc differ diff --git a/fastvideo/dataset/__pycache__/transform.cpython-312.pyc b/fastvideo/dataset/__pycache__/transform.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1f70f7ec13e12a16f78a449ffd0381bba6421d6 Binary files /dev/null and b/fastvideo/dataset/__pycache__/transform.cpython-312.pyc differ diff --git a/fastvideo/dataset/latent_datasets.py b/fastvideo/dataset/latent_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6c2e3907be93ceb355715d216078f88a0340d0 --- /dev/null +++ b/fastvideo/dataset/latent_datasets.py @@ -0,0 +1,132 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import json +import os +import random + +import torch +from torch.utils.data import Dataset + + +class LatentDataset(Dataset): + + def __init__( + self, + json_path, + num_latent_t, + cfg_rate, + ): + # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path + self.json_path = json_path + self.cfg_rate = cfg_rate + self.datase_dir_path = os.path.dirname(json_path) + self.video_dir = os.path.join(self.datase_dir_path, "video") + self.latent_dir = os.path.join(self.datase_dir_path, "latent") + self.prompt_embed_dir = os.path.join(self.datase_dir_path, + "prompt_embed") + self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path, + "prompt_attention_mask") + with open(self.json_path, "r") as f: + self.data_anno = json.load(f) + # json.load(f) already keeps the order + # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path']) + self.num_latent_t = num_latent_t + # just zero embeddings [256, 4096] + self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32) + # 256 zeros + self.uncond_prompt_mask = torch.zeros(256).bool() + self.lengths = [ + data_item["length"] if "length" in data_item else 1 + for data_item in self.data_anno + ] + + def __getitem__(self, idx): + latent_file = self.data_anno[idx]["latent_path"] + prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] + prompt_attention_mask_file = self.data_anno[idx][ + "prompt_attention_mask"] + # load + latent = torch.load( + os.path.join(self.latent_dir, latent_file), + map_location="cpu", + weights_only=True, + ) + latent = latent.squeeze(0)[:, -self.num_latent_t:] + if random.random() < self.cfg_rate: + prompt_embed = self.uncond_prompt_embed + prompt_attention_mask = self.uncond_prompt_mask + else: + prompt_embed = torch.load( + os.path.join(self.prompt_embed_dir, prompt_embed_file), + map_location="cpu", + weights_only=True, + ) + prompt_attention_mask = torch.load( + os.path.join(self.prompt_attention_mask_dir, + prompt_attention_mask_file), + map_location="cpu", + weights_only=True, + ) + return latent, prompt_embed, prompt_attention_mask + + def __len__(self): + return len(self.data_anno) + + +def latent_collate_function(batch): + # return latent, prompt, latent_attn_mask, text_attn_mask + # latent_attn_mask: # b t h w + # text_attn_mask: b 1 l + # needs to check if the latent/prompt' size and apply padding & attn mask + latents, prompt_embeds, prompt_attention_masks = zip(*batch) + # calculate max shape + max_t = max([latent.shape[1] for latent in latents]) + max_h = max([latent.shape[2] for latent in latents]) + max_w = max([latent.shape[3] for latent in latents]) + + # padding + latents = [ + torch.nn.functional.pad( + latent, + ( + 0, + max_t - latent.shape[1], + 0, + max_h - latent.shape[2], + 0, + max_w - latent.shape[3], + ), + ) for latent in latents + ] + # attn mask + latent_attn_mask = torch.ones(len(latents), max_t, max_h, max_w) + # set to 0 if padding + for i, latent in enumerate(latents): + latent_attn_mask[i, latent.shape[1]:, :, :] = 0 + latent_attn_mask[i, :, latent.shape[2]:, :] = 0 + latent_attn_mask[i, :, :, latent.shape[3]:] = 0 + + prompt_embeds = torch.stack(prompt_embeds, dim=0) + prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0) + latents = torch.stack(latents, dim=0) + return latents, prompt_embeds, latent_attn_mask, prompt_attention_masks + + +if __name__ == "__main__": + dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt", + num_latent_t=28) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=2, + shuffle=False, + collate_fn=latent_collate_function) + for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader: + print( + latent.shape, + prompt_embed.shape, + latent_attn_mask.shape, + prompt_attention_mask.shape, + ) + import pdb + + pdb.set_trace() diff --git a/fastvideo/dataset/latent_flux_rfpt_datasets.py b/fastvideo/dataset/latent_flux_rfpt_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2e789ea5d7f583fb0a2ed00cd8b8fda976a213 --- /dev/null +++ b/fastvideo/dataset/latent_flux_rfpt_datasets.py @@ -0,0 +1,122 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import torch +from torch.utils.data import Dataset +import json +import os +import random + + +class LatentDataset(Dataset): + def __init__( + self, json_path, num_latent_t, cfg_rate, + ): + # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path + self.json_path = json_path + self.cfg_rate = cfg_rate + self.datase_dir_path = os.path.dirname(json_path) + #self.video_dir = os.path.join(self.datase_dir_path, "video") + #self.latent_dir = os.path.join(self.datase_dir_path, "latent") + self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed") + self.pooled_prompt_embeds_dir = os.path.join( + self.datase_dir_path, "pooled_prompt_embeds" + ) + self.text_ids_dir = os.path.join( + self.datase_dir_path, "text_ids" + ) + self.latents_dir = os.path.join( + self.datase_dir_path, "images" + ) + with open(self.json_path, "r") as f: + self.data_anno = json.load(f) + # json.load(f) already keeps the order + # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path']) + self.num_latent_t = num_latent_t + # just zero embeddings [256, 4096] + self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32) + # 256 zeros + self.uncond_prompt_mask = torch.zeros(256).bool() + self.lengths = [ + data_item["length"] if "length" in data_item else 1 + for data_item in self.data_anno + ] + + def __getitem__(self, idx): + #latent_file = self.data_anno[idx]["latent_path"] + prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] + pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"] + text_ids_file = self.data_anno[idx]["text_ids"] + latent_file = text_ids_file + if random.random() < self.cfg_rate: + prompt_embed = self.uncond_prompt_embed + else: + prompt_embed = torch.load( + os.path.join(self.prompt_embed_dir, prompt_embed_file), + map_location="cpu", + weights_only=True, + ) + pooled_prompt_embeds = torch.load( + os.path.join( + self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file + ), + map_location="cpu", + weights_only=True, + ) + text_ids = torch.load( + os.path.join( + self.text_ids_dir, text_ids_file + ), + map_location="cpu", + weights_only=True, + ) + latents = torch.load( + os.path.join( + self.latents_dir, latent_file + ), + map_location="cpu", + weights_only=True, + ) + return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'], latents + + def __len__(self): + return len(self.data_anno) + + +def latent_collate_function(batch): + # return latent, prompt, latent_attn_mask, text_attn_mask + # latent_attn_mask: # b t h w + # text_attn_mask: b 1 l + # needs to check if the latent/prompt' size and apply padding & attn mask + prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents = zip(*batch) + # attn mask + prompt_embeds = torch.stack(prompt_embeds, dim=0) + pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0) + text_ids = torch.stack(text_ids, dim=0) + latents= torch.stack(latents, dim=0) + #latents = torch.stack(latents, dim=0) + return prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents + + +if __name__ == "__main__": + dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function + ) + for prompt_embed, prompt_attention_mask, caption in dataloader: + print( + prompt_embed.shape, + prompt_attention_mask.shape, + caption + ) + import pdb + + pdb.set_trace() \ No newline at end of file diff --git a/fastvideo/dataset/latent_flux_rfpt_datasets_all.py b/fastvideo/dataset/latent_flux_rfpt_datasets_all.py new file mode 100644 index 0000000000000000000000000000000000000000..96c66007dd4d4328b4e17954c4455090c6d40b67 --- /dev/null +++ b/fastvideo/dataset/latent_flux_rfpt_datasets_all.py @@ -0,0 +1,134 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import torch +from torch.utils.data import Dataset +import json +import os +import random + + +class LatentDataset(Dataset): + def __init__( + self, json_path, num_latent_t, cfg_rate, + ): + # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path + self.json_path = json_path + self.cfg_rate = cfg_rate + self.datase_dir_path = os.path.dirname(json_path) + #self.video_dir = os.path.join(self.datase_dir_path, "video") + #self.latent_dir = os.path.join(self.datase_dir_path, "latent") + self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed") + self.pooled_prompt_embeds_dir = os.path.join( + self.datase_dir_path, "pooled_prompt_embeds" + ) + self.text_ids_dir = os.path.join( + self.datase_dir_path, "text_ids" + ) + self.images_dir = os.path.join( + self.datase_dir_path, "images" + ) + self.latents_dir = os.path.join( + self.datase_dir_path, "latents" + ) + with open(self.json_path, "r") as f: + self.data_anno = json.load(f) + # json.load(f) already keeps the order + # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path']) + self.num_latent_t = num_latent_t + # just zero embeddings [256, 4096] + self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32) + # 256 zeros + self.uncond_prompt_mask = torch.zeros(256).bool() + self.lengths = [ + data_item["length"] if "length" in data_item else 1 + for data_item in self.data_anno + ] + + def __getitem__(self, idx): + #latent_file = self.data_anno[idx]["latent_path"] + prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] + pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"] + text_ids_file = self.data_anno[idx]["text_ids"] + latent_file = text_ids_file + image_file = text_ids_file + if random.random() < self.cfg_rate: + prompt_embed = self.uncond_prompt_embed + else: + prompt_embed = torch.load( + os.path.join(self.prompt_embed_dir, prompt_embed_file), + map_location="cpu", + weights_only=True, + ) + pooled_prompt_embeds = torch.load( + os.path.join( + self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file + ), + map_location="cpu", + weights_only=True, + ) + text_ids = torch.load( + os.path.join( + self.text_ids_dir, text_ids_file + ), + map_location="cpu", + weights_only=True, + ) + latents = torch.load( + os.path.join( + self.latents_dir, latent_file + ), + map_location="cpu", + weights_only=True, + ) + images = torch.load( + os.path.join( + self.images_dir, image_file + ), + map_location="cpu", + weights_only=True, + ) + return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'], latents, images + + def __len__(self): + return len(self.data_anno) + + +def latent_collate_function(batch): + # return latent, prompt, latent_attn_mask, text_attn_mask + # latent_attn_mask: # b t h w + # text_attn_mask: b 1 l + # needs to check if the latent/prompt' size and apply padding & attn mask + prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents, images = zip(*batch) + # attn mask + prompt_embeds = torch.stack(prompt_embeds, dim=0) + pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0) + text_ids = torch.stack(text_ids, dim=0) + latents= torch.stack(latents, dim=0) + images= torch.stack(images, dim=0) + #latents = torch.stack(latents, dim=0) + return prompt_embeds, pooled_prompt_embeds, text_ids, caption, latents, images + + +if __name__ == "__main__": + dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function + ) + for prompt_embed, prompt_attention_mask, caption in dataloader: + print( + prompt_embed.shape, + prompt_attention_mask.shape, + caption + ) + import pdb + + pdb.set_trace() \ No newline at end of file diff --git a/fastvideo/dataset/latent_flux_rl_datasets.py b/fastvideo/dataset/latent_flux_rl_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..2c80cfa5862dfedfcf9efb4bdf9a58b2b766dea4 --- /dev/null +++ b/fastvideo/dataset/latent_flux_rl_datasets.py @@ -0,0 +1,110 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import torch +from torch.utils.data import Dataset +import json +import os +import random + + +class LatentDataset(Dataset): + def __init__( + self, json_path, num_latent_t, cfg_rate, + ): + # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path + self.json_path = json_path + self.cfg_rate = cfg_rate + self.datase_dir_path = os.path.dirname(json_path) + #self.video_dir = os.path.join(self.datase_dir_path, "video") + #self.latent_dir = os.path.join(self.datase_dir_path, "latent") + self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed") + self.pooled_prompt_embeds_dir = os.path.join( + self.datase_dir_path, "pooled_prompt_embeds" + ) + self.text_ids_dir = os.path.join( + self.datase_dir_path, "text_ids" + ) + with open(self.json_path, "r") as f: + self.data_anno = json.load(f) + # json.load(f) already keeps the order + # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path']) + self.num_latent_t = num_latent_t + # just zero embeddings [256, 4096] + self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32) + # 256 zeros + self.uncond_prompt_mask = torch.zeros(256).bool() + self.lengths = [ + data_item["length"] if "length" in data_item else 1 + for data_item in self.data_anno + ] + + def __getitem__(self, idx): + #latent_file = self.data_anno[idx]["latent_path"] + prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] + pooled_prompt_embeds_file = self.data_anno[idx]["pooled_prompt_embeds_path"] + text_ids_file = self.data_anno[idx]["text_ids"] + if random.random() < self.cfg_rate: + prompt_embed = self.uncond_prompt_embed + else: + prompt_embed = torch.load( + os.path.join(self.prompt_embed_dir, prompt_embed_file), + map_location="cpu", + weights_only=True, + ) + pooled_prompt_embeds = torch.load( + os.path.join( + self.pooled_prompt_embeds_dir, pooled_prompt_embeds_file + ), + map_location="cpu", + weights_only=True, + ) + text_ids = torch.load( + os.path.join( + self.text_ids_dir, text_ids_file + ), + map_location="cpu", + weights_only=True, + ) + return prompt_embed, pooled_prompt_embeds, text_ids, self.data_anno[idx]['caption'] + + def __len__(self): + return len(self.data_anno) + + +def latent_collate_function(batch): + # return latent, prompt, latent_attn_mask, text_attn_mask + # latent_attn_mask: # b t h w + # text_attn_mask: b 1 l + # needs to check if the latent/prompt' size and apply padding & attn mask + prompt_embeds, pooled_prompt_embeds, text_ids, caption = zip(*batch) + # attn mask + prompt_embeds = torch.stack(prompt_embeds, dim=0) + pooled_prompt_embeds = torch.stack(pooled_prompt_embeds, dim=0) + text_ids = torch.stack(text_ids, dim=0) + #latents = torch.stack(latents, dim=0) + return prompt_embeds, pooled_prompt_embeds, text_ids, caption + + +if __name__ == "__main__": + dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function + ) + for prompt_embed, prompt_attention_mask, caption in dataloader: + print( + prompt_embed.shape, + prompt_attention_mask.shape, + caption + ) + import pdb + + pdb.set_trace() \ No newline at end of file diff --git a/fastvideo/dataset/latent_qwenimage_rl_datasets.py b/fastvideo/dataset/latent_qwenimage_rl_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f2b4457b0f571028090040625f3ff14bfc26ad --- /dev/null +++ b/fastvideo/dataset/latent_qwenimage_rl_datasets.py @@ -0,0 +1,90 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import torch +from torch.utils.data import Dataset +import json +import os +import random + + +class LatentDataset(Dataset): + def __init__( + self, json_path, num_latent_t, cfg_rate, + ): + # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path + self.json_path = json_path + self.cfg_rate = cfg_rate + self.datase_dir_path = os.path.dirname(json_path) + #self.video_dir = os.path.join(self.datase_dir_path, "video") + #self.latent_dir = os.path.join(self.datase_dir_path, "latent") + self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed") + self.prompt_attention_mask_dir = os.path.join( + self.datase_dir_path, "prompt_attention_mask" + ) + with open(self.json_path, "r") as f: + self.data_anno = json.load(f) + # json.load(f) already keeps the order + # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path']) + self.num_latent_t = num_latent_t + # just zero embeddings [256, 4096] + self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32) + # 256 zeros + self.uncond_prompt_mask = torch.zeros(256).bool() + self.lengths = [ + data_item["length"] if "length" in data_item else 1 + for data_item in self.data_anno + ] + + def __getitem__(self, idx): + #latent_file = self.data_anno[idx]["latent_path"] + prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] + prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"] + if random.random() < self.cfg_rate: + prompt_embed = self.uncond_prompt_embed + prompt_attention_mask = self.uncond_prompt_mask + else: + prompt_embed = torch.load( + os.path.join(self.prompt_embed_dir, prompt_embed_file), + map_location="cpu", + weights_only=True, + ) + prompt_attention_mask = torch.load( + os.path.join( + self.prompt_attention_mask_dir, prompt_attention_mask_file + ), + map_location="cpu", + weights_only=True, + ) + return prompt_embed, prompt_attention_mask, self.data_anno[idx]['caption'], self.data_anno[idx]['original_length'] + + def __len__(self): + return len(self.data_anno) + + +def latent_collate_function(batch): + # return latent, prompt, latent_attn_mask, text_attn_mask + # latent_attn_mask: # b t h w + # text_attn_mask: b 1 l + # needs to check if the latent/prompt' size and apply padding & attn mask + prompt_embeds, prompt_attention_masks, caption, original_length = zip(*batch) + # attn mask + prompt_embeds = torch.stack(prompt_embeds, dim=0) + prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0) + + # Convert original_length to tensor + original_length = torch.tensor(original_length, dtype=torch.long) + + # Convert caption to list + caption = list(caption) + + #latents = torch.stack(latents, dim=0) + return prompt_embeds, prompt_attention_masks, caption, original_length \ No newline at end of file diff --git a/fastvideo/dataset/latent_rl_datasets.py b/fastvideo/dataset/latent_rl_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c8d6107c931a507ad477fb0abdaa10ee7fe40c --- /dev/null +++ b/fastvideo/dataset/latent_rl_datasets.py @@ -0,0 +1,99 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +import torch +from torch.utils.data import Dataset +import json +import os +import random + + +class LatentDataset(Dataset): + def __init__( + self, json_path, num_latent_t, cfg_rate, + ): + # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path + self.json_path = json_path + self.cfg_rate = cfg_rate + self.datase_dir_path = os.path.dirname(json_path) + #self.video_dir = os.path.join(self.datase_dir_path, "video") + #self.latent_dir = os.path.join(self.datase_dir_path, "latent") + self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed") + self.prompt_attention_mask_dir = os.path.join( + self.datase_dir_path, "prompt_attention_mask" + ) + with open(self.json_path, "r") as f: + self.data_anno = json.load(f) + # json.load(f) already keeps the order + # self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path']) + self.num_latent_t = num_latent_t + # just zero embeddings [256, 4096] + self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32) + # 256 zeros + self.uncond_prompt_mask = torch.zeros(256).bool() + self.lengths = [ + data_item["length"] if "length" in data_item else 1 + for data_item in self.data_anno + ] + + def __getitem__(self, idx): + #latent_file = self.data_anno[idx]["latent_path"] + prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] + prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"] + if random.random() < self.cfg_rate: + prompt_embed = self.uncond_prompt_embed + prompt_attention_mask = self.uncond_prompt_mask + else: + prompt_embed = torch.load( + os.path.join(self.prompt_embed_dir, prompt_embed_file), + map_location="cpu", + weights_only=True, + ) + prompt_attention_mask = torch.load( + os.path.join( + self.prompt_attention_mask_dir, prompt_attention_mask_file + ), + map_location="cpu", + weights_only=True, + ) + return prompt_embed, prompt_attention_mask, self.data_anno[idx]['caption'] + + def __len__(self): + return len(self.data_anno) + + +def latent_collate_function(batch): + # return latent, prompt, latent_attn_mask, text_attn_mask + # latent_attn_mask: # b t h w + # text_attn_mask: b 1 l + # needs to check if the latent/prompt' size and apply padding & attn mask + prompt_embeds, prompt_attention_masks, caption = zip(*batch) + # attn mask + prompt_embeds = torch.stack(prompt_embeds, dim=0) + prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0) + #latents = torch.stack(latents, dim=0) + return prompt_embeds, prompt_attention_masks, caption + + +if __name__ == "__main__": + dataset = LatentDataset("data/rl_embeddings/videos2caption.json", num_latent_t=28, cfg_rate=0.0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function + ) + for prompt_embed, prompt_attention_mask, caption in dataloader: + print( + prompt_embed.shape, + prompt_attention_mask.shape, + caption + ) + import pdb + + pdb.set_trace() diff --git a/fastvideo/dataset/t2v_datasets.py b/fastvideo/dataset/t2v_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a177f9df7008d69d329926165c5dae831cbeb55d --- /dev/null +++ b/fastvideo/dataset/t2v_datasets.py @@ -0,0 +1,351 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import json +import math +import os +import random +from collections import Counter +from os.path import join as opj + +import numpy as np +import torch +import torchvision +from einops import rearrange +from PIL import Image +from torch.utils.data import Dataset + +from fastvideo.utils.dataset_utils import DecordInit +from fastvideo.utils.logging_ import main_print + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class DataSetProg(metaclass=SingletonMeta): + + def __init__(self): + self.cap_list = [] + self.elements = [] + self.num_workers = 1 + self.n_elements = 0 + self.worker_elements = dict() + self.n_used_elements = dict() + + def set_cap_list(self, num_workers, cap_list, n_elements): + self.num_workers = num_workers + self.cap_list = cap_list + self.n_elements = n_elements + self.elements = list(range(n_elements)) + random.shuffle(self.elements) + print(f"n_elements: {len(self.elements)}", flush=True) + + for i in range(self.num_workers): + self.n_used_elements[i] = 0 + per_worker = int( + math.ceil(len(self.elements) / float(self.num_workers))) + start = i * per_worker + end = min(start + per_worker, len(self.elements)) + self.worker_elements[i] = self.elements[start:end] + + def get_item(self, work_info): + if work_info is None: + worker_id = 0 + else: + worker_id = work_info.id + + idx = self.worker_elements[worker_id][ + self.n_used_elements[worker_id] % + len(self.worker_elements[worker_id])] + self.n_used_elements[worker_id] += 1 + return idx + + +dataset_prog = DataSetProg() + + +def filter_resolution(h, + w, + max_h_div_w_ratio=17 / 16, + min_h_div_w_ratio=8 / 16): + if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio: + return True + return False + + +class T2V_dataset(Dataset): + + def __init__(self, args, transform, temporal_sample, tokenizer, + transform_topcrop): + self.data = args.data_merge_path + self.num_frames = args.num_frames + self.train_fps = args.train_fps + self.use_image_num = args.use_image_num + self.transform = transform + self.transform_topcrop = transform_topcrop + self.temporal_sample = temporal_sample + self.tokenizer = tokenizer + self.text_max_length = args.text_max_length + self.cfg = args.cfg + self.speed_factor = args.speed_factor + self.max_height = args.max_height + self.max_width = args.max_width + self.drop_short_ratio = args.drop_short_ratio + assert self.speed_factor >= 1 + self.v_decoder = DecordInit() + self.video_length_tolerance_range = args.video_length_tolerance_range + self.support_Chinese = True + if "mt5" not in args.text_encoder_name: + self.support_Chinese = False + + cap_list = self.get_cap_list() + + assert len(cap_list) > 0 + cap_list, self.sample_num_frames = self.define_frame_index(cap_list) + self.lengths = self.sample_num_frames + + n_elements = len(cap_list) + dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, + n_elements) + + print(f"video length: {len(dataset_prog.cap_list)}", flush=True) + + def set_checkpoint(self, n_used_elements): + for i in range(len(dataset_prog.n_used_elements)): + dataset_prog.n_used_elements[i] = n_used_elements + + def __len__(self): + return dataset_prog.n_elements + + def __getitem__(self, idx): + + data = self.get_data(idx) + return data + + def get_data(self, idx): + path = dataset_prog.cap_list[idx]["path"] + if path.endswith(".mp4"): + return self.get_video(idx) + else: + return self.get_image(idx) + + def get_video(self, idx): + video_path = dataset_prog.cap_list[idx]["path"] + assert os.path.exists(video_path), f"file {video_path} do not exist!" + frame_indices = dataset_prog.cap_list[idx]["sample_frame_index"] + torchvision_video, _, metadata = torchvision.io.read_video( + video_path, output_format="TCHW") + video = torchvision_video[frame_indices] + video = self.transform(video) + video = rearrange(video, "t c h w -> c t h w") + video = video.to(torch.uint8) + assert video.dtype == torch.uint8 + + h, w = video.shape[-2:] + assert ( + h / w <= 17 / 16 and h / w >= 8 / 16 + ), f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}" + + video = video.float() / 127.5 - 1.0 + + text = dataset_prog.cap_list[idx]["cap"] + if not isinstance(text, list): + text = [text] + text = [random.choice(text)] + + text = text[0] if random.random() > self.cfg else "" + text_tokens_and_mask = self.tokenizer( + text, + max_length=self.text_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = text_tokens_and_mask["input_ids"] + cond_mask = text_tokens_and_mask["attention_mask"] + return dict( + pixel_values=video, + text=text, + input_ids=input_ids, + cond_mask=cond_mask, + path=video_path, + ) + + def get_image(self, idx): + image_data = dataset_prog.cap_list[ + idx] # [{'path': path, 'cap': cap}, ...] + + image = Image.open(image_data["path"]).convert("RGB") # [h, w, c] + image = torch.from_numpy(np.array(image)) # [h, w, c] + image = rearrange(image, "h w c -> c h w").unsqueeze(0) # [1 c h w] + # for i in image: + # h, w = i.shape[-2:] + # assert h / w <= 17 / 16 and h / w >= 8 / 16, f'Only image with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But found ratio is {round(h / w, 2)} with the shape of {i.shape}' + + image = (self.transform_topcrop(image) if "human_images" + in image_data["path"] else self.transform(image) + ) # [1 C H W] -> num_img [1 C H W] + image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W] + + image = image.float() / 127.5 - 1.0 + + caps = (image_data["cap"] if isinstance(image_data["cap"], list) else + [image_data["cap"]]) + caps = [random.choice(caps)] + text = caps + input_ids, cond_mask = [], [] + text = text[0] if random.random() > self.cfg else "" + text_tokens_and_mask = self.tokenizer( + text, + max_length=self.text_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = text_tokens_and_mask["input_ids"] # 1, l + cond_mask = text_tokens_and_mask["attention_mask"] # 1, l + return dict( + pixel_values=image, + text=text, + input_ids=input_ids, + cond_mask=cond_mask, + path=image_data["path"], + ) + + def define_frame_index(self, cap_list): + new_cap_list = [] + sample_num_frames = [] + cnt_too_long = 0 + cnt_too_short = 0 + cnt_no_cap = 0 + cnt_no_resolution = 0 + cnt_resolution_mismatch = 0 + cnt_movie = 0 + cnt_img = 0 + for i in cap_list: + path = i["path"] + cap = i.get("cap", None) + # ======no caption===== + if cap is None: + cnt_no_cap += 1 + continue + if path.endswith(".mp4"): + # ======no fps and duration===== + duration = i.get("duration", None) + fps = i.get("fps", None) + if fps is None or duration is None: + continue + + # ======resolution mismatch===== + resolution = i.get("resolution", None) + if resolution is None: + cnt_no_resolution += 1 + continue + else: + if (resolution.get("height", None) is None + or resolution.get("width", None) is None): + cnt_no_resolution += 1 + continue + height, width = i["resolution"]["height"], i["resolution"][ + "width"] + aspect = self.max_height / self.max_width + hw_aspect_thr = 1.5 + is_pick = filter_resolution( + height, + width, + max_h_div_w_ratio=hw_aspect_thr * aspect, + min_h_div_w_ratio=1 / hw_aspect_thr * aspect, + ) + if not is_pick: + print("resolution mismatch") + cnt_resolution_mismatch += 1 + continue + + # import ipdb;ipdb.set_trace() + i["num_frames"] = math.ceil(fps * duration) + # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. + if i["num_frames"] / fps > self.video_length_tolerance_range * ( + self.num_frames / self.train_fps * self.speed_factor + ): # too long video is not suitable for this training stage (self.num_frames) + cnt_too_long += 1 + continue + + # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) + frame_interval = fps / self.train_fps + start_frame_idx = 0 + frame_indices = np.arange(start_frame_idx, i["num_frames"], + frame_interval).astype(int) + + # comment out it to enable dynamic frames training + if (len(frame_indices) < self.num_frames + and random.random() < self.drop_short_ratio): + cnt_too_short += 1 + continue + + # too long video will be temporal-crop randomly + if len(frame_indices) > self.num_frames: + begin_index, end_index = self.temporal_sample( + len(frame_indices)) + frame_indices = frame_indices[begin_index:end_index] + # frame_indices = frame_indices[:self.num_frames] # head crop + i["sample_frame_index"] = frame_indices.tolist() + new_cap_list.append(i) + i["sample_num_frames"] = len( + i["sample_frame_index"] + ) # will use in dataloader(group sampler) + sample_num_frames.append(i["sample_num_frames"]) + elif path.endswith(".jpg"): # image + cnt_img += 1 + new_cap_list.append(i) + i["sample_num_frames"] = 1 + sample_num_frames.append(i["sample_num_frames"]) + else: + raise NameError( + f"Unknown file extension {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" + ) + # import ipdb;ipdb.set_trace() + main_print( + f"no_cap: {cnt_no_cap}, too_long: {cnt_too_long}, too_short: {cnt_too_short}, " + f"no_resolution: {cnt_no_resolution}, resolution_mismatch: {cnt_resolution_mismatch}, " + f"Counter(sample_num_frames): {Counter(sample_num_frames)}, cnt_movie: {cnt_movie}, cnt_img: {cnt_img}, " + f"before filter: {len(cap_list)}, after filter: {len(new_cap_list)}" + ) + return new_cap_list, sample_num_frames + + def decord_read(self, path, frame_indices): + decord_vr = self.v_decoder(path) + video_data = decord_vr.get_batch(frame_indices).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, + 2) # (T, H, W, C) -> (T C H W) + return video_data + + def read_jsons(self, data): + cap_lists = [] + with open(data, "r") as f: + folder_anno = [ + i.strip().split(",") for i in f.readlines() + if len(i.strip()) > 0 + ] + print(folder_anno) + for folder, anno in folder_anno: + with open(anno, "r") as f: + sub_list = json.load(f) + for i in range(len(sub_list)): + sub_list[i]["path"] = opj(folder, sub_list[i]["path"]) + cap_lists += sub_list + return cap_lists + + def get_cap_list(self): + cap_lists = self.read_jsons(self.data) + return cap_lists diff --git a/fastvideo/dataset/transform.py b/fastvideo/dataset/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..37cad8de637bbb4eff23e076f8eeff53aa4c0414 --- /dev/null +++ b/fastvideo/dataset/transform.py @@ -0,0 +1,647 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import numbers +import random + +import torch +from PIL import Image + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), + resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple( + round(x * scale) for x in pil_image.size), + resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y:crop_y + image_size, + crop_x:crop_x + image_size]) + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i:i + h, j:j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError( + f"target size should be tuple (height, width), instead got {target_size}" + ) + return torch.nn.functional.interpolate( + clip, + size=target_size, + mode=interpolation_mode, + align_corners=True, + antialias=True, + ) + + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError( + f"target size should be tuple (height, width), instead got {target_size}" + ) + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate( + clip, + scale_factor=scale_, + mode=interpolation_mode, + align_corners=True, + antialias=True, + ) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def center_crop_th_tw(clip, th, tw, top_crop): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + + # import ipdb;ipdb.set_trace() + h, w = clip.size(-2), clip.size(-1) + tr = th / tw + if h / w > tr: + new_h = int(w * tr) + new_w = w + else: + new_h = h + new_w = int(h / tr) + + i = 0 if top_crop else int(round((h - new_h) / 2.0)) + j = int(round((w - new_w) / 2.0)) + return crop(clip, i, j, new_h, new_w) + + +def random_shift_crop(clip): + """ + Slide along the long edge, with the short edge as crop size + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + short_edge = h + else: + short_edge = w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() + return crop(clip, i, j, th, tw) + + +def normalize_video(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % + str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class RandomCropVideo: + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError( + f"Required crop size {(th, tw)} is larger than input image size {(h, w)}" + ) + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class SpatialStrideCropVideo: + + def __init__(self, stride): + self.stride = stride + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: cropped video clip by stride. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + + th, tw = h // self.stride * self.stride, w // self.stride * self.stride + + return 0, 0, th, tw # from top-left + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class LongSideResizeVideo: + """ + First use the long side, + then resize to the specified size + """ + + def __init__( + self, + size, + skip_low_resolution=False, + interpolation_mode="bilinear", + ): + self.size = size + self.skip_low_resolution = skip_low_resolution + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized video clip. + size is (T, C, 512, *) or (T, C, *, 512) + """ + _, _, h, w = clip.shape + if self.skip_low_resolution and max(h, w) <= self.size: + return clip + if h > w: + w = int(w * self.size / h) + h = self.size + else: + h = int(h * self.size / w) + w = self.size + resize_clip = resize(clip, + target_size=(h, w), + interpolation_mode=self.interpolation_mode) + return resize_clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class CenterCropResizeVideo: + """ + First use the short side for cropping length, + center crop video, then resize to the specified size + """ + + def __init__( + self, + size, + top_crop=False, + interpolation_mode="bilinear", + ): + if len(size) != 2: + raise ValueError( + f"size should be tuple (height, width), instead got {size}") + self.size = size + self.top_crop = top_crop + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + # clip_center_crop = center_crop_using_short_edge(clip) + clip_center_crop = center_crop_th_tw(clip, + self.size[0], + self.size[1], + top_crop=self.top_crop) + # import ipdb;ipdb.set_trace() + clip_center_crop_resize = resize( + clip_center_crop, + target_size=self.size, + interpolation_mode=self.interpolation_mode, + ) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class UCFCenterCropVideo: + """ + First scale to the specified size in equal proportion to the short edge, + then center cropping + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError( + f"size should be tuple (height, width), instead got {size}" + ) + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, + target_size=self.size, + interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class KineticsRandomCropResizeVideo: + """ + Slide along the long edge, with the short edge as crop size. And resie to the desired size. + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError( + f"size should be tuple (height, width), instead got {size}" + ) + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + clip_random_crop = random_shift_crop(clip) + clip_resize = resize(clip_random_crop, self.size, + self.interpolation_mode) + return clip_resize + + +class CenterCropVideo: + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError( + f"size should be tuple (height, width), instead got {size}" + ) + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class Normalize: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class Normalize255: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return normalize_video(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (T, C, H, W) + Return: + clip (torch.tensor): Size is (T, C, H, W) + """ + if random.random() < self.p: + clip = hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + + +class DynamicSampleDuration(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, t_stride, extra_1): + self.t_stride = t_stride + self.extra_1 = extra_1 + + def __call__(self, t, h, w): + if self.extra_1: + t = t - 1 + truncate_t_list = list( + range(t + 1))[t // 2:][::self.t_stride] # need half at least + truncate_t = random.choice(truncate_t_list) + if self.extra_1: + truncate_t = truncate_t + 1 + return 0, truncate_t + + +if __name__ == "__main__": + import os + + import numpy as np + import torchvision.io as io + from torchvision import transforms + from torchvision.utils import save_image + + vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", + pts_unit="sec", + output_format="TCHW") + + trans = transforms.Compose([ + Normalize255(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + inplace=True), + ]) + + target_video_len = 32 + frame_interval = 1 + total_frames = len(vframes) + print(total_frames) + + temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) + + # Sampling video frames + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + # print(start_frame_ind) + # print(end_frame_ind) + assert end_frame_ind - start_frame_ind >= target_video_len + frame_indice = np.linspace(start_frame_ind, + end_frame_ind - 1, + target_video_len, + dtype=int) + print(frame_indice) + + select_vframes = vframes[frame_indice] + print(select_vframes.shape) + print(select_vframes.dtype) + + select_vframes_trans = trans(select_vframes) + print(select_vframes_trans.shape) + print(select_vframes_trans.dtype) + + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * + 255).to(dtype=torch.uint8) + print(select_vframes_trans_int.dtype) + print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) + + io.write_video("./test.avi", + select_vframes_trans_int.permute(0, 2, 3, 1), + fps=8) + + for i in range(target_video_len): + save_image( + select_vframes_trans[i], + os.path.join("./test000", "%04d.png" % i), + normalize=True, + value_range=(-1, 1), + ) diff --git a/fastvideo/distill/__init__.py b/fastvideo/distill/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fastvideo/distill/__pycache__/__init__.cpython-312.pyc b/fastvideo/distill/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12ea6ae82299391c994893a2e451a246fc20fbd2 Binary files /dev/null and b/fastvideo/distill/__pycache__/__init__.cpython-312.pyc differ diff --git a/fastvideo/distill/__pycache__/solver.cpython-312.pyc b/fastvideo/distill/__pycache__/solver.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b566bb6d4ed2da1583aa9331f4c5c2f57f64235e Binary files /dev/null and b/fastvideo/distill/__pycache__/solver.cpython-312.pyc differ diff --git a/fastvideo/distill/discriminator.py b/fastvideo/distill/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..07931aa2c832ad64ce1b685b3ad31f0a8011ccb2 --- /dev/null +++ b/fastvideo/distill/discriminator.py @@ -0,0 +1,84 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import torch.nn as nn +from diffusers.utils import logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DiscriminatorHead(nn.Module): + + def __init__(self, input_channel, output_channel=1): + super().__init__() + inner_channel = 1024 + self.conv1 = nn.Sequential( + nn.Conv2d(input_channel, inner_channel, 1, 1, 0), + nn.GroupNorm(32, inner_channel), + nn.LeakyReLU( + inplace=True + ), # use LeakyReLu instead of GELU shown in the paper to save memory + ) + self.conv2 = nn.Sequential( + nn.Conv2d(inner_channel, inner_channel, 1, 1, 0), + nn.GroupNorm(32, inner_channel), + nn.LeakyReLU( + inplace=True + ), # use LeakyReLu instead of GELU shown in the paper to save memory + ) + + self.conv_out = nn.Conv2d(inner_channel, output_channel, 1, 1, 0) + + def forward(self, x): + b, twh, c = x.shape + t = twh // (30 * 53) + x = x.view(-1, 30 * 53, c) + x = x.permute(0, 2, 1) + x = x.view(b * t, c, 30, 53) + x = self.conv1(x) + x = self.conv2(x) + x + x = self.conv_out(x) + return x + + +class Discriminator(nn.Module): + + def __init__( + self, + stride=8, + num_h_per_head=1, + adapter_channel_dims=[3072], + total_layers=48, + ): + super().__init__() + adapter_channel_dims = adapter_channel_dims * (total_layers // stride) + self.stride = stride + self.num_h_per_head = num_h_per_head + self.head_num = len(adapter_channel_dims) + self.heads = nn.ModuleList([ + nn.ModuleList([ + DiscriminatorHead(adapter_channel) + for _ in range(self.num_h_per_head) + ]) for adapter_channel in adapter_channel_dims + ]) + + def forward(self, features): + outputs = [] + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + assert len(features) == len(self.heads) + for i in range(0, len(features)): + for h in self.heads[i]: + # out = torch.utils.checkpoint.checkpoint( + # create_custom_forward(h), + # features[i], + # use_reentrant=False + # ) + out = h(features[i]) + outputs.append(out) + return outputs diff --git a/fastvideo/distill/solver.py b/fastvideo/distill/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..11591ba3b9d81d9657a0c599a165827538038916 --- /dev/null +++ b/fastvideo/distill/solver.py @@ -0,0 +1,310 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, logging + +from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class PCMFMSchedulerOutput(BaseOutput): + prev_sample: torch.FloatTensor + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1, ) * (len(x_shape) - 1))) + + +class PCMFMScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + pcm_timesteps: int = 50, + linear_quadratic=False, + linear_quadratic_threshold=0.025, + linear_range=0.5, + ): + if linear_quadratic: + linear_steps = int(num_train_timesteps * linear_range) + sigmas = linear_quadratic_schedule(num_train_timesteps, + linear_quadratic_threshold, + linear_steps) + sigmas = torch.tensor(sigmas).to(dtype=torch.float32) + else: + timesteps = np.linspace(1, + num_train_timesteps, + num_train_timesteps, + dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + self.euler_timesteps = (np.arange(1, pcm_timesteps + 1) * + (num_train_timesteps // + pcm_timesteps)).round().astype(np.int64) - 1 + self.sigmas = sigmas.numpy()[::-1][self.euler_timesteps] + self.sigmas = torch.from_numpy((self.sigmas[::-1].copy())) + self.timesteps = self.sigmas * num_train_timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, + num_inference_steps: int, + device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + inference_indices = np.linspace(0, + self.config.pcm_timesteps, + num=num_inference_steps, + endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + inference_indices = torch.from_numpy(inference_indices).long() + + self.sigmas_ = self.sigmas[inference_indices] + timesteps = self.sigmas_ * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas_ = torch.cat( + [self.sigmas_, + torch.zeros(1, device=self.sigmas_.device)]) + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[PCMFMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor)): + raise ValueError(( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep."), ) + + if self.step_index is None: + self._init_step_index(timestep) + + sample = sample.to(torch.float32) + + sigma = self.sigmas_[self.step_index] + + denoised = sample - model_output * sigma + derivative = (sample - denoised) / sigma + + dt = self.sigmas_[self.step_index + 1] - sigma + prev_sample = sample + derivative * dt + prev_sample = prev_sample.to(model_output.dtype) + self._step_index += 1 + + if not return_dict: + return (prev_sample, ) + + return PCMFMSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +class EulerSolver: + + def __init__(self, sigmas, timesteps=1000, euler_timesteps=50): + self.step_ratio = timesteps // euler_timesteps + self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * + self.step_ratio).round().astype(np.int64) - 1 + self.euler_timesteps_prev = np.asarray( + [0] + self.euler_timesteps[:-1].tolist()) + self.sigmas = sigmas[self.euler_timesteps] + self.sigmas_prev = np.asarray( + [sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist() + ) # either use sigma0 or 0 + + self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long() + self.euler_timesteps_prev = torch.from_numpy( + self.euler_timesteps_prev).long() + self.sigmas = torch.from_numpy(self.sigmas) + self.sigmas_prev = torch.from_numpy(self.sigmas_prev) + + def to(self, device): + self.euler_timesteps = self.euler_timesteps.to(device) + self.euler_timesteps_prev = self.euler_timesteps_prev.to(device) + + self.sigmas = self.sigmas.to(device) + self.sigmas_prev = self.sigmas_prev.to(device) + return self + + def euler_step(self, sample, model_pred, timestep_index): + sigma = extract_into_tensor(self.sigmas, timestep_index, + model_pred.shape) + sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, + model_pred.shape) + x_prev = sample + (sigma_prev - sigma) * model_pred + return x_prev + + def euler_style_multiphase_pred( + self, + sample, + model_pred, + timestep_index, + multiphase, + is_target=False, + ): + inference_indices = np.linspace(0, + len(self.euler_timesteps), + num=multiphase, + endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + inference_indices = (torch.from_numpy(inference_indices).long().to( + self.euler_timesteps.device)) + expanded_timestep_index = timestep_index.unsqueeze(1).expand( + -1, inference_indices.size(0)) + valid_indices_mask = expanded_timestep_index >= inference_indices + last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax( + dim=1) + last_valid_index = inference_indices.size(0) - 1 - last_valid_index + timestep_index_end = inference_indices[last_valid_index] + + if is_target: + sigma = extract_into_tensor(self.sigmas_prev, timestep_index, + sample.shape) + else: + sigma = extract_into_tensor(self.sigmas, timestep_index, + sample.shape) + sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end, + sample.shape) + x_prev = sample + (sigma_prev - sigma) * model_pred + + return x_prev, timestep_index_end diff --git a/fastvideo/models/.DS_Store b/fastvideo/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/fastvideo/models/.DS_Store differ diff --git a/fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc b/fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fdd10ea9f82da01db1c221023a1a8e144e5bf7c Binary files /dev/null and b/fastvideo/models/__pycache__/flash_attn_no_pad.cpython-310.pyc differ diff --git a/fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc b/fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00ee02102a42de3c2a0aba427dc12a0ad1e626c0 Binary files /dev/null and b/fastvideo/models/__pycache__/flash_attn_no_pad.cpython-312.pyc differ diff --git a/fastvideo/models/flash_attn_no_pad.py b/fastvideo/models/flash_attn_no_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..3bffc694707eb658dddb227f7117377386603a35 --- /dev/null +++ b/fastvideo/models/flash_attn_no_pad.py @@ -0,0 +1,37 @@ +from einops import rearrange +from flash_attn import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import pad_input, unpad_input + + +def flash_attn_no_pad(qkv, + key_padding_mask, + causal=False, + dropout_p=0.0, + softmax_scale=None): + # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27 + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input( + x, key_padding_mask) + + x_unpad = rearrange(x_unpad, + "nnz (three h d) -> nnz three h d", + three=3, + h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=softmax_scale, + causal=causal, + ) + output = rearrange( + pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, + batch_size, seqlen), + "b s (h d) -> b s h d", + h=nheads, + ) + return output diff --git a/fastvideo/reward_model/clip_score.py b/fastvideo/reward_model/clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..6266b5a8ec4f8363c1318adda9dab24a2c287e70 --- /dev/null +++ b/fastvideo/reward_model/clip_score.py @@ -0,0 +1,98 @@ +import numpy as np +import torch +from torchvision import transforms +import torch.nn.functional as F +import clip +from PIL import Image +from typing import List, Tuple, Union +from PIL import Image +import os +from open_clip import create_model_from_pretrained, get_tokenizer +import argparse + + + +@torch.no_grad() +def calculate_clip_score(prompts, images, clip_model, device): + texts = clip.tokenize(prompts, truncate=True).to(device=device) + + image_features = clip_model.encode_image(images) + text_features = clip_model.encode_text(texts) + + scores = F.cosine_similarity(image_features, text_features) + return scores + + +class CLIPScoreRewardModel(): + def __init__(self, clip_model_path, device, http_proxy=None, https_proxy=None, clip_model_type='ViT-H-14'): + super().__init__() + if http_proxy: + os.environ["http_proxy"] = http_proxy + if https_proxy: + os.environ["https_proxy"] = https_proxy + self.clip_model_path = clip_model_path + self.clip_model_type = clip_model_type + self.device = device + self.load_model() + + def load_model(self, logger=None): + self.model, self.preprocess = create_model_from_pretrained(self.clip_model_path) + self.tokenizer = get_tokenizer(self.clip_model_type) + self.model.to(self.device) + + # calculate clip score directly, such as for rerank + @torch.no_grad() + def __call__( + self, + prompts: Union[str, List[str]], + images: List[Image.Image] + ) -> List[float]: + if isinstance(prompts, str): + prompts = [prompts] * len(images) + if len(prompts) != len(images): + raise ValueError("prompts must have the same length as images") + + scores = [] + for prompt, image in zip(prompts, images): + image_proc = self.preprocess(image).unsqueeze(0).to(self.device) + text = self.tokenizer( + [prompt], + context_length=self.model.context_length + ).to(self.device) + + image_features = self.model.encode_image(image_proc) + text_features = self.model.encode_text(text) + image_features = F.normalize(image_features, dim=-1) + text_features = F.normalize(text_features, dim=-1) + + clip_score = image_features @ text_features.T + + scores.append(clip_score.item()) + + return scores + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="PickScore Reward Model") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (e.g., 'cuda', 'cpu')") + parser.add_argument("--http_proxy", type=str, default=None, help="HTTP proxy URL") + parser.add_argument("--https_proxy", type=str, default=None, help="HTTPS proxy URL") + args = parser.parse_args() + + # Example usage + clip_model_path = 'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384' + reward_model = CLIPScoreRewardModel( + clip_model_path, + device=args.device, + http_proxy=args.http_proxy, + https_proxy=args.https_proxy + ) + + image_path = "assets/reward_demo.jpg" + prompt = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting." + + image = Image.open(image_path).convert("RGB") + clip_score = reward_model(prompt, [image]) + + print(f"CLIP Score: {clip_score}") \ No newline at end of file diff --git a/fastvideo/reward_model/hps_score.py b/fastvideo/reward_model/hps_score.py new file mode 100644 index 0000000000000000000000000000000000000000..ee62cc19ffb599a8bd27d8efddbf75a053905b86 --- /dev/null +++ b/fastvideo/reward_model/hps_score.py @@ -0,0 +1,79 @@ +from typing import Union, List +import argparse +import torch +from PIL import Image + +from HPSv2.hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer + + +class HPSClipRewardModel(object): + def __init__(self, device, clip_ckpt_path, hps_ckpt_path, model_name='ViT-H-14'): + self.device = device + self.clip_ckpt_path = clip_ckpt_path + self.hps_ckpt_path = hps_ckpt_path + self.model_name = model_name + self.reward_model, self.text_processor, self.img_processor = self.build_reward_model() + + def build_reward_model(self): + model, preprocess_train, img_preprocess_val = create_model_and_transforms( + self.model_name, + self.clip_ckpt_path, + precision='amp', + device=self.device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False + ) + + # Convert device name to proper format + if isinstance(self.device, int): + ml_device = str(self.device) + else: + ml_device = self.device + + if not ml_device.startswith('cuda'): + ml_device = f'cuda:{ml_device}' if ml_device.isdigit() else ml_device + + checkpoint = torch.load(self.hps_ckpt_path, map_location=ml_device) + model.load_state_dict(checkpoint['state_dict']) + text_processor = get_tokenizer(self.model_name) + reward_model = model.to(self.device) + reward_model.eval() + + return reward_model, text_processor, img_preprocess_val + + @torch.no_grad() + def __call__( + self, + images: Union[Image.Image, List[Image.Image]], + texts: Union[str, List[str]], + ): + if isinstance(images, Image.Image): + images = [images] + if isinstance(texts, str): + texts = [texts] + + rewards = [] + for image, text in zip(images, texts): + image = self.img_processor(image).unsqueeze(0).to(self.device, non_blocking=True) + text = self.text_processor([text]).to(device=self.device, non_blocking=True) + with torch.amp.autocast('cuda'): + outputs = self.reward_model(image, text) + image_features, text_features = outputs["image_features"], outputs["text_features"] + logits_per_image = image_features @ text_features.T + hps_score = torch.diagonal(logits_per_image) + + # reward is a tensor of shape (1,) --> list + rewards.append(hps_score.float().cpu().item()) + + return rewards diff --git a/fastvideo/reward_model/image_reward.py b/fastvideo/reward_model/image_reward.py new file mode 100644 index 0000000000000000000000000000000000000000..cb227f9557d5f3a6cc7767d67182bb936c55ec5e --- /dev/null +++ b/fastvideo/reward_model/image_reward.py @@ -0,0 +1,40 @@ +# Image-Reward: Copyied from https://github.com/THUDM/ImageReward +import os +from typing import Union, List +from PIL import Image + +import torch +try: + import ImageReward as RM +except: + raise Warning("ImageReward is required to be installed (`pip install image-reward`) when using ImageReward for post-training.") + + +class ImageRewardModel(object): + def __init__(self, model_name, device, http_proxy=None, https_proxy=None, med_config=None): + if http_proxy: + os.environ["http_proxy"] = http_proxy + if https_proxy: + os.environ["https_proxy"] = https_proxy + self.model_name = model_name if model_name else "ImageReward-v1.0" + self.device = device + self.med_config = med_config + self.build_reward_model() + + def build_reward_model(self): + self.model = RM.load(self.model_name, device=self.device, med_config=self.med_config) + + @torch.no_grad() + def __call__( + self, + images, + texts, + ): + if isinstance(texts, str): + texts = [texts] * len(images) + + rewards = [] + for image, text in zip(images, texts): + ranking, reward = self.model.inference_rank(text, [image]) + rewards.append(reward) + return rewards diff --git a/fastvideo/reward_model/pick_score.py b/fastvideo/reward_model/pick_score.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ba70bb5be1bfc21243338920e32e013bb48854 --- /dev/null +++ b/fastvideo/reward_model/pick_score.py @@ -0,0 +1,107 @@ +import os +import torch +import argparse +from typing import List, Tuple, Union +from transformers import AutoProcessor, AutoModel +from PIL import Image + + +class PickScoreRewardModel(object): + def __init__(self, device: str = "cuda", http_proxy=None, https_proxy=None, mean=18.0, std=8.0): + """ + Initialize PickScore reward model. + + Args: + device: Device to run the model on ('cuda' or 'cpu') + """ + if http_proxy: + os.environ["http_proxy"] = http_proxy + if https_proxy: + os.environ["https_proxy"] = https_proxy + self.device = device + self.processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + self.model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1" + self.mean = mean + self.std = std + + # Initialize model and processor + self.processor = AutoProcessor.from_pretrained(self.processor_name_or_path) + self.model = AutoModel.from_pretrained(self.model_pretrained_name_or_path).eval().to(device) + + @torch.no_grad() + def __call__( + self, + images: List[Image.Image], + prompts: Union[str, List[str]], + ) -> Tuple[List[float], List[float]]: + """ + Calculate probabilities and scores for images given a prompt. + + Args: + prompts: Text prompt to evaluate images against + images: List of PIL Images to evaluate + + Returns: + Tuple of (probabilities, scores) for each image + """ + if isinstance(prompts, str): + prompts = [prompts] * len(images) + if len(prompts) != len(images): + raise ValueError("prompts must have the same length as images") + + scores = [] + for prompt, image in zip(prompts, images): + # Preprocess images + image_inputs = self.processor( + images=[image], + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ).to(self.device) + + # Preprocess text + text_inputs = self.processor( + text=prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ).to(self.device) + + # Get embeddings + image_embs = self.model.get_image_features(**image_inputs) + image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) + + text_embs = self.model.get_text_features(**text_inputs) + text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) + + # Calculate scores + score = self.model.logit_scale.exp() * (text_embs @ image_embs.T)[0] + score = (score - self.mean) / self.std + scores.extend(score.cpu().tolist()) + + return scores + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="PickScore Reward Model") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (e.g., 'cuda', 'cpu')") + parser.add_argument("--http_proxy", type=str, default=None, help="HTTP proxy URL") + parser.add_argument("--https_proxy", type=str, default=None, help="HTTPS proxy URL") + args = parser.parse_args() + + # Example usage + reward_model = PickScoreRewardModel( + device=args.device, + http_proxy=args.http_proxy, + https_proxy=args.https_proxy, + ) + pil_images = [Image.open("assets/reward_demo.jpg")] + + prompt = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting." + + scores = reward_model(pil_images, [prompt] * len(pil_images)) + scores = [(s * reward_model.std + reward_model.mean) / 100.0 for s in scores] + print(f"scores: {scores}") + diff --git a/fastvideo/reward_model/unified_reward.py b/fastvideo/reward_model/unified_reward.py new file mode 100644 index 0000000000000000000000000000000000000000..12b06a22dbbdb6dc6453ad1d5a93bceb6e57a93e --- /dev/null +++ b/fastvideo/reward_model/unified_reward.py @@ -0,0 +1,333 @@ +import argparse +import base64 +import os +import re +import requests +import time +import concurrent.futures +from io import BytesIO +from typing import List, Optional, Union + +from PIL import Image + + +QUESTION_TEMPLATE_SEMANTIC = ( + "You are presented with a generated image and its associated text caption. Your task is to analyze the image across multiple dimensions in relation to the caption. Specifically:\n\n" + "1. Evaluate each word in the caption based on how well it is visually represented in the image. Assign a numerical score to each word using the format:\n" + " Word-wise Scores: [[\"word1\", score1], [\"word2\", score2], ..., [\"wordN\", scoreN], [\"[No_mistakes]\", scoreM]]\n" + " - A higher score indicates that the word is less well represented in the image.\n" + " - The special token [No_mistakes] represents whether all elements in the caption were correctly depicted. A high score suggests no mistakes; a low score suggests missing or incorrect elements.\n\n" + "2. Provide overall assessments for the image along the following axes (each rated from 1 to 5):\n" + "- Alignment Score: How well the image matches the caption in terms of content.\n" + "- Coherence Score: How logically consistent the image is (absence of visual glitches, object distortions, etc.).\n" + "- Style Score: How aesthetically appealing the image looks, regardless of caption accuracy.\n\n" + "Output your evaluation using the format below:\n\n" + "---\n\n" + "Word-wise Scores: [[\"word1\", score1], ..., [\"[No_mistakes]\", scoreM]]\n\n" + "Alignment Score (1-5): X\n" + "Coherence Score (1-5): Y\n" + "Style Score (1-5): Z\n\n" + "Your task is provided as follows:\nText Caption: [{}]" +) + +QUESTION_TEMPLATE_SCORE = ( + "You are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n" + "1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n" + "2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\n" + "Extract key elements from the provided text caption, evaluate their presence in the generated image using the format: \'element (type): value\' (where value=0 means not generated, and value=1 means generated), and assign a score from 1 to 5 after \'Final Score:\'.\n" + "Your task is provided as follows:\nText Caption: [{}]" +) + + +class VLMessageClient: + def __init__(self, api_url): + self.api_url = api_url + self._session = None + + @property + def session(self): + if self._session is None: + self._session = requests.Session() + return self._session + + def close(self): + """Close the session if it exists.""" + if self._session is not None: + self._session.close() + self._session = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def _encode_image_base64(self, image): + if isinstance(image, str): + with Image.open(image) as img: + img = img.convert("RGB") + buffered = BytesIO() + img.save(buffered, format="JPEG", quality=95) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + elif isinstance(image, Image.Image): + buffered = BytesIO() + image.save(buffered, format="JPEG", quality=95) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + def build_messages(self, item, image_root=""): + if isinstance(item['image'], str): + image_path = os.path.join(image_root, item['image']) + return [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"file://{image_path}"}}, + { + "type": "text", + "text": f"{item['question']}" + } + ] + } + ] + assert isinstance(item['image'], Image.Image), f"image must be a PIL.Image.Image, but got {type(item['image'])}" + return [ + { + "role": "user", + "content": [ + {"type": "pil_image", "pil_image": item['image']}, + { + "type": "text", + "text": f"{item['question']}" + } + ] + } + ] + + def format_messages(self, messages): + formatted = [] + for msg in messages: + new_msg = {"role": msg["role"], "content": []} + + if msg["role"] == "system": + new_msg["content"] = msg["content"][0]["text"] + else: + for part in msg["content"]: + if part["type"] == "image_url": + img_path = part["image_url"]["url"].replace("file://", "") + base64_image = self._encode_image_base64(img_path) + new_part = { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} + } + new_msg["content"].append(new_part) + elif part["type"] == "pil_image": + base64_image = self._encode_image_base64(part["pil_image"]) + new_part = { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} + } + new_msg["content"].append(new_part) + else: + new_msg["content"].append(part) + formatted.append(new_msg) + return formatted + + def process_item(self, item, image_root=""): + max_retries = 3 + attempt = 0 + result = None + + while attempt < max_retries: + try: + attempt += 1 + raw_messages = self.build_messages(item, image_root) + formatted_messages = self.format_messages(raw_messages) + + payload = { + "model": "UnifiedReward", + "messages": formatted_messages, + "temperature": 0, + "max_tokens": 4096, + } + + response = self.session.post( + f"{self.api_url}/v1/chat/completions", + json=payload, + timeout=30 + attempt*5 + ) + response.raise_for_status() + + output = response.json()["choices"][0]["message"]["content"] + + result = { + "question": item["question"], + "image_path": item["image"] if isinstance(item["image"], str) else "PIL_Image", + "model_output": output, + "attempt": attempt, + "success": True + } + break + + except Exception as e: + if attempt == max_retries: + result = { + "question": item["question"], + "image_path": item["image"] if isinstance(item["image"], str) else "PIL_Image", + "error": str(e), + "attempt": attempt, + "success": False + } + raise(e) + else: + sleep_time = min(2 ** attempt, 10) + time.sleep(sleep_time) + + return result, result.get("success", False) + + +class UnifiedRewardModel(object): + def __init__(self, api_url, default_question_type="score", num_workers=8): + self.api_url = api_url + self.num_workers = num_workers + self.default_question_type = default_question_type + self.question_template_score = QUESTION_TEMPLATE_SCORE + self.question_template_semantic = QUESTION_TEMPLATE_SEMANTIC + # self.client = VLMessageClient(self.api_url) + + def question_constructor(self, prompt, question_type=None): + if question_type is None: + question_type = self.default_question_type + if question_type == "score": + return self.question_template_score.format(prompt) + elif question_type == "semantic": + return self.question_template_semantic.format(prompt) + else: + raise ValueError(f"Invalid question type: {question_type}") + + def _process_item_wrapper(self, client, image, question): + try: + item = { + "image": image, + "question": question, + } + result, _ = client.process_item(item) + return result + except Exception as e: + print(f"Encountered error in unified reward model processing: {str(e)}") + return None + + def _reset_proxy(self): + os.environ.pop('http_proxy', None) + os.environ.pop('https_proxy', None) + + def __call__(self, + images: Union[List[Image.Image], List[str]], + prompts: Union[str, List[str]], + question_type: Optional[str] = None, + ): + # Reset proxy, otherwise cannot access the server url + self._reset_proxy() + if isinstance(prompts, str): + prompts = [prompts] * len(images) + if len(prompts) != len(images): + raise ValueError("prompts must have the same length as images") + + with VLMessageClient(self.api_url) as client: + questions = [self.question_constructor(prompt, question_type) for prompt in prompts] + + # Initialize results and successes lists with None and False + results = [None] * len(images) + successes = [False] * len(images) + + with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_workers) as executor: + # Submit all tasks and keep track of their order + future_to_idx = { + executor.submit(self._process_item_wrapper, client, image, question): idx + for idx, (image, question) in enumerate(zip(images, questions)) + } + + # Get results in completion order but store them in the correct position + for future in concurrent.futures.as_completed(future_to_idx): + idx = future_to_idx[future] + result = future.result() + if result is not None and result.get("success", False): + output = result.get("model_output", "") + score = self.score_parser(output, question_type) + results[idx] = score + successes[idx] = True + else: + results[idx] = None + successes[idx] = False + + return results, successes + + def score_parser(self, text, question_type=None): + if question_type is None: + question_type = self.default_question_type + if question_type == "score": + return self.extract_final_score(text) + elif question_type == "semantic": + return self.extract_alignment_score(text) + else: + raise ValueError(f"Invalid question type: {question_type}") + + @staticmethod + def extract_alignment_score(text): + """ + Extract Alignment Score (1-5) from the evaluation text. + Returns a float score if found, None otherwise. + """ + match = re.search(r'Alignment Score \(1-5\):\s*([0-5](?:\.\d+)?)', text) + if match: + return float(match.group(1)) + else: + return None + + @staticmethod + def extract_final_score(text): + """ + Extract Final Score from the evaluation text. + Returns a float score if found, None otherwise. + Example input: + 'ocean (location): 0 + clouds (object): 1 + birds (animal): 0 + day time (attribute): 1 + low depth field effect (attribute): 1 + painting (attribute): 1 + Final Score: 2.33' + """ + match = re.search(r'Final Score:\s*([0-5](?:\.\d+)?)', text) + if match: + return float(match.group(1)) + else: + return None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--api_url", type=str) + parser.add_argument("--max_workers", type=int) + args = parser.parse_args() + + unified_reward_model = UnifiedRewardModel(args.api_url, num_workers=args.max_workers) + img_path = "assets/reward_demo.jpg" + images = [ + Image.open(img_path).convert("RGB") + for i in range(1, 5) + ] * 4 + prompts = "A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting." + results, successes = unified_reward_model(images, prompts, question_type="semantic") + print(results) + print(successes) + + # # 并发测试 + # proc_num = 32 + + # for i in range(5): + # with concurrent.futures.ThreadPoolExecutor(max_workers=proc_num) as executor: + # futures = [executor.submit(unified_reward_model, images, prompts, question_type="semantic") for _ in range(proc_num)] + # results = [future.result() for future in concurrent.futures.as_completed(futures)] + # print(results) \ No newline at end of file diff --git a/fastvideo/reward_model/utils.py b/fastvideo/reward_model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c70c88c8883b0e624362732d790db0b6fbbfba0 --- /dev/null +++ b/fastvideo/reward_model/utils.py @@ -0,0 +1,126 @@ +import concurrent.futures +import random + +def _compute_single_reward(reward_model, images, input_prompts): + """Compute reward for a single reward model.""" + reward_model_name = type(reward_model).__name__ + try: + if reward_model_name == 'HPSClipRewardModel': + rewards = reward_model(images, input_prompts) + successes = [1] * len(rewards) + + elif reward_model_name == 'CLIPScoreRewardModel': + rewards = reward_model(input_prompts, images) + successes = [1] * len(rewards) + + elif reward_model_name == 'ImageRewardModel': + rewards = reward_model(images, input_prompts) + successes = [1] * len(rewards) + + elif reward_model_name == 'UnifiedRewardModel': + rewards, successes_bool = reward_model(images, input_prompts) + rewards = [float(reward) if success else 0.0 for reward, success in zip(rewards, successes_bool)] + successes = [1 if success else 0 for success in successes_bool] + + elif reward_model_name == 'PickScoreRewardModel': + rewards = reward_model(images, input_prompts) + successes = [1] * len(rewards) + + else: + raise ValueError(f"Unknown reward model: {reward_model_name}") + + # Verify the length of results matches input + assert len(rewards) == len(input_prompts), \ + f"Length mismatch in {reward_model_name}: rewards ({len(rewards)}) != input_prompts ({len(input_prompts)})" + assert len(successes) == len(input_prompts), \ + f"Length mismatch in {reward_model_name}: successes ({len(successes)}) != input_prompts ({len(input_prompts)})" + + return rewards, successes + + except Exception as e: + raise ValueError(f"Error in _compute_single_reward with {reward_model_name}: {e}") from e + +def compute_reward(images, input_prompts, reward_models, reward_weights): + assert ( + len(images) == len(input_prompts) + ), f"length of `images` ({len(images)}) must be equal to length of `input_prompts` ({len(input_prompts)})" + + # Initialize results + rewards_dict = {} + successes_dict = {} + + # Create a thread pool for parallel reward computation + with concurrent.futures.ThreadPoolExecutor(max_workers=len(reward_models)) as executor: + # Submit all reward computation tasks + future_to_model = { + executor.submit(_compute_single_reward, reward_model, images, input_prompts): reward_model + for reward_model in reward_models + } + + # Process results as they complete + for future in concurrent.futures.as_completed(future_to_model): + reward_model = future_to_model[future] + model_name = type(reward_model).__name__ + try: + model_rewards, model_successes = future.result() + rewards_dict[model_name] = model_rewards + successes_dict[model_name] = model_successes + except Exception as e: + print(f"Error computing reward with {model_name}: {e}") + rewards_dict[model_name] = [0.0] * len(input_prompts) + successes_dict[model_name] = [0] * len(input_prompts) + continue + + # Merge rewards based on weights + merged_rewards = [0.0] * len(input_prompts) + merged_successes = [0] * len(input_prompts) + + # First check if all models are successful for each sample + for i in range(len(merged_rewards)): + all_success = True + for model_name in reward_weights.keys(): + if model_name in successes_dict and successes_dict[model_name][i] != 1: + all_success = False + break + + if all_success: + # Only compute weighted sum if all models are successful + for model_name, weight in reward_weights.items(): + if model_name in rewards_dict: + merged_rewards[i] += rewards_dict[model_name][i] * weight + merged_successes[i] = 1 + + return merged_rewards, merged_successes, rewards_dict, successes_dict + +def balance_pos_neg(samples, use_random=False): + """Balance positive and negative samples distribution in the samples list.""" + if use_random: + return random.sample(samples, len(samples)) + else: + positive_samples = [sample for sample in samples if sample['advantages'].item() > 0] + negative_samples = [sample for sample in samples if sample['advantages'].item() < 0] + + positive_samples = random.sample(positive_samples, len(positive_samples)) + negative_samples = random.sample(negative_samples, len(negative_samples)) + + num_positive = len(positive_samples) + num_negative = len(negative_samples) + + balanced_samples = [] + + if num_positive < num_negative: + smaller_group = positive_samples + larger_group = negative_samples + else: + smaller_group = negative_samples + larger_group = positive_samples + + for i in range(len(smaller_group)): + balanced_samples.append(smaller_group[i]) + balanced_samples.append(larger_group[i]) + + # If there are remaining samples in the larger group, add them + remaining_samples = larger_group[len(smaller_group):] + balanced_samples.extend(remaining_samples) + return balanced_samples + diff --git a/fastvideo/utils/.DS_Store b/fastvideo/utils/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/fastvideo/utils/.DS_Store differ diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7c055cf14161ba8208941c60710dcb6b13a156be --- /dev/null +++ b/fastvideo/utils/checkpoint.py @@ -0,0 +1,314 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import json +import os + +import torch +import torch.distributed.checkpoint as dist_cp +from peft import get_peft_model_state_dict +from safetensors.torch import load_file, save_file +from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner, + DefaultSavePlanner) +from torch.distributed.checkpoint.optimizer import \ + load_sharded_optimizer_state_dict +from torch.distributed.fsdp import (FullOptimStateDictConfig, + FullStateDictConfig) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType + +from fastvideo.utils.logging_ import main_print + + +def save_checkpoint_optimizer(model, + optimizer, + rank, + output_dir, + step, + discriminator=False): + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state = model.state_dict() + optim_state = FSDP.optim_state_dict( + model, + optimizer, + ) + + # todo move to get_state_dict + save_dir = os.path.join(output_dir, f"checkpoint-{step}") + os.makedirs(save_dir, exist_ok=True) + # save using safetensors + if rank <= 0 and not discriminator: + weight_path = os.path.join(save_dir, + "diffusion_pytorch_model.safetensors") + save_file(cpu_state, weight_path) + config_dict = dict(model.config) + config_dict.pop('dtype') + config_path = os.path.join(save_dir, "config.json") + # save dict as json + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=4) + optimizer_path = os.path.join(save_dir, "optimizer.pt") + torch.save(optim_state, optimizer_path) + else: + weight_path = os.path.join(save_dir, + "discriminator_pytorch_model.safetensors") + save_file(cpu_state, weight_path) + optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt") + torch.save(optim_state, optimizer_path) + main_print(f"--> checkpoint saved at step {step}") + + +def save_checkpoint(transformer, rank, output_dir, step, epoch): + main_print(f"--> saving checkpoint at step {step}") + with FSDP.state_dict_type( + transformer, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state = transformer.state_dict() + # todo move to get_state_dict + if rank <= 0: + save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}") + os.makedirs(save_dir, exist_ok=True) + # save using safetensors + weight_path = os.path.join(save_dir, + "diffusion_pytorch_model.safetensors") + save_file(cpu_state, weight_path) + config_dict = dict(transformer.config) + if "dtype" in config_dict: + del config_dict["dtype"] # TODO + config_path = os.path.join(save_dir, "config.json") + # save dict as json + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=4) + main_print(f"--> checkpoint saved at step {step}") + + +def save_checkpoint_generator_discriminator( + model, + optimizer, + discriminator, + discriminator_optimizer, + rank, + output_dir, + step, +): + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state = model.state_dict() + + # todo move to get_state_dict + save_dir = os.path.join(output_dir, f"checkpoint-{step}") + os.makedirs(save_dir, exist_ok=True) + hf_weight_dir = os.path.join(save_dir, "hf_weights") + os.makedirs(hf_weight_dir, exist_ok=True) + # save using safetensors + if rank <= 0: + config_dict = dict(model.config) + config_path = os.path.join(hf_weight_dir, "config.json") + # save dict as json + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=4) + weight_path = os.path.join(hf_weight_dir, + "diffusion_pytorch_model.safetensors") + save_file(cpu_state, weight_path) + + main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}") + model_weight_dir = os.path.join(save_dir, "model_weights_state") + os.makedirs(model_weight_dir, exist_ok=True) + model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state") + os.makedirs(model_optimizer_dir, exist_ok=True) + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + optim_state = FSDP.optim_state_dict(model, optimizer) + model_state = model.state_dict() + weight_state_dict = {"model": model_state} + dist_cp.save_state_dict( + state_dict=weight_state_dict, + storage_writer=dist_cp.FileSystemWriter(model_weight_dir), + planner=DefaultSavePlanner(), + ) + optimizer_state_dict = {"optimizer": optim_state} + dist_cp.save_state_dict( + state_dict=optimizer_state_dict, + storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir), + planner=DefaultSavePlanner(), + ) + + discriminator_fsdp_state_dir = os.path.join(save_dir, + "discriminator_fsdp_state") + os.makedirs(discriminator_fsdp_state_dir, exist_ok=True) + with FSDP.state_dict_type( + discriminator, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + optim_state = FSDP.optim_state_dict(discriminator, + discriminator_optimizer) + model_state = discriminator.state_dict() + state_dict = {"optimizer": optim_state, "model": model_state} + if rank <= 0: + discriminator_fsdp_state_fil = os.path.join( + discriminator_fsdp_state_dir, "discriminator_state.pt") + torch.save(state_dict, discriminator_fsdp_state_fil) + + main_print("--> saved FSDP state checkpoint") + + +def load_sharded_model(model, optimizer, model_dir, optimizer_dir): + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + weight_state_dict = {"model": model.state_dict()} + + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=weight_state_dict["model"], + optimizer_key="optimizer", + storage_reader=dist_cp.FileSystemReader(optimizer_dir), + ) + optim_state = optim_state["optimizer"] + flattened_osd = FSDP.optim_state_dict_to_load( + model=model, optim=optimizer, optim_state_dict=optim_state) + optimizer.load_state_dict(flattened_osd) + dist_cp.load_state_dict( + state_dict=weight_state_dict, + storage_reader=dist_cp.FileSystemReader(model_dir), + planner=DefaultLoadPlanner(), + ) + model_state = weight_state_dict["model"] + model.load_state_dict(model_state) + main_print(f"--> loaded model and optimizer from path {model_dir}") + return model, optimizer + + +def load_full_state_model(model, optimizer, checkpoint_file, rank): + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + discriminator_state = torch.load(checkpoint_file) + model_state = discriminator_state["model"] + if rank <= 0: + optim_state = discriminator_state["optimizer"] + else: + optim_state = None + model.load_state_dict(model_state) + discriminator_optim_state = FSDP.optim_state_dict_to_load( + model=model, optim=optimizer, optim_state_dict=optim_state) + optimizer.load_state_dict(discriminator_optim_state) + main_print( + f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}" + ) + return model, optimizer + + +def resume_training_generator_discriminator(model, optimizer, discriminator, + discriminator_optimizer, + checkpoint_dir, rank): + step = int(checkpoint_dir.split("-")[-1]) + model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state") + model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state") + model, optimizer = load_sharded_model(model, optimizer, model_weight_dir, + model_optimizer_dir) + discriminator_ckpt_file = os.path.join(checkpoint_dir, + "discriminator_fsdp_state", + "discriminator_state.pt") + discriminator, discriminator_optimizer = load_full_state_model( + discriminator, discriminator_optimizer, discriminator_ckpt_file, rank) + return model, optimizer, discriminator, discriminator_optimizer, step + + +def resume_training(model, optimizer, checkpoint_dir, discriminator=False): + weight_path = os.path.join(checkpoint_dir, + "diffusion_pytorch_model.safetensors") + if discriminator: + weight_path = os.path.join(checkpoint_dir, + "discriminator_pytorch_model.safetensors") + model_weights = load_file(weight_path) + + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + current_state = model.state_dict() + current_state.update(model_weights) + model.load_state_dict(current_state, strict=False) + if discriminator: + optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt") + else: + optim_path = os.path.join(checkpoint_dir, "optimizer.pt") + optimizer_state_dict = torch.load(optim_path, weights_only=False) + optim_state = FSDP.optim_state_dict_to_load( + model=model, optim=optimizer, optim_state_dict=optimizer_state_dict) + optimizer.load_state_dict(optim_state) + step = int(checkpoint_dir.split("-")[-1]) + return model, optimizer, step + + +def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, + pipeline, epoch): + with FSDP.state_dict_type( + transformer, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + full_state_dict = transformer.state_dict() + lora_optim_state = FSDP.optim_state_dict( + transformer, + optimizer, + ) + + if rank <= 0: + save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}-{epoch}") + os.makedirs(save_dir, exist_ok=True) + + # save optimizer + optim_path = os.path.join(save_dir, "lora_optimizer.pt") + torch.save(lora_optim_state, optim_path) + # save lora weight + main_print(f"--> saving LoRA checkpoint at step {step}") + transformer_lora_layers = get_peft_model_state_dict( + model=transformer, state_dict=full_state_dict) + pipeline.save_lora_weights( + save_directory=save_dir, + transformer_lora_layers=transformer_lora_layers, + is_main_process=True, + ) + # save config + lora_config = { + "step": step, + "lora_params": { + "lora_rank": transformer.config.lora_rank, + "lora_alpha": transformer.config.lora_alpha, + "target_modules": transformer.config.lora_target_modules, + }, + } + config_path = os.path.join(save_dir, "lora_config.json") + with open(config_path, "w") as f: + json.dump(lora_config, f, indent=4) + main_print(f"--> LoRA checkpoint saved at step {step}") + + +def resume_lora_optimizer(transformer, checkpoint_dir, optimizer): + config_path = os.path.join(checkpoint_dir, "lora_config.json") + with open(config_path, "r") as f: + config_dict = json.load(f) + optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt") + optimizer_state_dict = torch.load(optim_path, weights_only=False) + optim_state = FSDP.optim_state_dict_to_load( + model=transformer, + optim=optimizer, + optim_state_dict=optimizer_state_dict) + optimizer.load_state_dict(optim_state) + step = config_dict["step"] + main_print(f"--> Successfully resuming LoRA optimizer from step {step}") + return transformer, optimizer, step diff --git a/fastvideo/utils/communications.py b/fastvideo/utils/communications.py new file mode 100644 index 0000000000000000000000000000000000000000..97ce0c385e01375a99dea310130750ec975bddb9 --- /dev/null +++ b/fastvideo/utils/communications.py @@ -0,0 +1,335 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor + +from fastvideo.utils.parallel_states import nccl_info + + +def broadcast(input_: torch.Tensor): + src = nccl_info.group_id * nccl_info.sp_size + dist.broadcast(input_, src=src, group=nccl_info.group) + + +def _all_to_all_4D(input: torch.tensor, + scatter_idx: int = 2, + gather_idx: int = 1, + group=None) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, + hs).transpose(0, 2).contiguous()) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + torch.cuda.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape( + bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = (input.reshape( + bs, seq_world_size, shard_seqlen, shard_hc, + hs).transpose(0, 3).transpose(0, 1).contiguous().reshape( + seq_world_size, shard_hc, shard_seqlen, bs, hs)) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape( + bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError( + "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) + + @staticmethod + def backward(ctx: Any, + *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, + ctx.scatter_idx), + None, + None, + ) + + +def all_to_all_4D( + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, + gather_dim) + + +def _all_to_all( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [ + t.contiguous() + for t in torch.tensor_split(input_, world_size, scatter_dim) + ] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(process_group) + output = _all_to_all(input_, ctx.world_size, process_group, + scatter_dim, gather_dim) + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = _all_to_all( + grad_output, + ctx.world_size, + ctx.process_group, + ctx.gather_dim, + ctx.scatter_dim, + ) + return ( + grad_output, + None, + None, + None, + ) + + +def all_to_all( + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) + + +class _AllGather(torch.autograd.Function): + """All-gather communication with autograd support. + + Args: + input_: input tensor + dim: dimension along which to concatenate + """ + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + world_size = nccl_info.sp_size + group = nccl_info.group + input_size = list(input_.size()) + + ctx.input_size = input_size[dim] + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + input_ = input_.contiguous() + dist.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim) + return output + + @staticmethod + def backward(ctx, grad_output): + world_size = nccl_info.sp_size + rank = nccl_info.rank_within_group + dim = ctx.dim + input_size = ctx.input_size + + sizes = [input_size] * world_size + + grad_input_list = torch.split(grad_output, sizes, dim=dim) + grad_input = grad_input_list[rank] + + return grad_input, None + + +def all_gather(input_: torch.Tensor, dim: int = 1): + """Performs an all-gather operation on the input tensor along the specified dimension. + + Args: + input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. + dim (int, optional): Dimension along which to concatenate. Defaults to 1. + + Returns: + torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. + """ + return _AllGather.apply(input_, dim) + + +def prepare_sequence_parallel_data( + encoder_hidden_states, encoder_attention_mask, caption +): + if nccl_info.sp_size == 1: + return ( + encoder_hidden_states, + encoder_attention_mask, + caption, + ) + + def prepare( + encoder_hidden_states, encoder_attention_mask, caption + ): + #hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) + encoder_hidden_states = all_to_all( + encoder_hidden_states, scatter_dim=1, gather_dim=0 + ) + #attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0) + encoder_attention_mask = all_to_all( + encoder_attention_mask, scatter_dim=1, gather_dim=0 + ) + return ( + encoder_hidden_states, + encoder_attention_mask, + caption + ) + + sp_size = nccl_info.sp_size + #frame = hidden_states.shape[2] + #assert frame % sp_size == 0, "frame should be a multiple of sp_size" + + ( + #hidden_states, + encoder_hidden_states, + #attention_mask, + encoder_attention_mask, + caption, + ) = prepare( + #hidden_states, + encoder_hidden_states.repeat(1, sp_size, 1), + #attention_mask.repeat(1, sp_size, 1, 1), + encoder_attention_mask.repeat(1, sp_size), + caption, + ) + + return encoder_hidden_states, encoder_attention_mask, caption + + +def sp_parallel_dataloader_wrapper( + dataloader, device, train_batch_size, sp_size, train_sp_batch_size +): + while True: + for data_item in dataloader: + cond, cond_mask, caption = data_item + #latents = latents.to(device) + cond = cond.to(device) + #attn_mask = attn_mask.to(device) + cond_mask = cond_mask.to(device) + #frame = latents.shape[2] + frame = 19 + if frame == 1: + yield cond, cond_mask, caption + else: + cond, cond_mask, caption = prepare_sequence_parallel_data( + cond, cond_mask, caption + ) + assert ( + train_batch_size * sp_size >= train_sp_batch_size + ), "train_batch_size * sp_size should be greater than train_sp_batch_size" + for iter in range(train_batch_size * sp_size // train_sp_batch_size): + st_idx = iter * train_sp_batch_size + ed_idx = (iter + 1) * train_sp_batch_size + encoder_hidden_states = cond[st_idx:ed_idx] + #attention_mask = attn_mask[st_idx:ed_idx] + encoder_attention_mask = cond_mask[st_idx:ed_idx] + yield ( + #latents[st_idx:ed_idx], + encoder_hidden_states, + #attention_mask, + encoder_attention_mask, + caption + ) + diff --git a/fastvideo/utils/communications_flux.py b/fastvideo/utils/communications_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..18a9fb3edb1d2b7e4b48443415b9e5996df42761 --- /dev/null +++ b/fastvideo/utils/communications_flux.py @@ -0,0 +1,345 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor + +from fastvideo.utils.parallel_states import nccl_info + + +def broadcast(input_: torch.Tensor): + src = nccl_info.group_id * nccl_info.sp_size + dist.broadcast(input_, src=src, group=nccl_info.group) + + +def _all_to_all_4D(input: torch.tensor, + scatter_idx: int = 2, + gather_idx: int = 1, + group=None) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, + hs).transpose(0, 2).contiguous()) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + torch.cuda.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape( + bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = (input.reshape( + bs, seq_world_size, shard_seqlen, shard_hc, + hs).transpose(0, 3).transpose(0, 1).contiguous().reshape( + seq_world_size, shard_hc, shard_seqlen, bs, hs)) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape( + bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError( + "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) + + @staticmethod + def backward(ctx: Any, + *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, + ctx.scatter_idx), + None, + None, + ) + + +def all_to_all_4D( + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, + gather_dim) + + +def _all_to_all( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [ + t.contiguous() + for t in torch.tensor_split(input_, world_size, scatter_dim) + ] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(process_group) + output = _all_to_all(input_, ctx.world_size, process_group, + scatter_dim, gather_dim) + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = _all_to_all( + grad_output, + ctx.world_size, + ctx.process_group, + ctx.gather_dim, + ctx.scatter_dim, + ) + return ( + grad_output, + None, + None, + None, + ) + + +def all_to_all( + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) + + +class _AllGather(torch.autograd.Function): + """All-gather communication with autograd support. + + Args: + input_: input tensor + dim: dimension along which to concatenate + """ + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + world_size = nccl_info.sp_size + group = nccl_info.group + input_size = list(input_.size()) + + ctx.input_size = input_size[dim] + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + input_ = input_.contiguous() + dist.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim) + return output + + @staticmethod + def backward(ctx, grad_output): + world_size = nccl_info.sp_size + rank = nccl_info.rank_within_group + dim = ctx.dim + input_size = ctx.input_size + + sizes = [input_size] * world_size + + grad_input_list = torch.split(grad_output, sizes, dim=dim) + grad_input = grad_input_list[rank] + + return grad_input, None + + +def all_gather(input_: torch.Tensor, dim: int = 1): + """Performs an all-gather operation on the input tensor along the specified dimension. + + Args: + input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. + dim (int, optional): Dimension along which to concatenate. Defaults to 1. + + Returns: + torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. + """ + return _AllGather.apply(input_, dim) + + +def prepare_sequence_parallel_data( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption +): + if nccl_info.sp_size == 1: + return ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) + + def prepare( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + ): + #hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) + encoder_hidden_states = all_to_all( + encoder_hidden_states, scatter_dim=1, gather_dim=0 + ) + #attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0) + pooled_prompt_embeds = all_to_all( + pooled_prompt_embeds, scatter_dim=1, gather_dim=0 + ) + text_ids = all_to_all(text_ids, scatter_dim=1, gather_dim=0) + return ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) + + sp_size = nccl_info.sp_size + #frame = hidden_states.shape[2] + #assert frame % sp_size == 0, "frame should be a multiple of sp_size" + + ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) = prepare( + #hidden_states, + encoder_hidden_states.repeat(1, sp_size, 1), + pooled_prompt_embeds.repeat(1, sp_size, 1, 1), + text_ids.repeat(1, sp_size), + caption, + ) + + return encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + + +def sp_parallel_dataloader_wrapper( + dataloader, device, train_batch_size, sp_size, train_sp_batch_size +): + while True: + for data_item in dataloader: + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption = data_item + #latents = latents.to(device) + encoder_hidden_states = encoder_hidden_states.to(device) + pooled_prompt_embeds = pooled_prompt_embeds.to(device) + text_ids = text_ids.to(device) + #frame = latents.shape[2] + frame = 19 + if frame == 1: + yield encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + else: + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption = prepare_sequence_parallel_data( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + ) + assert ( + train_batch_size * sp_size >= train_sp_batch_size + ), "train_batch_size * sp_size should be greater than train_sp_batch_size" + for iter in range(train_batch_size * sp_size // train_sp_batch_size): + st_idx = iter * train_sp_batch_size + ed_idx = (iter + 1) * train_sp_batch_size + encoder_hidden_states = encoder_hidden_states[st_idx:ed_idx] + pooled_prompt_embeds = pooled_prompt_embeds[st_idx:ed_idx] + text_ids = text_ids[st_idx:ed_idx] + yield ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) + diff --git a/fastvideo/utils/communications_flux_rfpt.py b/fastvideo/utils/communications_flux_rfpt.py new file mode 100644 index 0000000000000000000000000000000000000000..78fab3bd1cd3905620f5a45a87f309e67bb752e6 --- /dev/null +++ b/fastvideo/utils/communications_flux_rfpt.py @@ -0,0 +1,388 @@ +# Copyright (c) [2025] [FastVideo Team] +# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.] +# SPDX-License-Identifier: [Apache License 2.0] +# +# This file has been modified by [ByteDance Ltd. and/or its affiliates.] in 2025. +# +# Original file was released under [Apache License 2.0], with the full license text +# available at [https://github.com/hao-ai-lab/FastVideo/blob/main/LICENSE]. +# +# This modified file is released under the same license. + +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor + +from fastvideo.utils.parallel_states import nccl_info + + +def broadcast(input_: torch.Tensor): + src = nccl_info.group_id * nccl_info.sp_size + dist.broadcast(input_, src=src, group=nccl_info.group) + + +def _all_to_all_4D(input: torch.tensor, + scatter_idx: int = 2, + gather_idx: int = 1, + group=None) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, + hs).transpose(0, 2).contiguous()) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + torch.cuda.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape( + bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = (input.reshape( + bs, seq_world_size, shard_seqlen, shard_hc, + hs).transpose(0, 3).transpose(0, 1).contiguous().reshape( + seq_world_size, shard_hc, shard_seqlen, bs, hs)) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape( + bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError( + "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) + + @staticmethod + def backward(ctx: Any, + *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, + ctx.scatter_idx), + None, + None, + ) + + +def all_to_all_4D( + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, + gather_dim) + + +def _all_to_all( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [ + t.contiguous() + for t in torch.tensor_split(input_, world_size, scatter_dim) + ] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(process_group) + output = _all_to_all(input_, ctx.world_size, process_group, + scatter_dim, gather_dim) + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = _all_to_all( + grad_output, + ctx.world_size, + ctx.process_group, + ctx.gather_dim, + ctx.scatter_dim, + ) + return ( + grad_output, + None, + None, + None, + ) + + +def all_to_all( + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) + + +class _AllGather(torch.autograd.Function): + """All-gather communication with autograd support. + + Args: + input_: input tensor + dim: dimension along which to concatenate + """ + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + world_size = nccl_info.sp_size + group = nccl_info.group + input_size = list(input_.size()) + + ctx.input_size = input_size[dim] + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + input_ = input_.contiguous() + dist.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim) + return output + + @staticmethod + def backward(ctx, grad_output): + world_size = nccl_info.sp_size + rank = nccl_info.rank_within_group + dim = ctx.dim + input_size = ctx.input_size + + sizes = [input_size] * world_size + + grad_input_list = torch.split(grad_output, sizes, dim=dim) + grad_input = grad_input_list[rank] + + return grad_input, None + + +def all_gather(input_: torch.Tensor, dim: int = 1): + """Performs an all-gather operation on the input tensor along the specified dimension. + + Args: + input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. + dim (int, optional): Dimension along which to concatenate. Defaults to 1. + + Returns: + torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. + """ + return _AllGather.apply(input_, dim) + + +def prepare_sequence_parallel_data( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption +): + if nccl_info.sp_size == 1: + return ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) + + def prepare( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + ): + #hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) + encoder_hidden_states = all_to_all( + encoder_hidden_states, scatter_dim=1, gather_dim=0 + ) + #attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0) + pooled_prompt_embeds = all_to_all( + pooled_prompt_embeds, scatter_dim=1, gather_dim=0 + ) + text_ids = all_to_all(text_ids, scatter_dim=1, gather_dim=0) + return ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) + + sp_size = nccl_info.sp_size + #frame = hidden_states.shape[2] + #assert frame % sp_size == 0, "frame should be a multiple of sp_size" + + ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + ) = prepare( + #hidden_states, + encoder_hidden_states.repeat(1, sp_size, 1), + pooled_prompt_embeds.repeat(1, sp_size, 1, 1), + text_ids.repeat(1, sp_size), + caption, + ) + + return encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + + +def sp_parallel_dataloader_wrapper( + dataloader, device, train_batch_size, sp_size, train_sp_batch_size +): + while True: + for data_item in dataloader: + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption, latents = data_item + #latents = latents.to(device) + encoder_hidden_states = encoder_hidden_states.to(device) + pooled_prompt_embeds = pooled_prompt_embeds.to(device) + text_ids = text_ids.to(device) + latents = latents.to(device) + #frame = latents.shape[2] + frame = 19 + if frame == 1: + yield encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + else: + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption = prepare_sequence_parallel_data( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + ) + assert ( + train_batch_size * sp_size >= train_sp_batch_size + ), "train_batch_size * sp_size should be greater than train_sp_batch_size" + for iter in range(train_batch_size * sp_size // train_sp_batch_size): + st_idx = iter * train_sp_batch_size + ed_idx = (iter + 1) * train_sp_batch_size + encoder_hidden_states = encoder_hidden_states[st_idx:ed_idx] + pooled_prompt_embeds = pooled_prompt_embeds[st_idx:ed_idx] + text_ids = text_ids[st_idx:ed_idx] + latents = latents[st_idx:ed_idx] + yield ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + latents + ) + +def sp_parallel_dataloader_wrapper_all( + dataloader, device, train_batch_size, sp_size, train_sp_batch_size +): + while True: + for data_item in dataloader: + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption, latents, images = data_item + #latents = latents.to(device) + encoder_hidden_states = encoder_hidden_states.to(device) + pooled_prompt_embeds = pooled_prompt_embeds.to(device) + text_ids = text_ids.to(device) + latents = latents.to(device) + images= images.to(device) + #frame = latents.shape[2] + frame = 19 + if frame == 1: + yield encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + else: + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption = prepare_sequence_parallel_data( + encoder_hidden_states, pooled_prompt_embeds, text_ids, caption + ) + assert ( + train_batch_size * sp_size >= train_sp_batch_size + ), "train_batch_size * sp_size should be greater than train_sp_batch_size" + for iter in range(train_batch_size * sp_size // train_sp_batch_size): + st_idx = iter * train_sp_batch_size + ed_idx = (iter + 1) * train_sp_batch_size + encoder_hidden_states = encoder_hidden_states[st_idx:ed_idx] + pooled_prompt_embeds = pooled_prompt_embeds[st_idx:ed_idx] + text_ids = text_ids[st_idx:ed_idx] + latents = latents[st_idx:ed_idx] + images = images[st_idx:ed_idx] + yield ( + encoder_hidden_states, + pooled_prompt_embeds, + text_ids, + caption, + latents, + images + ) + diff --git a/fastvideo/utils/dataset_utils.py b/fastvideo/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..559a750463dfc3db39670887f4cb9b2a8f1b6693 --- /dev/null +++ b/fastvideo/utils/dataset_utils.py @@ -0,0 +1,378 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + + +import math +import random +from collections import Counter +from typing import List, Optional + +import decord +import torch +import torch.utils +import torch.utils.data +from torch.nn import functional as F +from torch.utils.data import Sampler + +IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG"] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +class DecordInit(object): + """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" + + def __init__(self, num_threads=1): + self.num_threads = num_threads + self.ctx = decord.cpu(0) + + def __call__(self, filename): + """Perform the Decord initialization. + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + reader = decord.VideoReader(filename, + ctx=self.ctx, + num_threads=self.num_threads) + return reader + + def __repr__(self): + repr_str = (f"{self.__class__.__name__}(" + f"sr={self.sr}," + f"num_threads={self.num_threads})") + return repr_str + + +def pad_to_multiple(number, ds_stride): + remainder = number % ds_stride + if remainder == 0: + return number + else: + padding = ds_stride - remainder + return number + padding + + +# TODO +class Collate: + + def __init__(self, args): + self.batch_size = args.train_batch_size + self.group_frame = args.group_frame + self.group_resolution = args.group_resolution + + self.max_height = args.max_height + self.max_width = args.max_width + self.ae_stride = args.ae_stride + + self.ae_stride_t = args.ae_stride_t + self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride) + + self.patch_size = args.patch_size + self.patch_size_t = args.patch_size_t + + self.num_frames = args.num_frames + self.use_image_num = args.use_image_num + self.max_thw = (self.num_frames, self.max_height, self.max_width) + + def package(self, batch): + batch_tubes = [i["pixel_values"] for i in batch] # b [c t h w] + input_ids = [i["input_ids"] for i in batch] # b [1 l] + cond_mask = [i["cond_mask"] for i in batch] # b [1 l] + return batch_tubes, input_ids, cond_mask + + def __call__(self, batch): + batch_tubes, input_ids, cond_mask = self.package(batch) + + ds_stride = self.ae_stride * self.patch_size + t_ds_stride = self.ae_stride_t * self.patch_size_t + + pad_batch_tubes, attention_mask, input_ids, cond_mask = self.process( + batch_tubes, + input_ids, + cond_mask, + t_ds_stride, + ds_stride, + self.max_thw, + self.ae_stride_thw, + ) + assert not torch.any( + torch.isnan(pad_batch_tubes)), "after pad_batch_tubes" + return pad_batch_tubes, attention_mask, input_ids, cond_mask + + def process( + self, + batch_tubes, + input_ids, + cond_mask, + t_ds_stride, + ds_stride, + max_thw, + ae_stride_thw, + ): + # pad to max multiple of ds_stride + batch_input_size = [i.shape + for i in batch_tubes] # [(c t h w), (c t h w)] + assert len(batch_input_size) == self.batch_size + if self.group_frame or self.group_resolution or self.batch_size == 1: # + len_each_batch = batch_input_size + idx_length_dict = dict( + [*zip(list(range(self.batch_size)), len_each_batch)]) + count_dict = Counter(len_each_batch) + if len(count_dict) != 1: + sorted_by_value = sorted(count_dict.items(), + key=lambda item: item[1]) + pick_length = sorted_by_value[-1][0] # the highest frequency + candidate_batch = [ + idx for idx, length in idx_length_dict.items() + if length == pick_length + ] + random_select_batch = [ + random.choice(candidate_batch) + for _ in range(len(len_each_batch) - len(candidate_batch)) + ] + print( + batch_input_size, + idx_length_dict, + count_dict, + sorted_by_value, + pick_length, + candidate_batch, + random_select_batch, + ) + pick_idx = candidate_batch + random_select_batch + + batch_tubes = [batch_tubes[i] for i in pick_idx] + batch_input_size = [i.shape for i in batch_tubes + ] # [(c t h w), (c t h w)] + input_ids = [input_ids[i] for i in pick_idx] # b [1, l] + cond_mask = [cond_mask[i] for i in pick_idx] # b [1, l] + + for i in range(1, self.batch_size): + assert batch_input_size[0] == batch_input_size[i] + max_t = max([i[1] for i in batch_input_size]) + max_h = max([i[2] for i in batch_input_size]) + max_w = max([i[3] for i in batch_input_size]) + else: + max_t, max_h, max_w = max_thw + pad_max_t, pad_max_h, pad_max_w = ( + pad_to_multiple(max_t - 1 + self.ae_stride_t, t_ds_stride), + pad_to_multiple(max_h, ds_stride), + pad_to_multiple(max_w, ds_stride), + ) + pad_max_t = pad_max_t + 1 - self.ae_stride_t + each_pad_t_h_w = [[ + pad_max_t - i.shape[1], pad_max_h - i.shape[2], + pad_max_w - i.shape[3] + ] for i in batch_tubes] + pad_batch_tubes = [ + F.pad(im, (0, pad_w, 0, pad_h, 0, pad_t), value=0) + for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes) + ] + pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0) + + max_tube_size = [pad_max_t, pad_max_h, pad_max_w] + max_latent_size = [ + ((max_tube_size[0] - 1) // ae_stride_thw[0] + 1), + max_tube_size[1] // ae_stride_thw[1], + max_tube_size[2] // ae_stride_thw[2], + ] + valid_latent_size = [[ + int(math.ceil((i[1] - 1) / ae_stride_thw[0])) + 1, + int(math.ceil(i[2] / ae_stride_thw[1])), + int(math.ceil(i[3] / ae_stride_thw[2])), + ] for i in batch_input_size] + attention_mask = [ + F.pad( + torch.ones(i, dtype=pad_batch_tubes.dtype), + ( + 0, + max_latent_size[2] - i[2], + 0, + max_latent_size[1] - i[1], + 0, + max_latent_size[0] - i[0], + ), + value=0, + ) for i in valid_latent_size + ] + attention_mask = torch.stack(attention_mask) # b t h w + if self.batch_size == 1 or self.group_frame or self.group_resolution: + assert torch.all(attention_mask.bool()) + + input_ids = torch.stack(input_ids) # b 1 l + cond_mask = torch.stack(cond_mask) # b 1 l + + return pad_batch_tubes, attention_mask, input_ids, cond_mask + + +def split_to_even_chunks(indices, lengths, num_chunks, batch_size): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + chunks = [indices[i::num_chunks] for i in range(num_chunks)] + else: + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + # return chunks + + pad_chunks = [] + for idx, chunk in enumerate(chunks): + if batch_size != len(chunk): + assert batch_size > len(chunk) + if len(chunk) != 0: + chunk = chunk + [ + random.choice(chunk) + for _ in range(batch_size - len(chunk)) + ] + else: + chunk = random.choice(pad_chunks) + print(chunks[idx], "->", chunk) + pad_chunks.append(chunk) + return pad_chunks + + +def group_frame_fun(indices, lengths): + # sort by num_frames + indices.sort(key=lambda i: lengths[i], reverse=True) + return indices + + +def megabatch_frame_alignment(megabatches, lengths): + aligned_magabatches = [] + for _, megabatch in enumerate(megabatches): + assert len(megabatch) != 0 + len_each_megabatch = [lengths[i] for i in megabatch] + idx_length_dict = dict([*zip(megabatch, len_each_megabatch)]) + count_dict = Counter(len_each_megabatch) + + # mixed frame length, align megabatch inside + if len(count_dict) != 1: + sorted_by_value = sorted(count_dict.items(), + key=lambda item: item[1]) + pick_length = sorted_by_value[-1][0] # the highest frequency + candidate_batch = [ + idx for idx, length in idx_length_dict.items() + if length == pick_length + ] + random_select_batch = [ + random.choice(candidate_batch) + for i in range(len(idx_length_dict) - len(candidate_batch)) + ] + aligned_magabatch = candidate_batch + random_select_batch + aligned_magabatches.append(aligned_magabatch) + # already aligned megabatches + else: + aligned_magabatches.append(megabatch) + + return aligned_magabatches + + +def get_length_grouped_indices( + lengths, + batch_size, + world_size, + generator=None, + group_frame=False, + group_resolution=False, + seed=42, +): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + if generator is None: + generator = torch.Generator().manual_seed( + seed) # every rank will generate a fixed order but random index + + indices = torch.randperm(len(lengths), generator=generator).tolist() + + # sort dataset according to frame + indices = group_frame_fun(indices, lengths) + + # chunk dataset to megabatches + megabatch_size = world_size * batch_size + megabatches = [ + indices[i:i + megabatch_size] + for i in range(0, len(lengths), megabatch_size) + ] + + # make sure the length in each magabatch is align with each other + megabatches = megabatch_frame_alignment(megabatches, lengths) + + # aplit aligned megabatch into batches + megabatches = [ + split_to_even_chunks(megabatch, lengths, world_size, batch_size) + for megabatch in megabatches + ] + + # random megabatches to do video-image mix training + indices = torch.randperm(len(megabatches), generator=generator).tolist() + shuffled_megabatches = [megabatches[i] for i in indices] + + # expand indices and return + return [ + i for megabatch in shuffled_megabatches for batch in megabatch + for i in batch + ] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + rank: int, + world_size: int, + lengths: Optional[List[int]] = None, + group_frame=False, + group_resolution=False, + generator=None, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.rank = rank + self.world_size = world_size + self.lengths = lengths + self.group_frame = group_frame + self.group_resolution = group_resolution + self.generator = generator + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + indices = get_length_grouped_indices( + self.lengths, + self.batch_size, + self.world_size, + group_frame=self.group_frame, + group_resolution=self.group_resolution, + generator=self.generator, + ) + + def distributed_sampler(lst, rank, batch_size, world_size): + result = [] + index = rank * batch_size + while index < len(lst): + result.extend(lst[index:index + batch_size]) + index += batch_size * world_size + return result + + indices = distributed_sampler(indices, self.rank, self.batch_size, + self.world_size) + return iter(indices) diff --git a/fastvideo/utils/env_utils.py b/fastvideo/utils/env_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd00328aa382431eb90732bd5d40a6ee1bc99e9 --- /dev/null +++ b/fastvideo/utils/env_utils.py @@ -0,0 +1,42 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import platform + +import accelerate +import peft +import torch +import transformers +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + +VERSION = "1.2.0" + +if __name__ == "__main__": + info = { + "FastVideo version": VERSION, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": torch.__version__, + "Transformers version": transformers.__version__, + "Accelerate version": accelerate.__version__, + "PEFT version": peft.__version__, + } + + if is_torch_cuda_available(): + info["PyTorch version"] += " (GPU)" + info["GPU type"] = torch.cuda.get_device_name() + + if is_torch_npu_available(): + info["PyTorch version"] += " (NPU)" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann # codespell:ignore + + try: + import bitsandbytes + + info["Bitsandbytes version"] = bitsandbytes.__version__ + except Exception: + pass + + print("\n" + + "\n".join([f"- {key}: {value}" + for key, value in info.items()]) + "\n") diff --git a/fastvideo/utils/fsdp_util.py b/fastvideo/utils/fsdp_util.py new file mode 100644 index 0000000000000000000000000000000000000000..aed7fc3dc5d4872ac45d8cb482f67a1deb903db4 --- /dev/null +++ b/fastvideo/utils/fsdp_util.py @@ -0,0 +1,136 @@ +# ruff: noqa: E731 +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + + +import functools +from functools import partial + +import torch +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper) +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformerBlock +from fastvideo.utils.load import get_no_split_modules + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + +check_fn = lambda submodule: isinstance(submodule, MochiTransformerBlock) + + +def apply_fsdp_checkpointing(model, no_split_modules, p=1): + # https://github.com/foundation-model-stack/fms-fsdp/blob/408c7516d69ea9b6bcd4c0f5efab26c0f64b3c2d/fms_fsdp/policies/ac_handler.py#L16 + """apply activation checkpointing to model + returns None as model is updated directly + """ + print("--> applying fdsp activation checkpointing...") + block_idx = 0 + cut_off = 1 / 2 + # when passing p as a fraction number (e.g. 1/3), it will be interpreted + # as a string in argv, thus we need eval("1/3") here for fractions. + p = eval(p) if isinstance(p, str) else p + + def selective_checkpointing(submodule): + nonlocal block_idx + nonlocal cut_off + + if isinstance(submodule, no_split_modules): + block_idx += 1 + if block_idx * p >= cut_off: + cut_off += 1 + return True + return False + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=selective_checkpointing, + ) + + +def get_mixed_precision(master_weight_type="fp32"): + weight_type = torch.float32 if master_weight_type == "fp32" else torch.bfloat16 + mixed_precision = MixedPrecision( + param_dtype=weight_type, + # Gradient communication precision. + reduce_dtype=weight_type, + # Buffer precision. + buffer_dtype=weight_type, + cast_forward_inputs=False, + ) + return mixed_precision + + +def get_dit_fsdp_kwargs( + transformer, + sharding_strategy, + use_lora=False, + cpu_offload=False, + master_weight_type="fp32", +): + no_split_modules = get_no_split_modules(transformer) + if use_lora: + auto_wrap_policy = fsdp_auto_wrap_policy + else: + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=no_split_modules, + ) + + # we use float32 for fsdp but autocast during training + mixed_precision = get_mixed_precision(master_weight_type) + + if sharding_strategy == "full": + sharding_strategy = ShardingStrategy.FULL_SHARD + elif sharding_strategy == "hybrid_full": + sharding_strategy = ShardingStrategy.HYBRID_SHARD + elif sharding_strategy == "none": + sharding_strategy = ShardingStrategy.NO_SHARD + auto_wrap_policy = None + elif sharding_strategy == "hybrid_zero2": + sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + + device_id = torch.cuda.current_device() + cpu_offload = (torch.distributed.fsdp.CPUOffload( + offload_params=True) if cpu_offload else None) + fsdp_kwargs = { + "auto_wrap_policy": auto_wrap_policy, + "mixed_precision": mixed_precision, + "sharding_strategy": sharding_strategy, + "device_id": device_id, + "limit_all_gathers": True, + "cpu_offload": cpu_offload, + } + + # Add LoRA-specific settings when LoRA is enabled + if use_lora: + fsdp_kwargs.update({ + "use_orig_params": False, # Required for LoRA memory savings + "sync_module_states": True, + }) + + return fsdp_kwargs, no_split_modules + + +def get_discriminator_fsdp_kwargs(master_weight_type="fp32"): + auto_wrap_policy = None + + # Use existing mixed precision settings + + mixed_precision = get_mixed_precision(master_weight_type) + sharding_strategy = ShardingStrategy.NO_SHARD + device_id = torch.cuda.current_device() + fsdp_kwargs = { + "auto_wrap_policy": auto_wrap_policy, + "mixed_precision": mixed_precision, + "sharding_strategy": sharding_strategy, + "device_id": device_id, + "limit_all_gathers": True, + } + + return fsdp_kwargs diff --git a/fastvideo/utils/fsdp_util_qwenimage.py b/fastvideo/utils/fsdp_util_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ae50be963e30c900e175c65887a0491fdd5ebc --- /dev/null +++ b/fastvideo/utils/fsdp_util_qwenimage.py @@ -0,0 +1,69 @@ +import os +import functools +import torch +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch, MixedPrecision, CPUOffload +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from fastvideo.models.qwenimage.transformer_qwenimage import QwenImageTransformerBlock + +class FSDPConfig: + def __init__( + self, + sharding_strategy="FULL_SHARD", + backward_prefetch="BACKWARD_PRE", + cpu_offload=False, + num_replicate=1, + num_shard=8, + mixed_precision_dtype=torch.bfloat16, + use_device_mesh=False, + ): + self.sharding_strategy = sharding_strategy + self.backward_prefetch = backward_prefetch + self.cpu_offload = cpu_offload + self.num_replicate = num_replicate + self.num_shard = num_shard + self.mixed_precision_dtype = mixed_precision_dtype + self.use_device_mesh = use_device_mesh + +def fsdp_wrapper(model, fsdp_config, ignored_modules=None): + if ignored_modules is None: + ignored_modules = [] + device_mesh = None + if fsdp_config.sharding_strategy == 'HYBRID_SHARD' and fsdp_config.use_device_mesh: + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard), + mesh_dim_names=("replicate", "shard") + ) + def get_transformer_layer_cls(): + return { + QwenImageTransformerBlock, + } + fsdp_model = FSDP( + model, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=get_transformer_layer_cls(), + ), + ignored_modules=ignored_modules, + mixed_precision=MixedPrecision( + param_dtype=fsdp_config.mixed_precision_dtype, + reduce_dtype=fsdp_config.mixed_precision_dtype, + buffer_dtype=fsdp_config.mixed_precision_dtype, + ), + device_id=dist.get_rank() % torch.cuda.device_count(), + sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy], + backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch], + cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload), + device_mesh=device_mesh, + use_orig_params=True, + ) + + return fsdp_model \ No newline at end of file diff --git a/fastvideo/utils/load.py b/fastvideo/utils/load.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2a0dbd122faa149ee3290c01307de9bb32cfc7 --- /dev/null +++ b/fastvideo/utils/load.py @@ -0,0 +1,382 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import os +from pathlib import Path + +import torch +import torch.nn.functional as F +from diffusers import AutoencoderKLHunyuanVideo, AutoencoderKLMochi +from torch import nn +from transformers import AutoTokenizer, T5EncoderModel + +from fastvideo.models.hunyuan.modules.models import ( + HYVideoDiffusionTransformer, MMDoubleStreamBlock, MMSingleStreamBlock) +from fastvideo.models.hunyuan.text_encoder import TextEncoder +from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import \ + AutoencoderKLCausal3D +from fastvideo.models.hunyuan_hf.modeling_hunyuan import ( + HunyuanVideoSingleTransformerBlock, HunyuanVideoTransformer3DModel, + HunyuanVideoTransformerBlock) +from fastvideo.models.mochi_hf.modeling_mochi import (MochiTransformer3DModel, + MochiTransformerBlock) +from fastvideo.utils.logging_ import main_print +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock + +hunyuan_config = { + "mm_double_blocks_depth": 20, + "mm_single_blocks_depth": 40, + "rope_dim_list": [16, 56, 56], + "hidden_size": 3072, + "heads_num": 24, + "mlp_width_ratio": 4, + "guidance_embed": True, +} + +PROMPT_TEMPLATE_ENCODE = ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") +PROMPT_TEMPLATE_ENCODE_VIDEO = ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") + +NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" + +PROMPT_TEMPLATE = { + "dit-llm-encode": { + "template": PROMPT_TEMPLATE_ENCODE, + "crop_start": 36, + }, + "dit-llm-encode-video": { + "template": PROMPT_TEMPLATE_ENCODE_VIDEO, + "crop_start": 95, + }, +} + + +class HunyuanTextEncoderWrapper(nn.Module): + + def __init__(self, pretrained_model_name_or_path, device): + super().__init__() + + text_len = 256 + crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"].get( + "crop_start", 0) + + max_length = text_len + crop_start + + # prompt_template + prompt_template = PROMPT_TEMPLATE["dit-llm-encode"] + + # prompt_template_video + prompt_template_video = PROMPT_TEMPLATE["dit-llm-encode-video"] + text_encoder_path = os.path.join(pretrained_model_name_or_path, + "text_encoder") + self.text_encoder = TextEncoder( + text_encoder_type="llm", + text_encoder_path=text_encoder_path, + max_length=max_length, + text_encoder_precision="fp16", + tokenizer_type="llm", + prompt_template=prompt_template, + prompt_template_video=prompt_template_video, + hidden_state_skip_layer=2, + apply_final_norm=False, + reproduce=False, + logger=None, + device=device, + ) + text_encoder_path_2 = os.path.join(pretrained_model_name_or_path, + "text_encoder_2") + self.text_encoder_2 = TextEncoder( + text_encoder_type="clipL", + text_encoder_path=text_encoder_path_2, + max_length=77, + text_encoder_precision="fp16", + tokenizer_type="clipL", + reproduce=False, + logger=None, + device=device, + ) + + def encode_(self, prompt, text_encoder, clip_skip=None): + # TODO + device = self.text_encoder.device + data_type = "video" + num_videos_per_prompt = 1 + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + + if clip_skip is None: + prompt_outputs = text_encoder.encode(text_inputs, + data_type="video", + device=device) + prompt_embeds = prompt_outputs.hidden_state + else: + prompt_outputs = text_encoder.encode( + text_inputs, + output_hidden_states=True, + data_type=data_type, + device=device, + ) + prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] + + prompt_embeds = text_encoder.model.text_model.final_layer_norm( + prompt_embeds) + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view( + bs_embed * num_videos_per_prompt, seq_len) + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, + device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, seq_len, -1) + return (prompt_embeds, attention_mask) + + def encode_prompt(self, prompt): + prompt_embeds, attention_mask = self.encode_(prompt, self.text_encoder) + prompt_embeds_2, attention_mask_2 = self.encode_( + prompt, self.text_encoder_2) + prompt_embeds_2 = F.pad( + prompt_embeds_2, + (0, prompt_embeds.shape[2] - prompt_embeds_2.shape[1]), + value=0, + ).unsqueeze(1) + prompt_embeds = torch.cat([prompt_embeds_2, prompt_embeds], dim=1) + return prompt_embeds, attention_mask + + +class MochiTextEncoderWrapper(nn.Module): + + def __init__(self, pretrained_model_name_or_path, device): + super().__init__() + self.text_encoder = T5EncoderModel.from_pretrained( + os.path.join(pretrained_model_name_or_path, + "text_encoder")).to(device) + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(pretrained_model_name_or_path, "tokenizer")) + self.max_sequence_length = 256 + + def encode_prompt(self, prompt): + device = self.text_encoder.device + dtype = self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, + padding="longest", + return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.max_sequence_length - 1:-1]) + main_print( + f"Truncated text input: {prompt} to: {removed_text} for model input." + ) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return prompt_embeds, prompt_attention_mask + + +def load_hunyuan_state_dict(model, dit_model_name_or_path): + load_key = "module" + model_path = dit_model_name_or_path + bare_model = "unknown" + + state_dict = torch.load(model_path, + map_location=lambda storage, loc: storage, + weights_only=True) + + if bare_model == "unknown" and ("ema" in state_dict + or "module" in state_dict): + bare_model = False + if bare_model is False: + if load_key in state_dict: + state_dict = state_dict[load_key] + else: + raise KeyError( + f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " + f"are: {list(state_dict.keys())}.") + model.load_state_dict(state_dict, strict=True) + return model + + +def load_transformer( + model_type, + dit_model_name_or_path, + pretrained_model_name_or_path, + master_weight_type, +): + if model_type == "mochi": + if dit_model_name_or_path: + transformer = MochiTransformer3DModel.from_pretrained( + dit_model_name_or_path, + torch_dtype=master_weight_type, + # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, + ) + else: + transformer = MochiTransformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=master_weight_type, + # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, + ) + elif model_type == "hunyuan_hf": + if dit_model_name_or_path: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + dit_model_name_or_path, + torch_dtype=master_weight_type, + # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, + ) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=master_weight_type, + # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, + ) + elif model_type == "hunyuan": + transformer = HYVideoDiffusionTransformer( + in_channels=16, + out_channels=16, + **hunyuan_config, + dtype=master_weight_type, + ) + transformer = load_hunyuan_state_dict(transformer, + dit_model_name_or_path) + if master_weight_type == torch.bfloat16: + transformer = transformer.bfloat16() + else: + raise ValueError(f"Unsupported model type: {model_type}") + return transformer + + +def load_vae(model_type, pretrained_model_name_or_path): + weight_dtype = torch.float32 + if model_type == "mochi": + vae = AutoencoderKLMochi.from_pretrained( + pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=weight_dtype).to("cuda") + autocast_type = torch.bfloat16 + fps = 30 + elif model_type == "hunyuan_hf": + vae = AutoencoderKLHunyuanVideo.from_pretrained( + pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=weight_dtype).to("cuda") + autocast_type = torch.bfloat16 + fps = 24 + elif model_type == "hunyuan": + vae_precision = torch.float32 + vae_path = os.path.join(pretrained_model_name_or_path, + "hunyuan-video-t2v-720p/vae") + + config = AutoencoderKLCausal3D.load_config(vae_path) + vae = AutoencoderKLCausal3D.from_config(config) + + vae_ckpt = Path(vae_path) / "pytorch_model.pt" + assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" + + ckpt = torch.load(vae_ckpt, map_location=vae.device, weights_only=True) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + if any(k.startswith("vae.") for k in ckpt.keys()): + ckpt = { + k.replace("vae.", ""): v + for k, v in ckpt.items() if k.startswith("vae.") + } + vae.load_state_dict(ckpt) + vae = vae.to(dtype=vae_precision) + vae.requires_grad_(False) + vae = vae.to("cuda") + vae.eval() + autocast_type = torch.float32 + fps = 24 + return vae, autocast_type, fps + + +def load_text_encoder(model_type, pretrained_model_name_or_path, device): + if model_type == "mochi": + text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, + device) + elif model_type == "hunyuan" or "hunyuan_hf": + text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, + device) + else: + raise ValueError(f"Unsupported model type: {model_type}") + return text_encoder + + +def get_no_split_modules(transformer): + # if of type MochiTransformer3DModel + if isinstance(transformer, MochiTransformer3DModel): + return (MochiTransformerBlock, ) + elif isinstance(transformer, HunyuanVideoTransformer3DModel): + return (HunyuanVideoSingleTransformerBlock, + HunyuanVideoTransformerBlock) + elif isinstance(transformer, HYVideoDiffusionTransformer): + return (MMDoubleStreamBlock, MMSingleStreamBlock) + elif isinstance(transformer, FluxTransformer2DModel): + return (FluxTransformerBlock, FluxSingleTransformerBlock) + else: + raise ValueError(f"Unsupported transformer type: {type(transformer)}") + + +if __name__ == "__main__": + # test encode prompt + device = torch.cuda.current_device() + pretrained_model_name_or_path = "data/hunyuan" + text_encoder = load_text_encoder("hunyuan", pretrained_model_name_or_path, + device) + prompt = "A man on stage claps his hands together while facing the audience. The audience, visible in the foreground, holds up mobile devices to record the event, capturing the moment from various angles. The background features a large banner with text identifying the man on stage. Throughout the sequence, the man's expression remains engaged and directed towards the audience. The camera angle remains constant, focusing on capturing the interaction between the man on stage and the audience." + prompt_embeds, attention_mask = text_encoder.encode_prompt(prompt) diff --git a/fastvideo/utils/logging_.py b/fastvideo/utils/logging_.py new file mode 100644 index 0000000000000000000000000000000000000000..0332658764c0258b0eb968bfef9b1e51252ae8f4 --- /dev/null +++ b/fastvideo/utils/logging_.py @@ -0,0 +1,26 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import os +import pdb +import sys + + +def main_print(content): + if int(os.environ["LOCAL_RANK"]) <= 0: + print(content) + + +# ForkedPdb().set_trace() +class ForkedPdb(pdb.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin diff --git a/fastvideo/utils/optimizer.py b/fastvideo/utils/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6506490b13d5bb266bd4b15d9834d0eab9c543f3 --- /dev/null +++ b/fastvideo/utils/optimizer.py @@ -0,0 +1,78 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import torch +from accelerate.logging import get_logger + +logger = get_logger(__name__) + + +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy"] + if args.optimizer not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not (args.optimizer.lower() + not in ["adam", "adamw"]): + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}") + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if args.optimizer.lower() == "adamw": + optimizer_class = (bnb.optim.AdamW8bit + if args.use_8bit_adam else torch.optim.AdamW) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "adam": + optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + return optimizer diff --git a/fastvideo/utils/parallel_states.py b/fastvideo/utils/parallel_states.py new file mode 100644 index 0000000000000000000000000000000000000000..a8801e44125a820ef2e39b3d39991015d3d6d43e --- /dev/null +++ b/fastvideo/utils/parallel_states.py @@ -0,0 +1,66 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import os + +import torch.distributed as dist + + +class COMM_INFO: + + def __init__(self): + self.group = None + self.sp_size = 1 + self.global_rank = 0 + self.rank_within_group = 0 + self.group_id = 0 + + +nccl_info = COMM_INFO() +_SEQUENCE_PARALLEL_STATE = False + + +def initialize_sequence_parallel_state(sequence_parallel_size): + global _SEQUENCE_PARALLEL_STATE + if sequence_parallel_size > 1: + _SEQUENCE_PARALLEL_STATE = True + initialize_sequence_parallel_group(sequence_parallel_size) + else: + nccl_info.sp_size = 1 + nccl_info.global_rank = int(os.getenv("RANK", "0")) + nccl_info.rank_within_group = 0 + nccl_info.group_id = int(os.getenv("RANK", "0")) + + +def set_sequence_parallel_state(state): + global _SEQUENCE_PARALLEL_STATE + _SEQUENCE_PARALLEL_STATE = state + + +def get_sequence_parallel_state(): + return _SEQUENCE_PARALLEL_STATE + + +def initialize_sequence_parallel_group(sequence_parallel_size): + """Initialize the sequence parallel group.""" + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + assert ( + world_size % sequence_parallel_size == 0 + ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( + world_size, sequence_parallel_size) + nccl_info.sp_size = sequence_parallel_size + nccl_info.global_rank = rank + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, + (i + 1) * sequence_parallel_size) + group = dist.new_group(ranks) + if rank in ranks: + nccl_info.group = group + nccl_info.rank_within_group = rank - i * sequence_parallel_size + nccl_info.group_id = i + + +def destroy_sequence_parallel_group(): + """Destroy the sequence parallel group.""" + dist.destroy_process_group() diff --git a/fastvideo/utils/validation.py b/fastvideo/utils/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..50c32e0e357f1f382d709d4b79a8b7103814361e --- /dev/null +++ b/fastvideo/utils/validation.py @@ -0,0 +1,347 @@ +#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. + +import gc +import os +from typing import List, Optional, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import export_to_video +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from tqdm import tqdm + +import wandb +from fastvideo.distill.solver import PCMFMScheduler +from fastvideo.models.mochi_hf.pipeline_mochi import ( + linear_quadratic_schedule, retrieve_timesteps) +from fastvideo.utils.communications import all_gather +from fastvideo.utils.load import load_vae +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) + + +def prepare_latents( + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + vae_spatial_scale_factor, + vae_temporal_scale_factor, +): + height = height // vae_spatial_scale_factor + width = width // vae_spatial_scale_factor + num_frames = (num_frames - 1) // vae_temporal_scale_factor + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + latents = randn_tensor(shape, + generator=generator, + device=device, + dtype=dtype) + return latents + + +def sample_validation_video( + model_type, + transformer, + vae, + scheduler, + scheduler_type="euler", + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + vae_spatial_scale_factor=8, + vae_temporal_scale_factor=6, + num_channels_latents=12, +): + device = vae.device + + batch_size = prompt_embeds.shape[0] + + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], + dim=0) + prompt_attention_mask = torch.cat( + [negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + # TODO: Remove hardcore + latents = prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + vae_spatial_scale_factor, + vae_temporal_scale_factor, + ) + world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group + if get_sequence_parallel_state(): + latents = rearrange(latents, + "b t (n s) h w -> b t n s h w", + n=world_size).contiguous() + latents = latents[:, :, rank, :, :, :] + + # 5. Prepare timestep + # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 + threshold_noise = 0.025 + sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + sigmas = np.array(sigmas) + if scheduler_type == "euler" and model_type == "mochi": #todo + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * scheduler.order, 0) + + # 6. Denoising loop + # with self.progress_bar(total=num_inference_steps) as progress_bar: + # write with tqdm instead + # only enable if nccl_info.global_rank == 0 + + with tqdm( + total=num_inference_steps, + disable=nccl_info.rank_within_group != 0, + desc="Validation sampling...", + ) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = (torch.cat([latents] * 2) + if do_classifier_free_guidance else latents) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + with torch.autocast("cuda", dtype=torch.bfloat16): + noise_pred = transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + return_dict=False, + )[0] + + # Mochi CFG + Sampling runs in FP32 + noise_pred = noise_pred.to(torch.float32) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = scheduler.step(noise_pred, + t, + latents.to(torch.float32), + return_dict=False)[0] + latents = latents.to(latents_dtype) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and + (i + 1) % scheduler.order == 0): + progress_bar.update() + + if get_sequence_parallel_state(): + latents = all_gather(latents, dim=2) + + if output_type == "latent": + video = latents + else: + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = (hasattr(vae.config, "latents_mean") + and vae.config.latents_mean is not None) + has_latents_std = (hasattr(vae.config, "latents_std") + and vae.config.latents_std is not None) + if has_latents_mean and has_latents_std: + latents_mean = (torch.tensor(vae.config.latents_mean).view( + 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) + latents_std = (torch.tensor(vae.config.latents_std).view( + 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) + latents = latents * latents_std / vae.config.scaling_factor + latents_mean + else: + latents = latents / vae.config.scaling_factor + with torch.autocast("cuda", dtype=vae.dtype): + video = vae.decode(latents, return_dict=False)[0] + video_processor = VideoProcessor( + vae_scale_factor=vae_spatial_scale_factor) + video = video_processor.postprocess_video(video, + output_type=output_type) + + return (video, ) + + +@torch.no_grad() +@torch.autocast("cuda", dtype=torch.bfloat16) +def log_validation( + args, + transformer, + device, + weight_dtype, # TODO + global_step, + scheduler_type="euler", + shift=1.0, + num_euler_timesteps=100, + linear_quadratic_threshold=0.025, + linear_range=0.5, + ema=False, +): + # TODO + print("Running validation....\n") + if args.model_type == "mochi": + vae_spatial_scale_factor = 8 + vae_temporal_scale_factor = 6 + num_channels_latents = 12 + elif args.model_type == "hunyuan" or "hunyuan_hf": + vae_spatial_scale_factor = 8 + vae_temporal_scale_factor = 4 + num_channels_latents = 16 + else: + raise ValueError(f"Model type {args.model_type} not supported") + vae, autocast_type, fps = load_vae(args.model_type, + args.pretrained_model_name_or_path) + vae.enable_tiling() + if scheduler_type == "euler": + scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) + else: + linear_quadraic = True if scheduler_type == "pcm_linear_quadratic" else False + scheduler = PCMFMScheduler( + 1000, + shift, + num_euler_timesteps, + linear_quadraic, + linear_quadratic_threshold, + linear_range, + ) + # args.validation_prompt_dir + + validation_guidance_scale_ls = args.validation_guidance_scale.split(",") + validation_guidance_scale_ls = [ + float(scale) for scale in validation_guidance_scale_ls + ] + for validation_sampling_step in args.validation_sampling_steps.split(","): + validation_sampling_step = int(validation_sampling_step) + for validation_guidance_scale in validation_guidance_scale_ls: + videos = [] + # prompt_embed are named embed0 to embedN + # check how many embeds are there + embe_dir = os.path.join(args.validation_prompt_dir, "prompt_embed") + mask_dir = os.path.join(args.validation_prompt_dir, + "prompt_attention_mask") + embeds = sorted([f for f in os.listdir(embe_dir)]) + masks = sorted([f for f in os.listdir(mask_dir)]) + num_embeds = len(embeds) + validation_prompt_ids = list(range(num_embeds)) + num_sp_groups = int(os.getenv("WORLD_SIZE", + "1")) // nccl_info.sp_size + # pad to multiple of groups + if num_embeds % num_sp_groups != 0: + validation_prompt_ids += [0] * (num_sp_groups - + num_embeds % num_sp_groups) + num_embeds_per_group = len(validation_prompt_ids) // num_sp_groups + local_prompt_ids = validation_prompt_ids[nccl_info.group_id * + num_embeds_per_group: + (nccl_info.group_id + 1) * + num_embeds_per_group] + + for i in local_prompt_ids: + prompt_embed_path = os.path.join(embe_dir, f"{embeds[i]}") + prompt_mask_path = os.path.join(mask_dir, f"{masks[i]}") + prompt_embeds = (torch.load( + prompt_embed_path, map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) + prompt_attention_mask = (torch.load( + prompt_mask_path, map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) + negative_prompt_embeds = torch.zeros( + 256, 4096).to(device).unsqueeze(0) + negative_prompt_attention_mask = ( + torch.zeros(256).bool().to(device).unsqueeze(0)) + generator = torch.Generator(device="cpu").manual_seed(12345) + video = sample_validation_video( + args.model_type, + transformer, + vae, + scheduler, + scheduler_type=scheduler_type, + num_frames=args.num_frames, + height=args.num_height, + width=args.num_width, + num_inference_steps=validation_sampling_step, + guidance_scale=validation_guidance_scale, + generator=generator, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask= + negative_prompt_attention_mask, + vae_spatial_scale_factor=vae_spatial_scale_factor, + vae_temporal_scale_factor=vae_temporal_scale_factor, + num_channels_latents=num_channels_latents, + )[0] + if nccl_info.rank_within_group == 0: + videos.append(video[0]) + # collect videos from all process to process zero + + gc.collect() + torch.cuda.empty_cache() + # log if main process + torch.distributed.barrier() + all_videos = [ + None for i in range(int(os.getenv("WORLD_SIZE", "1"))) + ] # remove padded videos + torch.distributed.all_gather_object(all_videos, videos) + if nccl_info.global_rank == 0: + # remove padding + videos = [video for videos in all_videos for video in videos] + videos = videos[:num_embeds] + # linearize all videos + video_filenames = [] + for i, video in enumerate(videos): + filename = os.path.join( + args.output_dir, + f"validation_step_{global_step}_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}_video_{i}.mp4", + ) + export_to_video(video, filename, fps=fps) + video_filenames.append(filename) + + logs = { + f"{'ema_' if ema else ''}validation_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}": + [ + wandb.Video(filename) + for i, filename in enumerate(video_filenames) + ] + } + wandb.log(logs, step=global_step) diff --git a/hope/ablation/finetune_mergestep_2.hope b/hope/ablation/finetune_mergestep_2.hope new file mode 100644 index 0000000000000000000000000000000000000000..e06faf45c3e2e5dbcc9d5e0339343d5a0ea55426 --- /dev/null +++ b/hope/ablation/finetune_mergestep_2.hope @@ -0,0 +1,68 @@ +[base] +type = ml-vision + +[resource] +usergroup = hadoop-camera3d +queue = root.hldy_training_cluster.hadoop-aipnlp.h800_vi_sp + +[dataset] +dataset_name = +dataset_type = +dataset_path = + +[job_track] +demand_id = 91369190 +upstream_jobid = +input_dir = +output_dir = +log_dir = + +[user_args] + +[roles] +workers = 1 +worker.memory = 1920000 +worker.vcore = 128 +worker.gcoresh800-80g = 8 +worker.script = sh /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO/hope/ablation/finetune_mergestep_2.sh /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/DanceGRPO /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/conda-envs/dancegrpo-v2/bin/python fastvideo/train_g2rpo_hps_merge.py 1 8 + +worker.ports = 1 + +[am] +afo.app.am.resource.mb = 4096 + +[tensorboard] +with.tensor.board = false + +[docker] +afo.docker.image.name = registryonline-hulk.sankuai.com/custom_prod/com.sankuai.data.hadoop.gpu/data-hadoop-camera3d_cuda12.4-nccl2.21.5-prod-10ab7b1d + + +[data] +afo.data.prefetch = false + +[failover] +afo.app.support.engine.failover = true + +[conda] +afo.conda.env.name = +afo.conda.env.path = +afo.conda.store.type = + +[distribute] +afo.role.worker.gpu_driver_version = 470.103.01 + +[others] +afo.app.env.YARN_CONTAINER_RUNTIME_DOCKER_SHM_SIZE_BYTES = 640000000000 +afo.xm.notice.receivers.account = zhangshengjun02 +with_requirements = false +afo.app.yarn.allocate.timeout.seconds = 3600000 +afo.app.blacklist.fail_times = 16 +#afo.role.worker.task.attempt.max.retry = 16 +afo.role.worker.task.attempt.max.retry = 1 +afo.dolphinfs.otherusers = hadoop-videogen-hl,hadoop-imagen-hl:true,hadoop-vision-data:true +afo.use.hdfs.fuse=true +afo.use.hdfs.fuse.subpath=:/mnt/hdfs +afo.use.hdfs.fuse.readonly=false +afo.role.worker.not.node_name = hldy-data-k8s-gpu-h800-node0483.mt,hldy-data-k8s-gpu-h800-node0866.mt,hldy-data-k8s-gpu-h800-node0187.mt,hldy-data-k8s-gpu-h800-node0059.mt,hldy-data-k8s-gpu-h800-node0178.mt,hldy-data-k8s-gpu-h800-node0670.mt,hldy-data-k8s-gpu-h800-node0303.mt,hldy-data-k8s-gpu-h800-node0950.mt,hldy-data-k8s-gpu-h800-node0785.mt,hldy-data-k8s-gpu-h800-node0416.mt,hldy-data-k8s-gpu-h800-node0846.mt,hldy-data-k8s-gpu-h800-node0836.mt,hldy-data-k8s-gpu-h800-node0802.mt,hldy-data-k8s-gpu-h800-node0768.mt,hldy-data-k8s-gpu-h800-node1014.mt,hldy-data-k8s-gpu-h800-node0843.mt +afo.role.am.not.node_name = hlsc-data-k8s-node0187.mt \ No newline at end of file diff --git a/hope/ablation/finetune_mergestep_2.sh b/hope/ablation/finetune_mergestep_2.sh new file mode 100644 index 0000000000000000000000000000000000000000..41c9c80305e38072e5906b50f97200f3087865a4 --- /dev/null +++ b/hope/ablation/finetune_mergestep_2.sh @@ -0,0 +1,97 @@ + +# cluster_spec='{"am":["psx2s7cxrbvmlcvk-am-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local"],"index":"0","role":"worker","worker":["psx2s7cxrbvmlcvk-worker-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400","psx2s7cxrbvmlcvk-worker-1.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400"]}' +# echo "cluster spec is $cluster_spec" +WORK_DIR=$1 +PYTHON_BIN=$2 +SCRIPT=$3 +NNODES=$4 +NPROC_PER_NODE=$5 + +echo "WORK_DIR is $WORK_DIR" +echo "PYTHON_BIN is $PYTHON_BIN" +echo "SCRIPT is $SCRIPT" +echo "NNODES is $NNODES" +echo "NPROC_PER_NODE is $NPROC_PER_NODE" + +PORT=${PORT:-29509} +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ + +cluster_spec=${AFO_ENV_CLUSTER_SPEC//\"/\\\"} +echo "cluster spec is $cluster_spec" +# Assuming worker_list contains the JSON string (it's already been parsed) +worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['worker'])" +worker_list=$($PYTHON_BIN -c "$worker_list_command") + +# Remove the square brackets and quotes from worker_list +worker_list_cleaned=$(echo $worker_list | tr -d '[]' | tr -d "'") + +# Convert the cleaned worker list into an array by splitting by commas +worker_strs=($(echo $worker_list_cleaned | tr ',' '\n')) + +# Extract the master (first worker) +master=${worker_strs[0]} + +# Extract master address and port +master_addr=$(echo $master | cut -d ':' -f1) +master_port=$(echo $master | cut -d ':' -f2) + +# Output the master information without brackets and quotes +echo "worker list is $worker_list_cleaned" +echo "master is $master" +echo "master address is $master_addr" +echo "master port is $master_port" + +worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['index'])" +node_rank=$($PYTHON_BIN -c "$worker_list_command") +echo "node rank is $node_rank" +dist_url="tcp://$master_addr:$master_port" +echo "dist url is $dist_url" + +export TOKENIZERS_PARALLELISM=false +export OMP_NUM_THREADS=1 +export NCCL_DEBUG=INFO +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 + +### launch with DDP (multi-machines-multi-gpus) +source scl_source enable devtoolset-7 +ifconfig +cd $WORK_DIR=/mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO +$PYTHON_BIN -m torch.distributed.run \ +--nnodes=$NNODES --nproc_per_node=$NPROC_PER_NODE --node_rank=$node_rank --master_addr=$master_addr --master_port=$PORT \ +$SCRIPT \ +--seed 42 \ +--pretrained_model_name_or_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/flux \ +--hps_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/hps/HPS_v2.1_compressed.pt \ +--hps_clip_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin \ +--data_json_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/rl_embeddings/videos2caption.json \ +--gradient_checkpointing \ +--train_batch_size 1 \ +--num_latent_t 1 \ +--sp_size 1 \ +--train_sp_batch_size 1 \ +--dataloader_num_workers 4 \ +--max_train_steps 301 \ +--learning_rate 2e-6 \ +--mixed_precision bf16 \ +--checkpointing_steps 50 \ +--cfg 0.0 \ +--output_dir /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/save_exp/hps_mergestep_2 \ +--h 1024 \ +--w 1024 \ +--t 1 \ +--sampling_steps 16 \ +--eta 0.7 \ +--lr_warmup_steps 0 \ +--sampler_seed 1223627 \ +--max_grad_norm 1.0 \ +--weight_decay 0.0001 \ +--num_generations 12 \ +--shift 3 \ +--init_same_noise \ +--clip_range 1e-4 \ +--adv_clip_max 5.0 \ +--eta_step_list 0 1 2 3 4 5 6 7 \ +--eta_step_merge_list 2 2 2 2 2 2 2 2 \ +--granular_list 1 \ \ No newline at end of file diff --git a/hope/ablation/finetune_mergestep_4.hope b/hope/ablation/finetune_mergestep_4.hope new file mode 100644 index 0000000000000000000000000000000000000000..df4f46b1e6f19c3cc260162989a16f345e1371c6 --- /dev/null +++ b/hope/ablation/finetune_mergestep_4.hope @@ -0,0 +1,68 @@ +[base] +type = ml-vision + +[resource] +usergroup = hadoop-camera3d +queue = root.hldy_training_cluster.hadoop-aipnlp.h800_vi_sp + +[dataset] +dataset_name = +dataset_type = +dataset_path = + +[job_track] +demand_id = 91369190 +upstream_jobid = +input_dir = +output_dir = +log_dir = + +[user_args] + +[roles] +workers = 1 +worker.memory = 1920000 +worker.vcore = 128 +worker.gcoresh800-80g = 8 +worker.script = sh /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO/hope/ablation/finetune_mergestep_4.sh /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/DanceGRPO /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/conda-envs/dancegrpo-v2/bin/python fastvideo/train_g2rpo_hps_merge.py 1 8 + +worker.ports = 1 + +[am] +afo.app.am.resource.mb = 4096 + +[tensorboard] +with.tensor.board = false + +[docker] +afo.docker.image.name = registryonline-hulk.sankuai.com/custom_prod/com.sankuai.data.hadoop.gpu/data-hadoop-camera3d_cuda12.4-nccl2.21.5-prod-10ab7b1d + + +[data] +afo.data.prefetch = false + +[failover] +afo.app.support.engine.failover = true + +[conda] +afo.conda.env.name = +afo.conda.env.path = +afo.conda.store.type = + +[distribute] +afo.role.worker.gpu_driver_version = 470.103.01 + +[others] +afo.app.env.YARN_CONTAINER_RUNTIME_DOCKER_SHM_SIZE_BYTES = 640000000000 +afo.xm.notice.receivers.account = zhangshengjun02 +with_requirements = false +afo.app.yarn.allocate.timeout.seconds = 3600000 +afo.app.blacklist.fail_times = 16 +#afo.role.worker.task.attempt.max.retry = 16 +afo.role.worker.task.attempt.max.retry = 1 +afo.dolphinfs.otherusers = hadoop-videogen-hl,hadoop-imagen-hl:true,hadoop-vision-data:true +afo.use.hdfs.fuse=true +afo.use.hdfs.fuse.subpath=:/mnt/hdfs +afo.use.hdfs.fuse.readonly=false +afo.role.worker.not.node_name = hldy-data-k8s-gpu-h800-node0483.mt,hldy-data-k8s-gpu-h800-node0866.mt,hldy-data-k8s-gpu-h800-node0187.mt,hldy-data-k8s-gpu-h800-node0059.mt,hldy-data-k8s-gpu-h800-node0178.mt,hldy-data-k8s-gpu-h800-node0670.mt,hldy-data-k8s-gpu-h800-node0303.mt,hldy-data-k8s-gpu-h800-node0950.mt,hldy-data-k8s-gpu-h800-node0785.mt,hldy-data-k8s-gpu-h800-node0416.mt,hldy-data-k8s-gpu-h800-node0846.mt,hldy-data-k8s-gpu-h800-node0836.mt,hldy-data-k8s-gpu-h800-node0802.mt,hldy-data-k8s-gpu-h800-node0768.mt,hldy-data-k8s-gpu-h800-node1014.mt,hldy-data-k8s-gpu-h800-node0843.mt +afo.role.am.not.node_name = hlsc-data-k8s-node0187.mt \ No newline at end of file diff --git a/hope/ablation/finetune_mergestep_4.sh b/hope/ablation/finetune_mergestep_4.sh new file mode 100644 index 0000000000000000000000000000000000000000..afff44db934f9d92be9fc856f9be7e8b38662ddb --- /dev/null +++ b/hope/ablation/finetune_mergestep_4.sh @@ -0,0 +1,97 @@ + +# cluster_spec='{"am":["psx2s7cxrbvmlcvk-am-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local"],"index":"0","role":"worker","worker":["psx2s7cxrbvmlcvk-worker-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400","psx2s7cxrbvmlcvk-worker-1.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400"]}' +# echo "cluster spec is $cluster_spec" +WORK_DIR=$1 +PYTHON_BIN=$2 +SCRIPT=$3 +NNODES=$4 +NPROC_PER_NODE=$5 + +echo "WORK_DIR is $WORK_DIR" +echo "PYTHON_BIN is $PYTHON_BIN" +echo "SCRIPT is $SCRIPT" +echo "NNODES is $NNODES" +echo "NPROC_PER_NODE is $NPROC_PER_NODE" + +PORT=${PORT:-29509} +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ + +cluster_spec=${AFO_ENV_CLUSTER_SPEC//\"/\\\"} +echo "cluster spec is $cluster_spec" +# Assuming worker_list contains the JSON string (it's already been parsed) +worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['worker'])" +worker_list=$($PYTHON_BIN -c "$worker_list_command") + +# Remove the square brackets and quotes from worker_list +worker_list_cleaned=$(echo $worker_list | tr -d '[]' | tr -d "'") + +# Convert the cleaned worker list into an array by splitting by commas +worker_strs=($(echo $worker_list_cleaned | tr ',' '\n')) + +# Extract the master (first worker) +master=${worker_strs[0]} + +# Extract master address and port +master_addr=$(echo $master | cut -d ':' -f1) +master_port=$(echo $master | cut -d ':' -f2) + +# Output the master information without brackets and quotes +echo "worker list is $worker_list_cleaned" +echo "master is $master" +echo "master address is $master_addr" +echo "master port is $master_port" + +worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['index'])" +node_rank=$($PYTHON_BIN -c "$worker_list_command") +echo "node rank is $node_rank" +dist_url="tcp://$master_addr:$master_port" +echo "dist url is $dist_url" + +export TOKENIZERS_PARALLELISM=false +export OMP_NUM_THREADS=1 +export NCCL_DEBUG=INFO +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 + +### launch with DDP (multi-machines-multi-gpus) +source scl_source enable devtoolset-7 +ifconfig +cd $WORK_DIR=/mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO +$PYTHON_BIN -m torch.distributed.run \ +--nnodes=$NNODES --nproc_per_node=$NPROC_PER_NODE --node_rank=$node_rank --master_addr=$master_addr --master_port=$PORT \ +$SCRIPT \ +--seed 42 \ +--pretrained_model_name_or_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/flux \ +--hps_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/hps/HPS_v2.1_compressed.pt \ +--hps_clip_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin \ +--data_json_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/rl_embeddings/videos2caption.json \ +--gradient_checkpointing \ +--train_batch_size 1 \ +--num_latent_t 1 \ +--sp_size 1 \ +--train_sp_batch_size 1 \ +--dataloader_num_workers 4 \ +--max_train_steps 301 \ +--learning_rate 2e-6 \ +--mixed_precision bf16 \ +--checkpointing_steps 50 \ +--cfg 0.0 \ +--output_dir /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/save_exp/hps_mergestep_4 \ +--h 1024 \ +--w 1024 \ +--t 1 \ +--sampling_steps 16 \ +--eta 0.7 \ +--lr_warmup_steps 0 \ +--sampler_seed 1223627 \ +--max_grad_norm 1.0 \ +--weight_decay 0.0001 \ +--num_generations 12 \ +--shift 3 \ +--init_same_noise \ +--clip_range 1e-4 \ +--adv_clip_max 5.0 \ +--eta_step_list 0 1 2 3 4 5 6 7 \ +--eta_step_merge_list 4 4 4 4 4 4 4 4 \ +--granular_list 1 \ \ No newline at end of file diff --git a/hope/ablation/finetune_mergestep_6.hope b/hope/ablation/finetune_mergestep_6.hope new file mode 100644 index 0000000000000000000000000000000000000000..6ac5d12313e3dee7410ca25ede0196fe65bab2f9 --- /dev/null +++ b/hope/ablation/finetune_mergestep_6.hope @@ -0,0 +1,68 @@ +[base] +type = ml-vision + +[resource] +usergroup = hadoop-camera3d +queue = root.hldy_training_cluster.hadoop-aipnlp.h800_vi_sp + +[dataset] +dataset_name = +dataset_type = +dataset_path = + +[job_track] +demand_id = 91369190 +upstream_jobid = +input_dir = +output_dir = +log_dir = + +[user_args] + +[roles] +workers = 1 +worker.memory = 1920000 +worker.vcore = 128 +worker.gcoresh800-80g = 8 +worker.script = sh /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO/hope/ablation/finetune_mergestep_6.sh /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/DanceGRPO /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/conda-envs/dancegrpo-v2/bin/python fastvideo/train_g2rpo_hps_merge.py 1 8 + +worker.ports = 1 + +[am] +afo.app.am.resource.mb = 4096 + +[tensorboard] +with.tensor.board = false + +[docker] +afo.docker.image.name = registryonline-hulk.sankuai.com/custom_prod/com.sankuai.data.hadoop.gpu/data-hadoop-camera3d_cuda12.4-nccl2.21.5-prod-10ab7b1d + + +[data] +afo.data.prefetch = false + +[failover] +afo.app.support.engine.failover = true + +[conda] +afo.conda.env.name = +afo.conda.env.path = +afo.conda.store.type = + +[distribute] +afo.role.worker.gpu_driver_version = 470.103.01 + +[others] +afo.app.env.YARN_CONTAINER_RUNTIME_DOCKER_SHM_SIZE_BYTES = 640000000000 +afo.xm.notice.receivers.account = zhangshengjun02 +with_requirements = false +afo.app.yarn.allocate.timeout.seconds = 3600000 +afo.app.blacklist.fail_times = 16 +#afo.role.worker.task.attempt.max.retry = 16 +afo.role.worker.task.attempt.max.retry = 1 +afo.dolphinfs.otherusers = hadoop-videogen-hl,hadoop-imagen-hl:true,hadoop-vision-data:true +afo.use.hdfs.fuse=true +afo.use.hdfs.fuse.subpath=:/mnt/hdfs +afo.use.hdfs.fuse.readonly=false +afo.role.worker.not.node_name = hldy-data-k8s-gpu-h800-node0483.mt,hldy-data-k8s-gpu-h800-node0866.mt,hldy-data-k8s-gpu-h800-node0187.mt,hldy-data-k8s-gpu-h800-node0059.mt,hldy-data-k8s-gpu-h800-node0178.mt,hldy-data-k8s-gpu-h800-node0670.mt,hldy-data-k8s-gpu-h800-node0303.mt,hldy-data-k8s-gpu-h800-node0950.mt,hldy-data-k8s-gpu-h800-node0785.mt,hldy-data-k8s-gpu-h800-node0416.mt,hldy-data-k8s-gpu-h800-node0846.mt,hldy-data-k8s-gpu-h800-node0836.mt,hldy-data-k8s-gpu-h800-node0802.mt,hldy-data-k8s-gpu-h800-node0768.mt,hldy-data-k8s-gpu-h800-node1014.mt,hldy-data-k8s-gpu-h800-node0843.mt +afo.role.am.not.node_name = hlsc-data-k8s-node0187.mt \ No newline at end of file diff --git a/hope/ablation/finetune_mergestep_6.sh b/hope/ablation/finetune_mergestep_6.sh new file mode 100644 index 0000000000000000000000000000000000000000..832f6856fd6eea0ab162fc1a5d14e084c3069a83 --- /dev/null +++ b/hope/ablation/finetune_mergestep_6.sh @@ -0,0 +1,97 @@ + +# cluster_spec='{"am":["psx2s7cxrbvmlcvk-am-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local"],"index":"0","role":"worker","worker":["psx2s7cxrbvmlcvk-worker-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400","psx2s7cxrbvmlcvk-worker-1.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400"]}' +# echo "cluster spec is $cluster_spec" +WORK_DIR=$1 +PYTHON_BIN=$2 +SCRIPT=$3 +NNODES=$4 +NPROC_PER_NODE=$5 + +echo "WORK_DIR is $WORK_DIR" +echo "PYTHON_BIN is $PYTHON_BIN" +echo "SCRIPT is $SCRIPT" +echo "NNODES is $NNODES" +echo "NPROC_PER_NODE is $NPROC_PER_NODE" + +PORT=${PORT:-29509} +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ + +cluster_spec=${AFO_ENV_CLUSTER_SPEC//\"/\\\"} +echo "cluster spec is $cluster_spec" +# Assuming worker_list contains the JSON string (it's already been parsed) +worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['worker'])" +worker_list=$($PYTHON_BIN -c "$worker_list_command") + +# Remove the square brackets and quotes from worker_list +worker_list_cleaned=$(echo $worker_list | tr -d '[]' | tr -d "'") + +# Convert the cleaned worker list into an array by splitting by commas +worker_strs=($(echo $worker_list_cleaned | tr ',' '\n')) + +# Extract the master (first worker) +master=${worker_strs[0]} + +# Extract master address and port +master_addr=$(echo $master | cut -d ':' -f1) +master_port=$(echo $master | cut -d ':' -f2) + +# Output the master information without brackets and quotes +echo "worker list is $worker_list_cleaned" +echo "master is $master" +echo "master address is $master_addr" +echo "master port is $master_port" + +worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['index'])" +node_rank=$($PYTHON_BIN -c "$worker_list_command") +echo "node rank is $node_rank" +dist_url="tcp://$master_addr:$master_port" +echo "dist url is $dist_url" + +export TOKENIZERS_PARALLELISM=false +export OMP_NUM_THREADS=1 +export NCCL_DEBUG=INFO +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 + +### launch with DDP (multi-machines-multi-gpus) +source scl_source enable devtoolset-7 +ifconfig +cd $WORK_DIR=/mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO +$PYTHON_BIN -m torch.distributed.run \ +--nnodes=$NNODES --nproc_per_node=$NPROC_PER_NODE --node_rank=$node_rank --master_addr=$master_addr --master_port=$PORT \ +$SCRIPT \ +--seed 42 \ +--pretrained_model_name_or_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/flux \ +--hps_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/hps/HPS_v2.1_compressed.pt \ +--hps_clip_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin \ +--data_json_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/rl_embeddings/videos2caption.json \ +--gradient_checkpointing \ +--train_batch_size 1 \ +--num_latent_t 1 \ +--sp_size 1 \ +--train_sp_batch_size 1 \ +--dataloader_num_workers 4 \ +--max_train_steps 301 \ +--learning_rate 2e-6 \ +--mixed_precision bf16 \ +--checkpointing_steps 50 \ +--cfg 0.0 \ +--output_dir /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/save_exp/hps_mergestep_6 \ +--h 1024 \ +--w 1024 \ +--t 1 \ +--sampling_steps 16 \ +--eta 0.7 \ +--lr_warmup_steps 0 \ +--sampler_seed 1223627 \ +--max_grad_norm 1.0 \ +--weight_decay 0.0001 \ +--num_generations 12 \ +--shift 3 \ +--init_same_noise \ +--clip_range 1e-4 \ +--adv_clip_max 5.0 \ +--eta_step_list 0 1 2 3 4 5 6 7 \ +--eta_step_merge_list 6 6 6 6 6 6 6 6 \ +--granular_list 1 \ \ No newline at end of file diff --git a/scripts/dataset_preparation/prepare_json_file.py b/scripts/dataset_preparation/prepare_json_file.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9ca8d0a4ca81015a64a5e1e44684fd8bbaac84 --- /dev/null +++ b/scripts/dataset_preparation/prepare_json_file.py @@ -0,0 +1,154 @@ +import json +from pathlib import Path + +import cv2 + + +def get_video_info(video_path, prompt_text): + """Extract video information using OpenCV and corresponding prompt text""" + cap = cv2.VideoCapture(str(video_path)) + + if not cap.isOpened(): + print(f"Error: Could not open video {video_path}") + return None + + # Get video properties + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps if fps > 0 else 0 + + cap.release() + + return { + "path": video_path.name, + "resolution": { + "width": width, + "height": height + }, + "fps": fps, + "duration": duration, + "cap": [prompt_text] + } + + +def read_prompt_file(prompt_path): + """Read and return the content of a prompt file""" + try: + with open(prompt_path, 'r', encoding='utf-8') as f: + return f.read().strip() + except Exception as e: + print(f"Error reading prompt file {prompt_path}: {e}") + return None + + +def process_videos_and_prompts(video_dir_path, prompt_dir_path, verbose=False): + """Process videos and their corresponding prompt files + + Args: + video_dir_path (str): Path to directory containing video files + prompt_dir_path (str): Path to directory containing prompt files + verbose (bool): Whether to print verbose processing information + """ + video_dir = Path(video_dir_path) + prompt_dir = Path(prompt_dir_path) + processed_data = [] + + # Ensure directories exist + if not video_dir.exists() or not prompt_dir.exists(): + print( + f"Error: One or both directories do not exist:\nVideos: {video_dir}\nPrompts: {prompt_dir}" + ) + return [] + + # Process each video file + for video_file in video_dir.glob('*.mp4'): + video_name = video_file.stem + prompt_file = prompt_dir / f"{video_name}.txt" + + # Check if corresponding prompt file exists + if not prompt_file.exists(): + print(f"Warning: No prompt file found for video {video_name}") + continue + + # Read prompt content + prompt_text = read_prompt_file(prompt_file) + if prompt_text is None: + continue + + # Process video and add to results + video_info = get_video_info(video_file, prompt_text) + if video_info: + processed_data.append(video_info) + + return processed_data + + +def save_results(processed_data, output_path): + """Save processed data to JSON file + + Args: + processed_data (list): List of processed video information + output_path (str): Full path for output JSON file + """ + output_path = Path(output_path) + + # Create parent directories if they don't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(processed_data, f, indent=2, ensure_ascii=False) + + return output_path + + +def parse_args(): + """Parse command line arguments""" + import argparse + + parser = argparse.ArgumentParser( + description='Process videos and their corresponding prompt files') + parser.add_argument('--video_dir', + '-v', + required=True, + help='Directory containing video files') + parser.add_argument('--prompt_dir', + '-p', + required=True, + help='Directory containing prompt text files') + parser.add_argument( + '--output_path', + '-o', + required=True, + help= + 'Full path for output JSON file (e.g., /path/to/output/videos2caption.json)' + ) + parser.add_argument('--verbose', + action='store_true', + help='Print verbose processing information') + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_args() + + # Process videos and prompts + processed_videos = process_videos_and_prompts(args.video_dir, + args.prompt_dir, + args.verbose) + + if processed_videos: + # Save results + output_path = save_results(processed_videos, args.output_path) + + print(f"\nProcessed {len(processed_videos)} videos") + print(f"Results saved to: {output_path}") + + # Print example of processed data + print("\nExample of processed video info:") + print(json.dumps(processed_videos[0], indent=2)) + else: + print("No videos were processed successfully") diff --git a/scripts/dataset_preparation/resize_videos.py b/scripts/dataset_preparation/resize_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..523fa6ef28a547b6ef919836a45e5a27a291bf21 --- /dev/null +++ b/scripts/dataset_preparation/resize_videos.py @@ -0,0 +1,174 @@ +import argparse +import logging +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import numpy as np +from moviepy.editor import VideoFileClip +from skimage.transform import resize +from tqdm import tqdm + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[logging.FileHandler('video_processing.log')]) + + +def is_16_9_ratio(width: int, height: int, tolerance: float = 0.1) -> bool: + target_ratio = 16 / 9 + actual_ratio = width / height + return abs(actual_ratio - target_ratio) <= (target_ratio * tolerance) + + +def resize_video(args_tuple): + """ + Resize a single video file. + args_tuple: (input_file, output_dir, width, height, fps) + """ + input_file, output_dir, width, height, fps = args_tuple + video = None + resized = None + output_file = output_dir / f"{input_file.name}" + + if output_file.exists(): + output_file.unlink() + + video = VideoFileClip(str(input_file)) + + if not is_16_9_ratio(video.w, video.h): + return (input_file.name, "skipped", "Not 16:9") + + def process_frame(frame): + frame_float = frame.astype(float) / 255.0 + resized = resize(frame_float, (height, width, 3), + mode='reflect', + anti_aliasing=True, + preserve_range=True) + return (resized * 255).astype(np.uint8) + + resized = video.fl_image(process_frame) + resized = resized.set_fps(fps) + + resized.write_videofile(str(output_file), + codec='libx264', + audio_codec='aac', + temp_audiofile=f'temp-audio-{input_file.stem}.m4a', + remove_temp=True, + verbose=False, + logger=None, + fps=fps) + + return (input_file.name, "success", None) + + +def process_folder(args): + input_path = Path(args.input_dir) + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} + video_files = [ + f for f in input_path.iterdir() + if f.is_file() and f.suffix.lower() in video_extensions + ] + + if not video_files: + print(f"No video files found in {args.input_dir}") + return + + print(f"Found {len(video_files)} videos") + print(f"Target: {args.width}x{args.height} at {args.fps}fps") + + # Prepare arguments for parallel processing + process_args = [(video_file, output_path, args.width, args.height, + args.fps) for video_file in video_files] + + successful = 0 + skipped = 0 + failed = [] + + # Use ProcessPoolExecutor instead of ThreadPoolExecutor + with tqdm(total=len(video_files), + desc="Converting videos", + dynamic_ncols=True) as pbar: + # Use max_workers as specified or default to CPU count + max_workers = args.max_workers + with ProcessPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_file = { + executor.submit(resize_video, arg): arg[0] + for arg in process_args + } + + # Process completed tasks + for future in as_completed(future_to_file): + filename, status, message = future.result() + if status == "success": + successful += 1 + elif status == "skipped": + skipped += 1 + else: + failed.append((filename, message)) + pbar.update(1) + + # Print final summary + print( + f"\nDone! Processed: {successful}, Skipped: {skipped}, Failed: {len(failed)}" + ) + if failed: + print("Failed files:") + for fname, error in failed: + print(f"- {fname}: {error}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description= + 'Batch resize videos to specified resolution and FPS (16:9 only)') + parser.add_argument('--input_dir', + required=True, + help='Input directory containing video files') + parser.add_argument('--output_dir', + required=True, + help='Output directory for processed videos') + parser.add_argument('--width', + type=int, + default=1280, + help='Target width in pixels (default: 848)') + parser.add_argument('--height', + type=int, + default=720, + help='Target height in pixels (default: 480)') + parser.add_argument('--fps', + type=int, + default=30, + help='Target frames per second (default: 30)') + parser.add_argument( + '--max_workers', + type=int, + default=4, + help='Maximum number of concurrent processes (default: 4)') + parser.add_argument('--log-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default='INFO', + help='Set the logging level (default: INFO)') + return parser.parse_args() + + +def main(): + args = parse_args() + logging.getLogger().setLevel(getattr(logging, args.log_level)) + + if not Path(args.input_dir).exists(): + logging.error(f"Input directory not found: {args.input_dir}") + return + + start_time = time.time() + process_folder(args) + duration = time.time() - start_time + logging.info(f"Batch processing completed in {duration:.2f} seconds") + + +if __name__ == "__main__": + main() diff --git a/scripts/evaluation/eval.py b/scripts/evaluation/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9022d027c962acb219b1824871859e186e0d3d95 --- /dev/null +++ b/scripts/evaluation/eval.py @@ -0,0 +1,83 @@ +import os +import re +import torch +import torch.distributed as dist +from pathlib import Path +from diffusers import FluxPipeline +from diffusers import FluxTransformer2DModel +from torch.utils.data import Dataset, DistributedSampler + +class PromptDataset(Dataset): + def __init__(self, file_path): + with open(file_path, 'r') as f: + self.prompts = [line.strip() for line in f if line.strip()] + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, idx): + return self.prompts[idx] + +def sanitize_filename(text, max_length=200): + sanitized = re.sub(r'[\\/:*?"<>|]', '_', text) + return sanitized[:max_length].rstrip() or "untitled" + +def distributed_setup(): + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + return rank, local_rank, world_size + +def main(): + rank, local_rank, world_size = distributed_setup() + + model_path = "CKPT_PATH" + flux_path = "./ckpt/flux" + + transformer = FluxTransformer2DModel.from_pretrained(model_path, use_safetensors=True, torch_dtype=torch.float16).to("cuda") + pipe = FluxPipeline.from_pretrained(flux_path, transformer=None, torch_dtype=torch.float16).to("cuda") + pipe.transformer = transformer + + dataset = PromptDataset("scripts/evaluation/prompt_test.txt") + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=False + ) + + output_dir = Path("IMAGE_SAVE_FOLDER") + output_dir.mkdir(parents=True, exist_ok=True) + + for idx in sampler: + prompt = dataset[idx] + try: + generator = torch.Generator(device=f"cuda:{local_rank}") + generator.manual_seed(42 + idx + rank*1000) + + image = pipe( + prompt, + guidance_scale=3.5, + height=1024, + width=1024, + num_inference_steps=50, + max_sequence_length=512, + generator=generator, + ).images[0] + + filename = sanitize_filename(prompt) + save_path = output_dir / f"{filename}.png" + image.save(save_path) + + print(f"[Rank {rank}] Generated: {save_path.name}") + + except Exception as e: + print(f"[Rank {rank}] Error processing '{prompt[:20]}...': {str(e)}") + + dist.destroy_process_group() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/evaluation/prompt_test.txt b/scripts/evaluation/prompt_test.txt new file mode 100644 index 0000000000000000000000000000000000000000..5cc3768ffda1cbeb7079433fa094e3a910f4d365 --- /dev/null +++ b/scripts/evaluation/prompt_test.txt @@ -0,0 +1,400 @@ +Dwayne the Rock Johnson wrestles Jesus Christ in a WWE match in a hell in a cell. +An anime man in flight uniform with hyper detailed digital artwork and an art style inspired by Klimt, Nixeu, Ian Sprigger, Wlop, and Krenz Cushart. +A Wojak looking over a sea of memes from a cliff on 4chan. +Ralsei and Asriel from Deltarune eating pizza. +A portrait of an anime mecha robot with a Japanese town background and a starred night sky. +A minimalist portrait of Chloe Grace by Jean Giraud in a comic style. +A raccoon riding an oversized fox through a forest in a furry art anime still. +Chucky doll dressed as Beetlejuice. +2B from NieR Automata eating a bagel. +Groot depicted as a flower. +A portrait of two women with purple hair flying in different directions against a dark background. +A girl with pink pigtails and face tattoos. +A cat in a tutu dancing to Swan Lake. +Wicked witch casting fireball dressed in green with screaming expression. +Link fights an octorok in a cave in a Don Bluth-style from The Legend of Zelda, Breath of the Wild. +A cat with two horns on its head. +A cute anime schoolgirl with a sad face submerged in dark pink and blue water, portrayed in an oil painting style. +The image is a portrait of Homer Simpson as a Na'vi from Avatar, created with vibrant colors and highly detailed in a cinematic style reminiscent of romanticism by Eugene de Blaas and Ross Tran, available on Artstation with credits to Greg Rutkowski. +A depiction of Chucky, the killer doll in anime style. +A hand-drawn cute gnome holding a pumpkin in an autumn disguise, portrayed in a detailed close-up of the face with warm lighting and high detail. +The image is of an anthropomorphic orange walking on a sidewalk. +16-year-old teenager wearing a white bear-ear hat with a smirk on their face. +A spoon dressed up with eyes and a smile. +Keqing from Genshin Impact. +a papaya fruit dressed as a sailor. +The image features artwork in the style of Neon Genesis Evangelion, with a colorful anime design of John Key. +A still of Doraemon from "Shaun the Sheep" by Aardman Animation. +Rosario Dawson minimalist portrait by Jean Giraud in a comic book style. +Totem pole made out of cats. +Paul Chuckle bowling on a pirate ship during a storm. +Cartoonish serpent cephalopod mutants emerge from a fiery hell. +A gummy chameleon hanging on a tree branch. +A portrait of Nyan Cat, styled after Annie Leibovitz's dramatic photography. +An anime girl is riding a bicycle in Akihabara, resembling the style seen in Studio Ghibli films, and the depiction is detailed. +A full body portrait of Andre the Giant in the style of Justin Roiland. +An image of an emo with dark brown hair in a messy pixie cut, large entirely-black eyes, wearing black clothing and boots. +A fruit basket on a kitchen table with a Studio Ghibli reference. +A corgi puppy with many eyes depicted in a horror manga drawn by Junji Ito. +A young woman witch cosplaying with a magic wand and broom, wearing boots, and posing in a full body shot with a detailed face. +There is a Michael J. Fox Funko Pop figurine depicted in the image. +Girl shooting fireballs at a dragon in a battle pose, Madhouse studio anime style. +Portrait of anime girl in mechanic armor in night Tokyo. +Mugshot of Superman with unkempt appearance and green skin. +The image is of dancing potatoes in a cute cartoony style. +Miss Piggy dressed in futuristic outfit resembling Leeloo from The Fifth Element. +The image is of a raccoon wearing a Peaky Blinders hat, surrounded by swirling mist and rendered with fine detail. +A green Gundam in an action pose that resembles Shrek. +Digital anime art of mattress-man with a serious expression in an empty warehouse, highly detailed and spotlighted. +Portrait of an anime maid by Krenz Cushart, Alphonse Mucha, and Ilya Kuvshinov. +Siamese twins enjoying pickled eggs at a pub. +An anime furry anthro dragon in a deep forest, depicted in a realistic and detailed cel shading style. +Gnomes are playing music during Independence Day festivities in a forest near Lake George. +Spider-Man holding a ginger cat. +Close-up of Cad Bane, with bad flash. +Claymation of Futurama characters. +The image is a highly-detailed, symmetrical concept art depicting a full-body illustration of a character from the anime Saitama, with vibrant colors and a galaxy background. +A full body pic of a comic character drawn by Rob Liefield. +A still frame from the anime film Akira. +A small green dinosaur toy with orange spots standing on its hind legs and roaring with its mouth open. +An anime Spider-Man girl. +A 3D rendering of anime schoolgirls with a sad expression underwater, surrounded by dramatic lighting. +Teenage boy wearing a skull mask and smoking. +A colorful digital painting with a front view and anime-inspired vibes featuring a magical composition. +A vampire wearing Dr. Martens shoes. +Jack Pumpkinhead in the Land of Oz. +Portrait of young Jerry Lewis in comic style, colorized and created digitally by four artists. +A happy daffodil with big eyes, multiple leaf arms and vine legs, rendered in 3D Pixar style. +Head and shoulders portrait of Jinx from League of Legends of Arcane animated Series. +A one-eyed dwarf wizard holding a flagon in clean cel shaded vector art. +A painting of a koala wearing a princess dress and crown, with a confetti background. +A lemon with a McDonald's hat. +Ewoks swinging from Walmart rafters. +The image is of Pixel Art Huggy Wuggy performing a jumpscare. +Cartoonish illustration of a sci-fi machine shop inside a shipping container. +Australian soldiers surrendering to an emu. +A digital anime portrait of tatsumaki with green curly hair and green eyes wearing a jacket, featuring intricate details and atmospheric lighting, by artists wlop, Ilya Kuvshinov, and Krenz Cushart, trending on ArtStation. +A book about the history of Pepe the Frog. +A slime monster. +Francois Hollande depicted as a manga character in Japan. +The image features an anime girl in a short skirt and thigh-high socks, with a slim figure and accentuated hips. +Portrait of an anime princess in white and golden clothes. +A cute little anthropomorphic Tropical fish knight wearing a cape and a crown in short, pale blue armor. +Fullbody portrait of half-mouse anime girl by A-1 Pictures, trending on ArtStation. +An anthropomorphized block of tofu cartoon-style busting through a brick wall, inspired by Kool Aid man. +A kitten with a panda coloring eating bamboo. +The image is of Aunt Jemima in period attire on stage at Van's Warped Tour at 40 years old, throwing pancakes to the crowd. +A demon boy smiling while reading a book in a library. +Uncanny creatures with yin yang flagella in a surreal 1930s cartoon style. +Popcorn in mouth. +An anime-style demon princess is depicted in a digital painting. +A white polar bear cub wearing sunglasses sits in a meadow with flowers. +A plush of a cute sun-eating creature. +A cute rainbow kitten with different colored eyes in the chibi-style of Studio Ghibli is featured on a postcard. +Two Somali friends sitting and watching a Studio Ghibli movie. +An image of an aircraft carrier made of cheese. +Scene from Muppet Mad Max-Fury Road. +Family assembling missile in living room. +A lemon wearing a suit and tie, full body portrait. +Astolfo, an anime character, wearing a witch hat and lab coat, is flying on a broom while hexing in a forest. +A sci-fi machine shop in a shipping container, depicted in a manga-style digital painting with intricate details. +The image shows centric diatoms. +A spaceship in an empty landscape. +A mechanical planet amidst a space war with superships and exploding stars, featuring steampunk and clockpunk elements. +A photo of a mechanical angel woman with crystal wings, in the sci-fi style of Stefan Kostic, created by Stanley Lau and Artgerm. +Anthropomorphic alien plant creature with big eyes and leafy limbs depicted in a detailed painting. +A cyber girl with demon horns holds a black feather in front of a cybercity with a gloomy expression. +A massive and brightly colored spacecraft in a deserted landscape, depicted in retro 1960s sci-fi art. +A photorealistic image from a furry fandom convention set in a biopunk era after the genetic revolution and quantum singularity. +Gillian Anderson in a science fiction film directed by John Carpenter. +The image depicts a muscular blonde butch tomboy engineer wearing a patched flight suit, set in a detailed sci-fi environment with cinematic lighting and crepuscular rays. +A digital art image of a detailed surreal alien biomechanical temple interior by Giger. +A key visual of a young female swat officer with a neon futuristic gas mask in a cyberpunk setting. +A person wearing a lab coat holding a green apple and standing in front of a whiteboard filled with equations and diagrams. +Scene from the 1931 science fiction film "Escape from New York." +Sandman wearing black clothing, in a sci-fi themed digital painting by Greg Rutkowski. +A man watches an old TV while toxic slime and debris pour from the cracked steampunk ceiling into a glowing, neon-lit room. +A jellyfish sleeping in a space station pod. +Dr. Pepper floating in space, viewed through the window of a spaceship. +Alien abduction depicted in a blurry, old polaroid with lost and later found footage. +A ginger haired mouse mechanic in blue overalls in a cyberpunk scene with neon slums in the background. +The image is a symmetrical barge with a clearheaded textured fractal pattern. +Looking down at a destroyed city from a plane. +Witches performing a ritual in a dark mall. +The image depicts a concept car resembling a supercar or hypercar with a chrome reflection and global illumination. +The image is a front view of a mutant from Doom Eternal, with tubes fused to its body, and is a digital art masterpiece painted by Stanley Lau (Artgerm) and Greg Rutkowski. +A detailed and realistic fantasy Proto-Slavic skinny red troll creature. +A teddy bear mad scientist mixing chemicals depicted in oil painting style as a fantasy concept art piece. +Lava falls upon a crumbling stadium as crowds panic. +Image of a woman with snakes in her mouth, surrounded by flowers and a twisted branch background, created by various artists, with a dark and moody atmosphere. +A portrait of Childe Hassam in digital art style surrounded by other famous artists in HD. +Mixed media collage with broken glass photo and torn paper textures in a contemporary art style. +A cyber goth elf priestess strikes dramatic poses in a post-apocalyptic cyberpunk city with overgrown vegetation. +An alter made of bones with a glowing pineapple lamp on it and surrounded by candles, in front of a swirling mist with epic lighting. +A 3D render of a volcanic icon on a rocky background, in isometric perspective and darkly lit. +Sonic the Hedgehog depicted as a muscular Greek god in a highly detailed digital painting by Greg Rutkowski and Alphonse Mucha on Artstation. +The image is a highly detailed concept art of a medieval city, with a cinematic style and painted beautifully in oil by various artists including Wlop, Greg Rutkowski, and Artgerm, and can be viewed on Artstation. +An abandoned Israeli bus station in Tel Aviv depicted in a flat, colorful Ghibli-style digital artwork by Makoto Shinkai. +The image depicts alien flowers and plants surrounded by visceral exoskeletal formations in front of mythical mountains with dramatic contrast lighting, created with surreal hyper detailing in a 3D render. +Grogu is featured in the center of the image, with a cloudy sky, sun, and neon lights in the background, utilizing the rule of thirds and incorporating elements of retrofuturism and Studio Ghibli-inspired aesthetics. +Plasticine sculptures of two lovers walking through Paris with strict clothing and bright colors. +Undead army on riding beasts with symbols and music. +Kinetic wind sculpture in 3D render. +A warrior in glowing azure plate armor stands in a doorway to hell sliced by iridescent glass cracks, with crimson clouds and an art deco palace backdrop. +A male elf wearing heavy armor with a cape and a weathered face portrayed in detailed, smooth illustration. +A digital painting by Loish featuring a rush of half-body, cyberpunk androids and cyborgs adorned with intricate jewelry and colorful holographic dreads. +An investigator fights a tentacled monster in a finely detailed horror film still. +The image is a close up portrait of a man and a girl, with vibrant colors and a thermal background, resembling the style of Francis Bacon. +An intricate and elegant art deco-inspired metropolis with retrofuturistic elements in a cyberpunk style. +American cowboy with a scruffy appearance in a retrofuturistic style, inspired by the animations of Studio Ghibli. +A dragon standing in a forest, drinking river water. +White lines depict topography on a black background. +A planisphere lavalamp glows inside a glass jar buried in sand with swirling mist around it. +A full-body wide-angle photo of a wooden art doll by Agostino Arrivabene, depicting a peaceful sleeping pose. +A photorealistic 3D render of wooly mammoths grazing in a surreal mystical forest with a bright winding blue creek. +A portrait of a zebra with super detailed eyes and nose by various artists, posted on art platforms. +The image is titled "Queen of the Robots," created by artists Greg Rutowski, Victo Ngai, and Alphonse Mucha. +Close-up shot of a person running on a treadmill with worn running shoes under dramatic lighting and a comic book-style painting effect. +A point and click adventure based on Breaking Bad. +The image depicts a stunning supernova within a fantasy artwork on Artstation. +A surreal restaurant floating on water with architecture resembling Salvador Dalí's art, depicted in a film directed by Denis Villeneuve. +A highly detailed portrait of Yakuza 0's Goro Majima, featuring a variety of talented artists, created using Unreal Engine and featuring intricate environments. +A red-haired female knight with a golden prosthetic arm wields a long golden blade. +A key shot of an Australian Shepherd with a pastel color palette and dramatic lighting. +A black and white drawing of a road splitting the ocean leading to a giant eyeball looking at clouds in the distance. +Jim Carrey is eating giant hamburgers in a hyper-detailed, horror-inspired image trending on DeviantArt. +A photograph of a giant diamond skull in the ocean, featuring vibrant colors and detailed textures. +A woman's face in profile, with white carapace plates extruding from the skin and red kintsurugi. +Zarya from Overwatch jumping from a tall bridge into action. +A minimalist portrait of Rita Ora by Jean Giraud, inspired by the Moebius Starwatcher comic. +The image features baroque architecture by Escher and Jean Delville in the walled city of Kowloon, lit with golden lighting and displaying ornate details. +A hybrid creature concept painting of a zebra-striped unicorn with bunny ears and a colorful mane. +An image of a corn elemental, created as concept art for a high fantasy setting. +A raw green gemstone covered in black slime, with a shiny appearance, captured in a photorealistic digital art photograph. +Molten lava hanging from the ceiling creates art with octagons in a museum setting. +A giant nose made of water in a restaurant is depicted in a film still with unique art direction. +A red cobweb is seen inside a marble with an hourglass, lightning and intricate details, creating a sense of awe with swirling mist. +A village of stone buildings built on the side of a hill glowing with silver light. +A steampunk pocketwatch owl is trapped inside a glass jar buried in sand, surrounded by an hourglass and swirling mist. +A pocketwatch hangs from a steam punk hot air balloon amidst a swirling mist. +Small carved figurines of fantasy buildings, miniatures and standalone objects on a table. +A full body character of a mouse technomage in cyberpunk armor with a neon background, painted by jorsch. +A female human barbarian depicted in a traditional Dungeons and Dragons illustration. +Hooded figure standing over a ruined city with red haze and a grin. +A screenshot taken in hl2. +A digital illustration titled "The Rise of a Dadaist Government" by Jeffrey Smith and Tim Biskup. +A lady in a purple dress sitting in a tree - concept art. +A spray painted and analogue collage with canvas texture in a contemporary art style featuring a mathematically correct Tetris design. +The image features a closeup portrait of stone angel statues, created with the Unreal Engine and featuring intricate details by various artists. +A colorful tin toy robot runs a steam engine on a path near a beautiful flower meadow in the Swiss Alps with a mountain panorama in the background, captured in a long shot with motion blur and depth of field. +Male vampire of clan Banu Haqim with blue braided hair stands in a modern city at night surrounded by neon signs, jewelry, and tattoos. +The image is a fashion photograph of a humanoid lobster wearing a designer outfit with titanium claws, captured in red color and detailed texture by photographer Jovana Rikalo. +A giant cosmic tardigrade descending on Tokyo at sunset in a highly detailed concept art. +The image depicts Rengoku as Lucifer morningstar in a detailed digital painting available on ArtStation, featuring smooth render and sharp focus. +An image depicting a fantasy architectural concept with dramatic and cinematic touches through environmental concept art and an infographic-like display of marginalia. +"A samurai warrior made of smoke in Ghibli Studio's mystical and magical style." +A brightly painted temple with ornate structures and dramatic lighting inspired by Mayan and Islamic architecture. +Underwater concept art of marine life in Sea of Thieves featuring a wild boar. +Side-view blue-ice sneaker inspired by Spiderman created by Weta FX. +The image is a digital art depiction of a female angel warrior with detailed features by artist Magali Villeneuve. +A phantom airship. +An abstract painting depicting the balance between dark and light in nature with rough brushstrokes and fine details in natural colors. +Image of a foothpath in Indian summer with Zugspitze mountain in the background, painted by Sargent, Leyendecker, and Greg Hildebrandt. +A painting featuring two men in a fighting scene wearing black jodhpurs. +A town of pod homes integrated in a forest area with water and trees, depicted in a detailed watercolor by Lurid. +The image is a mixed media collage with broken glass and torn paper elements, featuring intricate oil details and a canvas texture, in a contemporary art style. +The Mona Lisa wearing headphones and listening to Lana Del Rey on a phone, depicted with photorealistic high detail. +"Albus Dumbledore" - a portrait of the headmaster of Hogwarts School of Witchcraft and Wizardry from the Harry Potter series. +The image is a Roy Lichtenstein emo portraying a woman with dark brown pixie hair, entirely black eyes, wearing a black tank top, leather jacket, skirt, choker, and boots. +"A portrait painting of Batman on Artstation." +Elvis Presley performing in a jumpsuit, artwork by Alessandro Pautasso. +The image depicts a gold filigree tree of life in a detailed fantasy painting. +A surreal painting by Ronny Khalil depicting a bestiary of wild emotion monsters repressed in the deep sea of the unconscious psyche, led by Baba Yaga, glowing with dramatic fire light as they prepare to escape in a revolution. +A man wearing a hat performs a magic trick for Jesus in a kitchen painting by Rockwell, Lovell, and Schoonover. +A graveyard at night with a moonlit grave under a sakura tree, rain falling, by Aleksandra Waliszewska. +The image depicts three female figures, known as the muses, playing musical instruments. +A person wearing black and white boots and pants sits in a portrait pose in a detailed fantasy painting by various artists. +Female human barbarian in a dungeon, illustrated by Jeff Easley for Dungeons and Dragons. +a pointillism ink painting of a Japanese demon with high detail. +A nebula forms the shape of a face in this detailed artwork. +The image depicts a Studio Ghibli-style painting of a colossal, ancient ruin with a road winding through the forest, overlooking a sunrise above the cloudy sea. +Michael Van Gerwen drinking a cup of tea in black and white Warhammer fantasy art. +Portrait of a Victorian gentleman standing on a balcony, richly detailed color illustration with cinematic lighting. +A painting of a woman by Zinaida Serebriakova wearing a T-shirt with the Supreme brand logo, a sleeveless white blouse, dark brown capris, and black loafers. +A painting of a starship landing by a temple, created by Hubert Robert. +Panorama of Hogwarts. +Close up portrait of a person speaking on the phone in front of a dark, geometrically abstract painting in the style of Sophie Taeuber-Arp, Gary Hume, and Tatsuro Kiuchi. +a shiny metallic renaissance steampunk robot in the style of Jan van Eyck. +Album art of a hand holding a balloon emerging from the water against a red sky. +A canvas artwork with art nouveau style depicting an ocean on fire, inspired by artists Reylia Slaby, Peter Gric, and Lexie Liu, featuring intricate and ornate details and volumetric lighting. +A painting depicting a foothpath at Indian summer with an epic evening sky at sunset and low thunder clouds. +A surreal dark painting featuring mythical creatures, by Ronny Khalil. +A frightened woman with blood on her face, wearing an emerald necklace, crouched in fear in a castle hallway. +Robot painted by Salvador Dali, resembling Wall-E. +The image features a person wearing jodhpurs, knee-high boots, and a leather jacket, painted in a fantasy style by various artists including Greg Manchess, Leyendecker, Greg Rutkowski, Greg Tocchini, James Gilleard, and Joe Fenton. +A portrait painting of a male deer in a suit sitting on a sofa near a window by John Singer Sargent. +The image is an ink and wash painting featuring vibrant colors, calligraphy, woodblock and geometric 3D shapes. +A detailed soft painting of a bat with golden rose flowers and amethyst stained glass in the background. +The image is a painting by Pierre-Auguste Renoir of an emo with short, messy brown hair, large entirely-black eyes, wearing a black tank top, leather jacket, knee-length skirt, choker, and boots. +Hyper realistic matte painting of the legendary ancient world of Shangrila. +Lady Britannia portrayed in crosshatching in 18th century art by William Hogarth. +The image is an oil on canvas painting of a broken man. +A watercolor portrait of a woman by Luke Rueda Studios and David Downton. +Cyrano de Bergerac holds Roxanne's hand in an illustrated scene by Csók István. +Geometric, colorful creature painted with rough brushstrokes on an abstract background by Pavel Lizano (2018). +Oil portrait of Gearless Joe by Greg Rutowski and Alphonse Mucha. +A portrait painting of Leighann Vail. +Oil portrait of a skeleton in a Victorian suit. +A colorful, detailed painting of a raccoon with a long, flowing mane reminiscent of a lion's, styled in a mohawk. +Two cats, one grey and one black, are wearing steampunk attire and standing in front of a ship in a heavily detailed painting. +Mythical beasts around a fire in a dark, surreal painting by Ronny Khalil. +Renaissance angel depicted in Gerhard Richter's oil painting. +Interior of a cathedral with a koi pond and surrounding greenery. +A silkscreen print of a woman with a short dark brown pixie cut, evil entirely-black eyes, wearing all black clothing. +The image features a painting of a person wearing white cloth fabric jodhpurs, with elements of fantasy and intricacy. +Portrait of Steve Buscemi by Ilya Kuvshinov. +A portrait of Phoenix Wright, painted in oil by Greg Rutowski and inspired by the art of Alphonse Mucha. +An award-winning portrait of a lemon in a muted, space age style reminiscent of the 1930s. +A painting by Raffaello Sanzi portraying Kajol and symbiots Riot during the Renaissance era, showcased on Artstation. +"Beeple's painting 'ethos of ego, mythos of id' on canvas features hyperrealistic photorealism." +Artwork by Shaun Tan. +Illustration of a sugar skull day of the dead girl. +A painting of the TV tower by Zaha Hadid. +A giant minotaur warrior holding a two handed axe, depicted in a dark fantasy digital painting. +A Japanese castle landscape painting trending on Artstation. +A depiction of Boudica, queen and warrior, with a battlefield in the background. +Mila Kunis portrayed as a fire elemental in a highly detailed digital painting. +The image is titled "Left 4 Dead" and was painted by John William Waterhouse. +A Proto-Slavic hero is depicted in an intricate and elegant digital painting by multiple artists on ArtStation. +A painting depicting a snowy winter scene featuring a river, a small house on a hill, and a dreamy cloudy sky. +Digital painting with vivid colors and a front view featuring a magical composition. +The image depicts a portrait of a panda by Petros Afshar. +An androgynous glam rocker poses outside CBGB in the style of Phil Hale. +A digital painting of a symmetric fantasy depiction of a Shinigami Japanese figure with highly detailed and realistic intricate port. +A magical hand reaching up on a dark-violet background, depicted through a digital painting. +An oil painting portrait of a beautiful dryad wearing an ombre velvet gown, with long hair and a tiara, adorned with dozens of jeweled necklaces, and illuminated with dramatic cinematic lighting. +Michal Karcz has painted a beautiful landscape, featuring purple trees with intricate and elegant details. +Portrait of a black man with open mouth, viewed from below, in the style of Lucian Freud. +Whimsical creatures in a dark, surreal painting by Ronny Khalil. +A falcon in flight, depicted in a highly detailed painting by Ilya Repin, Phil Hale, and Kent Williams. +A portrait painting of Lashaundra Garrette. +A polar expedition unloads from a ship in the 19th century in an intricate and elegant fantasy illustration. +Close-up view of ancient Greek ruins set against a colourful, starry night sky creating a mystical atmosphere. +A Monet portrait. +A portrait painting of Roxie Kownacki. +A cave painting of a Stone Age party. +An oil on canvas painting depicting a surreal cognitive illusion of a key, by artists Oleg Shupliak and Jeffrey Smith, with nods to Afrofuturism and surrealism. +A Weezer album painted by Hieronymus Bosch. +A painting by Ludek Marold of cattle grazing in a field, done in oil on canvas. +The image depicts a Sri Lankan king, created by artist Nizovtsev, Victor. +A semirealistic digital painting of a Japanese schoolgirl in a gentle grayish color palette, by Chinese artists on ArtStation. +A portrait of Frank Zappa smoking, with vivid neon colors, by various artists. +A portrait of Larry David playing poker by Sandra Chevrier, featured on Artstation. +A digital painting of a small universe with intricate details and ornate features, featuring the styles of Claude Monet and Vincent van Gogh. +A white Persian cat wearing a peacock feather headdress and surrounded by flowers, in a magical realism painting. +A purple and black generative art, with repeating biomorphic patterns by Patrick Heron. +The Magician by René Magritte. +Digital painting of a lush natural scene on an alien planet with colourful, weird vegetation, cliffs, and water by Gerald Brom. +Plankton creatures gathered around a fire geyser in a surreal, dark painting by Ronny Khalil. +Oil painting portrait of demon king with gazing eyes, art by John Howe, Keith Parkinson, and Larry Elmore, featured on ArtStation and CGSociety. +"Hyperrealistic acrylic painting on canvas by Junji Ito depicting the mythos of ego and ethos of id." +A train is moving along the track in the countryside. +a castle is in the middle of a eurpean city +A lot of building on each side of the road, with a very curvy road in the middle. +People standing in the grass playing with a frisbee. +A street light and other road signs. +Two men sitting in a green living room talking to a girl seen in the mirror. +Empty double decker London bus in rural area. +A plate topped with lots of different kinds of fruit. +an empty bench next to a busy street. +A person holding a remote while playing a game +Racing horse being guided by an asian man. +there is a old rusted train sitting on the ground +baseball player swinging metal bat at home plate. +Room full of people playing video games and having a party. +An airplane is flying with a white line in the sky above it. +Three zebras crowd around each other in a painting. +a plate that has some kind of food on it +a vehicle with many people even on top +The skater is riding very low on his board. +Snow boarder falling down in almost white out conditions of snow. +A brown and black dog sticking its head out a window. +A train is coming along near some unusual looking tracks. +A group of four friends commemorating a ski trip in the snow. +Very ornate bedroom with a chandelier over the bed. +A mother bear and her cub crossing a two lane road. +Two little dogs looking a large pizza sitting on a table. +an image of a messy counter in a bathroom +The bus is parked beside of the shopping center. +Side of a street, where there is a fire hydrant and a mirror showing the street. +Two young ladies seated with several other people at a dinner table. +A great room with the living area in the foreground, dining table behind it and kitchen in the very back. +The shadow of a person across an asphalt street next to a street sign. +There is a bicycle parked next to a car. +A man walking on the bea with his surfboard. +A lighted birthday cake with chunks of walnuts. +A woman wearing a bridal veil and suitcase poses with a man in a yellow tie. +a man walking alone down the street in a velvet jacket +A person holding a very small slice on pizza between their fingers. +A person flying a kite while standing in the grass. +A man stands in front of a bus and a car that have collided. +An airplane is flying down above a tree. +A woman riding skis across snow covered ground. +A bunch of people waiting in line by a rail. +A man playing with a tennis racquet, on a court. +A small dog looking at a white plate holding donuts. +A couple of men are standing outside their car watching sheep cross a road. +People rafting on a river while others ride on top of elephants. +A motorcycle that is sitting in the dirt. +A blue vase filled with yellow flowers on a window sill. +A giraffe walking through the grass towards a wall. +A baseball player pitching a baseball on a field. +Bikes sit parked under trees on a city street. +a young male holding a baseball bat in a baseball uniform +A clock tower with lighted clock faces, against a twilight sky. +a girl a white bear a camera a tea pot and a cup +An athletic middle aged male skier courses downhill. +a public transit bus on a city street +A tanker trunk is on it's side on the side of a road near a police car. +A bathroom stall containing an empty toilet in it. +A flag flying with kites all around it. +A train that is going by a building. +The feet of Lacrosse players scramble for the ball. +A little kid riding a skateboard down a sidewalk. +A man with his finger held up to his nose. +A man placing a turkey in an oven with oven mitts. +freshly baked donuts priced to sell at 60cents each +Some giraffes are walking around the zoo exhibit. +A pizza is displayed inside a pizza box. +a vintage photo of a man and a nurse +A blue airplane in a blue, cloudless sky +A large dog laying on a couch in a room. +A wooden trunk sitting outside with stickers on it. +A white Fed Ex sitting in front of a tall building. +there are three woman that are laying on the beach +Carrots bundled up for sale at an outside market +A nighstand topped with a white land-line phone, remote control, a metallic lamp, and two pens next to a black hardcover book. +There is a man and woman posing together in a restaurant +A tall giraffe in a zoo eating branches +A bear walks through a group of bushes with a plant in its mouth. +A giraffe stands inside a enclosure with a group of people standing near by. +A pigeon sitting on top of a white table. +A group of people on a field playing with a frisbee. +A view of a red light at Tenth Avenue +A sculpture that is planted in rocks in front of the water. +A bowl filled with apple slices and ice cream +people bicycling near the beach on a bicycle lane +A high shot of many people standing in an airport. +A portrait of a dinner dish of a protein and greens. +A big metal bed frame with no mattress on it. +Several different types of cell phones are shown here. +The umpire, catcher and batter during a baseball game +Kids are peering over a fence as a train rolls by. +Three cows eating in a field with sea in background. +A tall giraffe is eating out of a basket. +Group of birds sitting on top of a television antenna on a building. +a group of people wearing ski equipment on a snowy field +A stop sign with the phrase "hammer time" written on it. +Several bunches of unripened bananas growing from trees. +a cellphone standing on a counter next to a small statue of buddy christ +a train on a track with a car near by \ No newline at end of file diff --git a/scripts/evaluation/run_eval.sh b/scripts/evaluation/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4952dd810985cd3b56c950ebe2f0f28a18e7fc2 --- /dev/null +++ b/scripts/evaluation/run_eval.sh @@ -0,0 +1 @@ +torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="localhost" --master_port=29500 vis_flux.py \ No newline at end of file diff --git a/scripts/evaluation/test_clip_score.py b/scripts/evaluation/test_clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e9b75c06fdef8fb570015d531e0bbdbef936e1 --- /dev/null +++ b/scripts/evaluation/test_clip_score.py @@ -0,0 +1,67 @@ +from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer +import torch +from torchvision import transforms +from PIL import Image +import os +from tqdm import tqdm +from torch.nn import functional as F +from open_clip import create_model_from_pretrained, get_tokenizer + +def initialize_model(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_dict = {} + + processor = get_tokenizer('ViT-H-14') + reward_model, preprocess_dgn5b = create_model_from_pretrained( + 'local-dir:ckpt/clip_score') + reward_model.to(device).eval() + model_dict['model'] = reward_model + model_dict['preprocess_val'] = preprocess_dgn5b + + return model_dict, device + +def load_images_from_folder(folder): + images = [] + filenames = [] + for filename in os.listdir(folder): + if filename.endswith(".png"): + img_path = os.path.join(folder, filename) + image = Image.open(img_path).convert("RGB") + images.append(image) + filenames.append(filename) + return images, filenames + +def main(): + model_dict, device = initialize_model() + model = model_dict['model'] + preprocess_val = model_dict['preprocess_val'] + + tokenizer = get_tokenizer('ViT-H-14') + reward_model = model.to(device) + reward_model.eval() + + img_folder = "IMAGE_SAVE_FOLDER" + images, filenames = load_images_from_folder(img_folder) + + eval_rewards = [] + with torch.no_grad(): + for image_pil, filename in tqdm(zip(images, filenames), total=400): + + image = preprocess_val(image_pil).unsqueeze(0).to(device=device, non_blocking=True) + prompt = os.path.splitext(filename)[0] + text = tokenizer([prompt]).to(device=device, non_blocking=True) + + ## get score + clip_image_features = reward_model.encode_image(image) + clip_text_features = reward_model.encode_text(text) + clip_image_features = F.normalize(clip_image_features, dim=-1) + clip_text_features = F.normalize(clip_text_features, dim=-1) + clip_score = (clip_image_features @ clip_text_features.T)[0] + clip_score = clip_score.item() + eval_rewards.append(clip_score) + + avg_reward = sum(eval_rewards) / len(eval_rewards) if eval_rewards else 0 + print(f"Average CLIP score: {avg_reward:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/evaluation/test_hps_score.py b/scripts/evaluation/test_hps_score.py new file mode 100644 index 0000000000000000000000000000000000000000..1db1e2d835bca12e8882c686bb7b602b3636b2a1 --- /dev/null +++ b/scripts/evaluation/test_hps_score.py @@ -0,0 +1,77 @@ +from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer +import torch +from torchvision import transforms +from PIL import Image +import os +from tqdm import tqdm + +def initialize_model(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_dict = {} + model, preprocess_train, preprocess_val = create_model_and_transforms( + 'ViT-H-14', + '/mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K/pytorch_model.bin', + precision='amp', + device=device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False + ) + model_dict['model'] = model + model_dict['preprocess_val'] = preprocess_val + return model_dict, device + +def load_images_from_folder(folder): + images = [] + filenames = [] + for filename in os.listdir(folder): + if filename.endswith(".png"): + img_path = os.path.join(folder, filename) + image = Image.open(img_path).convert("RGB") + images.append(image) + filenames.append(filename) + return images, filenames + +def main(): + model_dict, device = initialize_model() + model = model_dict['model'] + preprocess_val = model_dict['preprocess_val'] + + cp = "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/hps/HPS_v2.1_compressed.pt" + checkpoint = torch.load(cp, map_location=device) + model.load_state_dict(checkpoint['state_dict']) + tokenizer = get_tokenizer('ViT-H-14') + reward_model = model.to(device) + reward_model.eval() + + img_folder = "IMAGE_SAVE_FOLDER" + images, filenames = load_images_from_folder(img_folder) + + eval_rewards = [] + with torch.no_grad(): + for image_pil, filename in tqdm(zip(images, filenames), total=400): + + image = preprocess_val(image_pil).unsqueeze(0).to(device=device, non_blocking=True) + prompt = os.path.splitext(filename)[0] # 剔除文件扩展名 + text = tokenizer([prompt]).to(device=device, non_blocking=True) + outputs = reward_model(image, text) + image_features, text_features = outputs["image_features"], outputs["text_features"] + logits_per_image = image_features @ text_features.T + hps_score = torch.diagonal(logits_per_image).item() # 转换为 Python 数值 + eval_rewards.append(hps_score) + + avg_reward = sum(eval_rewards) / len(eval_rewards) if eval_rewards else 0 + print(f"Average HPS score: {avg_reward:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/evaluation/test_imagereward_score.py b/scripts/evaluation/test_imagereward_score.py new file mode 100644 index 0000000000000000000000000000000000000000..686d5eb0bbf3f7c53047f0727306244a131ed809 --- /dev/null +++ b/scripts/evaluation/test_imagereward_score.py @@ -0,0 +1,54 @@ +import torch +from torchvision import transforms +from PIL import Image +import os +from tqdm import tqdm +from torch.nn import functional as F +from open_clip import create_model_from_pretrained, get_tokenizer +import ImageReward as RM + + +def initialize_model(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_dict = {} + ## download from https://huggingface.co/zai-org/ImageReward + model_path = "ckpt/ImageReward/ImageReward.pt" + config_path = "ckpt/ImageReward/med_config.json" + model = RM.load(model_path, device=device, med_config=config_path) + + return model, device + +def load_images_from_folder(folder): + images = [] + filenames = [] + for filename in os.listdir(folder): + if filename.endswith(".png"): + img_path = os.path.join(folder, filename) + image = Image.open(img_path).convert("RGB") + images.append(image) + filenames.append(filename) + return images, filenames + +def main(): + model, device = initialize_model() + + reward_model = model.to(device) + reward_model.eval() + + img_folder = "IMAGE_SAVE_FOLDER" + images, filenames = load_images_from_folder(img_folder) + + eval_rewards = [] + with torch.no_grad(): + for image_pil, filename in tqdm(zip(images, filenames), total=400): + prompt = os.path.splitext(filename)[0] + ## get score + rewards = reward_model.score(prompt, image_pil) + + eval_rewards.append(rewards) + + avg_reward = sum(eval_rewards) / len(eval_rewards) if eval_rewards else 0 + print(f"Average image reward score: {avg_reward:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/evaluation/test_pickscore_score.py b/scripts/evaluation/test_pickscore_score.py new file mode 100644 index 0000000000000000000000000000000000000000..efee5a27c42b6b470492ed2877bc2237ebba5d04 --- /dev/null +++ b/scripts/evaluation/test_pickscore_score.py @@ -0,0 +1,88 @@ +from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer +import torch +from torchvision import transforms +from PIL import Image +import os +from tqdm import tqdm +from torch.nn import functional as F +from open_clip import create_model_from_pretrained, get_tokenizer +from transformers import AutoProcessor, AutoModel + +def initialize_model(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_dict = {} + + process_path = "ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K" + # download from https://huggingface.co/yuvalkirstain/PickScore_v1 + model_path = "ckpt/PickScore_v1" + + processor = AutoProcessor.from_pretrained(process_path) + reward_model = AutoModel.from_pretrained(model_path) + reward_model.to(device).eval() + + model_dict['model'] = reward_model + model_dict['preprocess_val'] = processor + + return model_dict, device + +def load_images_from_folder(folder): + images = [] + filenames = [] + for filename in os.listdir(folder): + if filename.endswith(".png"): + img_path = os.path.join(folder, filename) + image = Image.open(img_path).convert("RGB") + images.append(image) + filenames.append(filename) + return images, filenames + +def main(): + model_dict, device = initialize_model() + model = model_dict['model'] + preprocess_val = model_dict['preprocess_val'] + + tokenizer = get_tokenizer('ViT-H-14') + reward_model = model.to(device) + reward_model.eval() + + img_folder = "IMAGE_SAVE_FOLDER" + images, filenames = load_images_from_folder(img_folder) + + eval_rewards = [] + with torch.no_grad(): + for image_pil, filename in tqdm(zip(images, filenames), total=400): + + image_inputs = preprocess_val( + images=[image_pil], + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ).to(device) + + prompt = os.path.splitext(filename)[0] # 剔除文件扩展名 + + text_inputs = preprocess_val( + text=prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ).to(device) + + # Get embeddings + image_embs = reward_model.get_image_features(**image_inputs) + image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) + + text_embs = reward_model.get_text_features(**text_inputs) + text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) + + # Calculate scores + score = reward_model.logit_scale.exp() * (text_embs @ image_embs.T)[0] + eval_rewards.append(score.item()) + + avg_reward = sum(eval_rewards) / len(eval_rewards) if eval_rewards else 0 + print(f"Average pickscore score: {avg_reward:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/huggingface/upload_hf.py b/scripts/huggingface/upload_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dc3957ebcfa53a33ee31c0539d74f448ac1cb3 --- /dev/null +++ b/scripts/huggingface/upload_hf.py @@ -0,0 +1,9 @@ +from huggingface_hub import HfApi + +api = HfApi() + +api.upload_folder( + folder_path="data/Black-Myth-Taylor-Src", + repo_id="FastVideo/Image-Vid-Finetune-Src", + repo_type="dataset", +) diff --git a/wandb/run-20260124_005321-uwv5dfod/files/config.yaml b/wandb/run-20260124_005321-uwv5dfod/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8012625d5fe056267ce7dc0113900ebde78dd4d5 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/files/config.yaml @@ -0,0 +1,88 @@ +_wandb: + value: + cli_version: 0.18.5 + m: [] + python_version: 3.10.19 + t: + "1": + - 1 + - 11 + - 41 + - 49 + - 55 + - 63 + - 71 + - 83 + - 98 + "2": + - 1 + - 11 + - 41 + - 49 + - 55 + - 63 + - 71 + - 83 + - 98 + "3": + - 13 + - 23 + - 55 + "4": 3.10.19 + "5": 0.18.5 + "6": 4.46.1 + "8": + - 5 + "12": 0.18.5 + "13": linux-x86_64 +allow_tf32: + value: true +logdir: + value: logs +mixed_precision: + value: bf16 +num_checkpoint_limit: + value: 5 +num_epochs: + value: 300 +pretrained: + value: + model: ./data/StableDiffusion + revision: main +prompt_fn: + value: imagenet_animals +resume_from: + value: "" +reward_fn: + value: hpsv2 +run_name: + value: 2026.01.24_00.53.09 +sample: + value: + batch_size: 1 + eta: 1 + guidance_scale: 5 + num_batches_per_epoch: 2 + num_steps: 50 +save_freq: + value: 20 +seed: + value: 42 +train: + value: + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1e-08 + adam_weight_decay: 0.0001 + adv_clip_max: 5 + batch_size: 1 + cfg: true + clip_range: 0.0001 + gradient_accumulation_steps: 1 + learning_rate: 1e-05 + max_grad_norm: 1 + num_inner_epochs: 1 + timestep_fraction: 1 + use_8bit_adam: false +use_lora: + value: false diff --git a/wandb/run-20260124_005321-uwv5dfod/files/output.log b/wandb/run-20260124_005321-uwv5dfod/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..4822a95915d0cb95880a28a99779f99983ac1d82 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/files/output.log @@ -0,0 +1,84 @@ +I0124 00:53:22.778190 137351200614208 train_g2rpo_sd_merge.py:510] +allow_tf32: true +logdir: logs +mixed_precision: bf16 +num_checkpoint_limit: 5 +num_epochs: 300 +pretrained: + model: ./data/StableDiffusion + revision: main +prompt_fn: imagenet_animals +prompt_fn_kwargs: {} +resume_from: '' +reward_fn: hpsv2 +run_name: 2026.01.24_00.53.09 +sample: + batch_size: 1 + eta: 1.0 + guidance_scale: 5.0 + num_batches_per_epoch: 2 + num_steps: 50 +save_freq: 20 +seed: 42 +train: + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1.0e-08 + adam_weight_decay: 0.0001 + adv_clip_max: 5 + batch_size: 1 + cfg: true + clip_range: 0.0001 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-05 + max_grad_norm: 1.0 + num_inner_epochs: 1 + timestep_fraction: 1.0 + use_8bit_adam: false +use_lora: false + +Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.48it/s] +Traceback (most recent call last): + File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 929, in + app.run(main) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 316, in run + _run_main(main, args) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 261, in _run_main + sys.exit(main(argv)) + File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 602, in main + unet, optimizer = accelerator.prepare(unet, optimizer) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1350, in prepare + result = tuple( + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1351, in + self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1226, in _prepare_one + return self.prepare_model(obj, device_placement=device_placement) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1477, in prepare_model + model = torch.nn.parallel.DistributedDataParallel( + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 858, in __init__ + _verify_param_shape_across_processes(self.process_group, parameters) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/distributed/utils.py", line 281, in _verify_param_shape_across_processes + return dist._verify_params_across_processes(process_group, tensors, logger) +RuntimeError: DDP expects same model across all ranks, but Rank 0 has 686 params, while rank 1 has inconsistent 0 params. +[rank0]: Traceback (most recent call last): +[rank0]: File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 929, in +[rank0]: app.run(main) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 316, in run +[rank0]: _run_main(main, args) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 261, in _run_main +[rank0]: sys.exit(main(argv)) +[rank0]: File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 602, in main +[rank0]: unet, optimizer = accelerator.prepare(unet, optimizer) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1350, in prepare +[rank0]: result = tuple( +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1351, in +[rank0]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1226, in _prepare_one +[rank0]: return self.prepare_model(obj, device_placement=device_placement) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1477, in prepare_model +[rank0]: model = torch.nn.parallel.DistributedDataParallel( +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 858, in __init__ +[rank0]: _verify_param_shape_across_processes(self.process_group, parameters) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/distributed/utils.py", line 281, in _verify_param_shape_across_processes +[rank0]: return dist._verify_params_across_processes(process_group, tensors, logger) +[rank0]: RuntimeError: DDP expects same model across all ranks, but Rank 0 has 686 params, while rank 1 has inconsistent 0 params. diff --git a/wandb/run-20260124_005321-uwv5dfod/files/requirements.txt b/wandb/run-20260124_005321-uwv5dfod/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee5c7ffa6079b296e15f3c9ff9edceed1bfe0802 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/files/requirements.txt @@ -0,0 +1,189 @@ +scipy==1.13.0 +regex==2024.9.11 +sentencepiece==0.2.0 +six==1.16.0 +anyio==4.11.0 +nvidia-cuda-nvrtc-cu12==12.6.77 +scikit-video==1.1.11 +platformdirs==4.5.0 +mypy==1.11.1 +ruff==0.6.5 +charset-normalizer==3.4.4 +torch==2.9.0+cu126 +av==13.1.0 +pillow==10.2.0 +gpustat==1.1.1 +torchvision==0.24.0+cu126 +multidict==6.7.0 +torchmetrics==1.5.1 +aiohttp==3.13.1 +transformers==4.46.1 +decord==0.6.0 +wcwidth==0.2.14 +sphinx-lint==1.0.0 +nvidia-cuda-runtime-cu12==12.6.77 +pytz==2025.2 +codespell==2.3.0 +hpsv2==1.2.0 +mypy_extensions==1.1.0 +numpy==1.26.3 +omegaconf==2.3.0 +Markdown==3.9 +tzdata==2025.2 +pandas==2.2.3 +pytorch-lightning==2.4.0 +aiosignal==1.4.0 +aiohappyeyeballs==2.6.1 +python-dateutil==2.9.0.post0 +seaborn==0.13.2 +beautifulsoup4==4.12.3 +isort==5.13.2 +httpx==0.28.1 +certifi==2025.10.5 +ml_collections==1.1.0 +nvidia-cudnn-cu12==9.10.2.21 +hf-xet==1.2.0 +requests==2.31.0 +inflect==6.0.4 +iniconfig==2.1.0 +braceexpand==0.1.7 +h5py==3.12.1 +wandb==0.18.5 +protobuf==3.20.3 +ninja==1.13.0 +kiwisolver==1.4.9 +networkx==3.3 +packaging==25.0 +fvcore==0.1.5.post20221221 +pyparsing==3.2.5 +starlette==0.41.3 +frozenlist==1.8.0 +docker-pycreds==0.4.0 +Werkzeug==3.1.3 +MarkupSafe==2.1.5 +einops==0.8.0 +sentry-sdk==2.42.0 +PyYAML==6.0.1 +nvidia-nccl-cu12==2.27.5 +datasets==4.3.0 +polib==1.2.0 +safetensors==0.6.2 +async-timeout==5.0.1 +setproctitle==1.3.7 +clint==0.5.1 +matplotlib==3.9.2 +propcache==0.4.1 +termcolor==3.1.0 +antlr4-python3-runtime==4.9.3 +cycler==0.12.1 +fastvideo==1.2.0 +toml==0.10.2 +xxhash==3.6.0 +wheel==0.44.0 +albumentations==1.4.20 +fastapi==0.115.3 +nvidia-cufft-cu12==11.3.0.4 +yarl==1.22.0 +psutil==7.1.0 +tensorboard-data-server==0.7.2 +pydantic==2.9.2 +nvidia-nvtx-cu12==12.6.77 +portalocker==3.2.0 +triton==3.5.0 +annotated-types==0.7.0 +proglog==0.1.12 +nvidia-cusparselt-cu12==0.7.1 +yapf==0.32.0 +Jinja2==3.1.6 +types-requests==2.32.4.20250913 +lightning-utilities==0.15.2 +grpcio==1.75.1 +uvicorn==0.32.0 +typing_extensions==4.15.0 +nvidia-nvjitlink-cu12==12.6.85 +watch==0.2.7 +moviepy==1.0.3 +timm==1.0.11 +pytest-split==0.8.0 +gdown==5.2.0 +types-setuptools==80.9.0.20250822 +nvidia-cusolver-cu12==11.7.1.2 +types-PyYAML==6.0.12.20250915 +pip==25.2 +qwen-vl-utils==0.0.14 +soupsieve==2.8 +zipp==3.23.0 +flash_attn==2.8.3 +yacs==0.1.8 +diffusers==0.32.0 +pluggy==1.6.0 +opencv-python-headless==4.11.0.86 +mpmath==1.3.0 +test_tube==0.7.5 +stringzilla==4.2.1 +fonttools==4.60.1 +nvidia-ml-py==13.580.82 +parameterized==0.9.0 +loguru==0.7.3 +tabulate==0.9.0 +idna==3.6 +iopath==0.1.10 +decorator==4.4.2 +nvidia-cufile-cu12==1.11.1.6 +threadpoolctl==3.6.0 +pyarrow==21.0.0 +httpcore==1.0.9 +hydra-core==1.3.2 +multiprocess==0.70.16 +contourpy==1.3.2 +clip==1.0 +tqdm==4.66.5 +open_clip_torch==3.2.0 +accelerate==1.0.1 +gitdb==4.0.12 +importlib_metadata==8.7.0 +nvidia-cublas-cu12==12.6.4.1 +h11==0.16.0 +filelock==3.19.1 +liger_kernel==0.4.1 +click==8.3.0 +urllib3==2.2.0 +imageio-ffmpeg==0.5.1 +setuptools==80.9.0 +joblib==1.5.2 +tensorboard==2.20.0 +attrs==25.4.0 +future==1.0.0 +albucore==0.0.19 +fsspec==2025.9.0 +sympy==1.14.0 +eval_type_backport==0.2.2 +pydantic_core==2.23.4 +sniffio==1.3.1 +nvidia-nvshmem-cu12==3.3.20 +exceptiongroup==1.3.0 +smmap==5.0.2 +tomli==2.0.2 +ftfy==6.3.0 +dill==0.4.0 +pytest==7.2.0 +PySocks==1.7.1 +nvidia-curand-cu12==10.3.7.77 +tokenizers==0.20.1 +args==0.1.0 +fairscale==0.4.13 +peft==0.13.2 +webdataset==1.0.2 +huggingface-hub==0.26.1 +GitPython==3.1.45 +pytorchvideo==0.1.5 +scikit-learn==1.5.2 +bitsandbytes==0.48.1 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cuda-cupti-cu12==12.6.80 +imageio==2.36.0 +pydub==0.25.1 +image-reward==1.5 +absl-py==2.3.1 +blessed==1.22.0 +torchdiffeq==0.2.4 diff --git a/wandb/run-20260124_005321-uwv5dfod/files/wandb-metadata.json b/wandb/run-20260124_005321-uwv5dfod/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..5158f40bd27c421e7c688b16142d600fa58efad0 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/files/wandb-metadata.json @@ -0,0 +1,96 @@ +{ + "os": "Linux-6.8.0-85-generic-x86_64-with-glibc2.35", + "python": "3.10.19", + "startedAt": "2026-01-23T16:53:21.730845Z", + "args": [ + "--config", + "fastvideo/config_sd/base.py", + "--eta_step_list", + "0,1,2,3,4,5,6,7", + "--eta_step_merge_list", + "1,1,1,2,2,2,3,3", + "--granular_list", + "1", + "--num_generations", + "4", + "--eta", + "1.0", + "--init_same_noise" + ], + "program": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", + "codePath": "fastvideo/train_g2rpo_sd_merge.py", + "email": "zhangemail1428@163.com", + "root": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code", + "host": "abc", + "username": "zsj", + "executable": "/home/zsj/anaconda3/envs/g2rpo/bin/python", + "codePathLocal": "fastvideo/train_g2rpo_sd_merge.py", + "cpu_count": 48, + "cpu_count_logical": 96, + "gpu": "NVIDIA RTX 5880 Ada Generation", + "gpu_count": 8, + "disk": { + "/": { + "total": "1006773899264", + "used": "812434640896" + } + }, + "memory": { + "total": "540697260032" + }, + "cpu": { + "count": 48, + "countLogical": 96 + }, + "gpu_nvidia": [ + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + } + ], + "cudaVersion": "12.9" +} \ No newline at end of file diff --git a/wandb/run-20260124_005321-uwv5dfod/files/wandb-summary.json b/wandb/run-20260124_005321-uwv5dfod/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..c47d14e1f66787d69cd7161169ba5ed1b76c36f7 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/files/wandb-summary.json @@ -0,0 +1 @@ +{"_wandb":{"runtime":666}} \ No newline at end of file diff --git a/wandb/run-20260124_005321-uwv5dfod/logs/debug-core.log b/wandb/run-20260124_005321-uwv5dfod/logs/debug-core.log new file mode 100644 index 0000000000000000000000000000000000000000..39f3e95ebb77e12a392b57b343b8fceb46c8f3ee --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/logs/debug-core.log @@ -0,0 +1,12 @@ +{"time":"2026-01-24T00:53:20.75562521+08:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpu1dm1d_5/port-587300.txt","pid":587300,"debug":false,"disable-analytics":false} +{"time":"2026-01-24T00:53:20.755664638+08:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false} +{"time":"2026-01-24T00:53:20.756652255+08:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":587300} +{"time":"2026-01-24T00:53:20.756592406+08:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":46113,"Zone":""}} +{"time":"2026-01-24T00:53:20.943895472+08:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:58190"} +{"time":"2026-01-24T00:53:21.7358555+08:00","level":"INFO","msg":"handleInformInit: received","streamId":"uwv5dfod","id":"127.0.0.1:58190"} +{"time":"2026-01-24T00:53:21.855581286+08:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"uwv5dfod","id":"127.0.0.1:58190"} +{"time":"2026-01-24T01:04:28.352849879+08:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:58190"} +{"time":"2026-01-24T01:04:28.353084799+08:00","level":"INFO","msg":"connection: Close: initiating connection closure","id":"127.0.0.1:58190"} +{"time":"2026-01-24T01:04:28.353137399+08:00","level":"INFO","msg":"server is shutting down"} +{"time":"2026-01-24T01:04:28.353218116+08:00","level":"INFO","msg":"connection: Close: connection successfully closed","id":"127.0.0.1:58190"} +{"time":"2026-01-24T01:04:29.663195793+08:00","level":"INFO","msg":"Parent process exited, terminating service process."} diff --git a/wandb/run-20260124_005321-uwv5dfod/logs/debug-internal.log b/wandb/run-20260124_005321-uwv5dfod/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..a33d1051d1db73628c00bb6a7c16deb1e591fe69 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/logs/debug-internal.log @@ -0,0 +1,15 @@ +{"time":"2026-01-24T00:53:21.736200931+08:00","level":"INFO","msg":"using version","core version":"0.18.5"} +{"time":"2026-01-24T00:53:21.736234192+08:00","level":"INFO","msg":"created symlink","path":"/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_005321-uwv5dfod/logs/debug-core.log"} +{"time":"2026-01-24T00:53:21.855497866+08:00","level":"INFO","msg":"created new stream","id":"uwv5dfod"} +{"time":"2026-01-24T00:53:21.855569627+08:00","level":"INFO","msg":"stream: started","id":"uwv5dfod"} +{"time":"2026-01-24T00:53:21.855783446+08:00","level":"INFO","msg":"handler: started","stream_id":{"value":"uwv5dfod"}} +{"time":"2026-01-24T00:53:21.855708998+08:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"uwv5dfod"}} +{"time":"2026-01-24T00:53:21.857158173+08:00","level":"INFO","msg":"sender: started","stream_id":"uwv5dfod"} +{"time":"2026-01-24T00:53:22.62886866+08:00","level":"INFO","msg":"Starting system monitor"} +{"time":"2026-01-24T01:04:28.352995417+08:00","level":"INFO","msg":"stream: closing","id":"uwv5dfod"} +{"time":"2026-01-24T01:04:28.35306103+08:00","level":"INFO","msg":"Stopping system monitor"} +{"time":"2026-01-24T01:04:28.354386663+08:00","level":"INFO","msg":"Stopped system monitor"} +{"time":"2026-01-24T01:04:28.679507012+08:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"} +{"time":"2026-01-24T01:04:28.679536167+08:00","level":"WARN","msg":"No source type found, not creating job artifact"} +{"time":"2026-01-24T01:04:28.67954769+08:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"} +{"time":"2026-01-24T01:04:29.628598981+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"} diff --git a/wandb/run-20260124_005321-uwv5dfod/logs/debug.log b/wandb/run-20260124_005321-uwv5dfod/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..31ef907ca0d9d43b24f2e8ef2cf7404dcd16e784 --- /dev/null +++ b/wandb/run-20260124_005321-uwv5dfod/logs/debug.log @@ -0,0 +1,27 @@ +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Current SDK version is 0.18.5 +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Configure stats pid to 587300 +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Loading settings from /home/zsj/.config/wandb/settings +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Loading settings from /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/settings +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Loading settings from environment variables: {} +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': None, '_disable_service': None} +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': 'fastvideo/train_g2rpo_sd_merge.py', 'program_abspath': '/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py', 'program': '/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py'} +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_setup.py:_flush():79] Applying login settings: {} +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_init.py:_log_setup():534] Logging user logs to /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_005321-uwv5dfod/logs/debug.log +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_init.py:_log_setup():535] Logging internal logs to /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_005321-uwv5dfod/logs/debug-internal.log +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_init.py:init():621] calling init triggers +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_init.py:init():628] wandb.init called with sweep_config: {} +config: {} +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_init.py:init():671] starting backend +2026-01-24 00:53:21,728 INFO MainThread:587300 [wandb_init.py:init():675] sending inform_init request +2026-01-24 00:53:21,729 INFO MainThread:587300 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn +2026-01-24 00:53:21,730 INFO MainThread:587300 [wandb_init.py:init():688] backend started and connected +2026-01-24 00:53:21,734 INFO MainThread:587300 [wandb_init.py:init():783] updated telemetry +2026-01-24 00:53:21,735 INFO MainThread:587300 [wandb_init.py:init():816] communicating run to backend with 90.0 second timeout +2026-01-24 00:53:22,622 INFO MainThread:587300 [wandb_init.py:init():867] starting run threads in backend +2026-01-24 00:53:22,773 INFO MainThread:587300 [wandb_run.py:_console_start():2463] atexit reg +2026-01-24 00:53:22,773 INFO MainThread:587300 [wandb_run.py:_redirect():2311] redirect: wrap_raw +2026-01-24 00:53:22,774 INFO MainThread:587300 [wandb_run.py:_redirect():2376] Wrapping output streams. +2026-01-24 00:53:22,774 INFO MainThread:587300 [wandb_run.py:_redirect():2401] Redirects installed. +2026-01-24 00:53:22,775 INFO MainThread:587300 [wandb_init.py:init():911] run started, returning control to user process +2026-01-24 00:53:22,776 INFO MainThread:587300 [wandb_run.py:_config_callback():1390] config_cb None None {'allow_tf32': True, 'logdir': 'logs', 'mixed_precision': 'bf16', 'num_checkpoint_limit': 5, 'num_epochs': 300, 'pretrained': {'model': './data/StableDiffusion', 'revision': 'main'}, 'prompt_fn': 'imagenet_animals', 'prompt_fn_kwargs': {}, 'resume_from': '', 'reward_fn': 'hpsv2', 'run_name': '2026.01.24_00.53.09', 'sample': {'batch_size': 1, 'eta': 1.0, 'guidance_scale': 5.0, 'num_batches_per_epoch': 2, 'num_steps': 50}, 'save_freq': 20, 'seed': 42, 'train': {'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'adam_weight_decay': 0.0001, 'adv_clip_max': 5, 'batch_size': 1, 'cfg': True, 'clip_range': 0.0001, 'gradient_accumulation_steps': 1, 'learning_rate': 1e-05, 'max_grad_norm': 1.0, 'num_inner_epochs': 1, 'timestep_fraction': 1.0, 'use_8bit_adam': False}, 'use_lora': False} +2026-01-24 01:04:28,353 WARNING MsgRouterThr:587300 [router.py:message_loop():77] message_loop has been closed diff --git a/wandb/run-20260124_132718-kot97lcx/files/output.log b/wandb/run-20260124_132718-kot97lcx/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..791b8cc60021b644e893fda0031f797ae08b11d3 --- /dev/null +++ b/wandb/run-20260124_132718-kot97lcx/files/output.log @@ -0,0 +1,40 @@ +I0124 13:27:19.117835 131683152287552 train_g2rpo_sd_merge.py:478] +allow_tf32: true +logdir: logs +mixed_precision: bf16 +num_checkpoint_limit: 5 +num_epochs: 300 +pretrained: + model: ./data/StableDiffusion + revision: main +prompt_fn: imagenet_animals +prompt_fn_kwargs: {} +resume_from: '' +reward_fn: hpsv2 +run_name: 2026.01.24_13.27.17 +sample: + batch_size: 1 + eta: 1.0 + guidance_scale: 5.0 + num_batches_per_epoch: 2 + num_steps: 50 +save_freq: 20 +seed: 42 +train: + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1.0e-08 + adam_weight_decay: 0.0001 + adv_clip_max: 5 + batch_size: 1 + cfg: true + clip_range: 0.0001 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-05 + max_grad_norm: 1.0 + num_inner_epochs: 1 + timestep_fraction: 1.0 + use_8bit_adam: false +use_lora: false + +Loading pipeline components...: 100%|███████████████████████████████████████| 7/7 [00:18<00:00, 2.58s/it] diff --git a/wandb/run-20260124_132718-kot97lcx/files/requirements.txt b/wandb/run-20260124_132718-kot97lcx/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee5c7ffa6079b296e15f3c9ff9edceed1bfe0802 --- /dev/null +++ b/wandb/run-20260124_132718-kot97lcx/files/requirements.txt @@ -0,0 +1,189 @@ +scipy==1.13.0 +regex==2024.9.11 +sentencepiece==0.2.0 +six==1.16.0 +anyio==4.11.0 +nvidia-cuda-nvrtc-cu12==12.6.77 +scikit-video==1.1.11 +platformdirs==4.5.0 +mypy==1.11.1 +ruff==0.6.5 +charset-normalizer==3.4.4 +torch==2.9.0+cu126 +av==13.1.0 +pillow==10.2.0 +gpustat==1.1.1 +torchvision==0.24.0+cu126 +multidict==6.7.0 +torchmetrics==1.5.1 +aiohttp==3.13.1 +transformers==4.46.1 +decord==0.6.0 +wcwidth==0.2.14 +sphinx-lint==1.0.0 +nvidia-cuda-runtime-cu12==12.6.77 +pytz==2025.2 +codespell==2.3.0 +hpsv2==1.2.0 +mypy_extensions==1.1.0 +numpy==1.26.3 +omegaconf==2.3.0 +Markdown==3.9 +tzdata==2025.2 +pandas==2.2.3 +pytorch-lightning==2.4.0 +aiosignal==1.4.0 +aiohappyeyeballs==2.6.1 +python-dateutil==2.9.0.post0 +seaborn==0.13.2 +beautifulsoup4==4.12.3 +isort==5.13.2 +httpx==0.28.1 +certifi==2025.10.5 +ml_collections==1.1.0 +nvidia-cudnn-cu12==9.10.2.21 +hf-xet==1.2.0 +requests==2.31.0 +inflect==6.0.4 +iniconfig==2.1.0 +braceexpand==0.1.7 +h5py==3.12.1 +wandb==0.18.5 +protobuf==3.20.3 +ninja==1.13.0 +kiwisolver==1.4.9 +networkx==3.3 +packaging==25.0 +fvcore==0.1.5.post20221221 +pyparsing==3.2.5 +starlette==0.41.3 +frozenlist==1.8.0 +docker-pycreds==0.4.0 +Werkzeug==3.1.3 +MarkupSafe==2.1.5 +einops==0.8.0 +sentry-sdk==2.42.0 +PyYAML==6.0.1 +nvidia-nccl-cu12==2.27.5 +datasets==4.3.0 +polib==1.2.0 +safetensors==0.6.2 +async-timeout==5.0.1 +setproctitle==1.3.7 +clint==0.5.1 +matplotlib==3.9.2 +propcache==0.4.1 +termcolor==3.1.0 +antlr4-python3-runtime==4.9.3 +cycler==0.12.1 +fastvideo==1.2.0 +toml==0.10.2 +xxhash==3.6.0 +wheel==0.44.0 +albumentations==1.4.20 +fastapi==0.115.3 +nvidia-cufft-cu12==11.3.0.4 +yarl==1.22.0 +psutil==7.1.0 +tensorboard-data-server==0.7.2 +pydantic==2.9.2 +nvidia-nvtx-cu12==12.6.77 +portalocker==3.2.0 +triton==3.5.0 +annotated-types==0.7.0 +proglog==0.1.12 +nvidia-cusparselt-cu12==0.7.1 +yapf==0.32.0 +Jinja2==3.1.6 +types-requests==2.32.4.20250913 +lightning-utilities==0.15.2 +grpcio==1.75.1 +uvicorn==0.32.0 +typing_extensions==4.15.0 +nvidia-nvjitlink-cu12==12.6.85 +watch==0.2.7 +moviepy==1.0.3 +timm==1.0.11 +pytest-split==0.8.0 +gdown==5.2.0 +types-setuptools==80.9.0.20250822 +nvidia-cusolver-cu12==11.7.1.2 +types-PyYAML==6.0.12.20250915 +pip==25.2 +qwen-vl-utils==0.0.14 +soupsieve==2.8 +zipp==3.23.0 +flash_attn==2.8.3 +yacs==0.1.8 +diffusers==0.32.0 +pluggy==1.6.0 +opencv-python-headless==4.11.0.86 +mpmath==1.3.0 +test_tube==0.7.5 +stringzilla==4.2.1 +fonttools==4.60.1 +nvidia-ml-py==13.580.82 +parameterized==0.9.0 +loguru==0.7.3 +tabulate==0.9.0 +idna==3.6 +iopath==0.1.10 +decorator==4.4.2 +nvidia-cufile-cu12==1.11.1.6 +threadpoolctl==3.6.0 +pyarrow==21.0.0 +httpcore==1.0.9 +hydra-core==1.3.2 +multiprocess==0.70.16 +contourpy==1.3.2 +clip==1.0 +tqdm==4.66.5 +open_clip_torch==3.2.0 +accelerate==1.0.1 +gitdb==4.0.12 +importlib_metadata==8.7.0 +nvidia-cublas-cu12==12.6.4.1 +h11==0.16.0 +filelock==3.19.1 +liger_kernel==0.4.1 +click==8.3.0 +urllib3==2.2.0 +imageio-ffmpeg==0.5.1 +setuptools==80.9.0 +joblib==1.5.2 +tensorboard==2.20.0 +attrs==25.4.0 +future==1.0.0 +albucore==0.0.19 +fsspec==2025.9.0 +sympy==1.14.0 +eval_type_backport==0.2.2 +pydantic_core==2.23.4 +sniffio==1.3.1 +nvidia-nvshmem-cu12==3.3.20 +exceptiongroup==1.3.0 +smmap==5.0.2 +tomli==2.0.2 +ftfy==6.3.0 +dill==0.4.0 +pytest==7.2.0 +PySocks==1.7.1 +nvidia-curand-cu12==10.3.7.77 +tokenizers==0.20.1 +args==0.1.0 +fairscale==0.4.13 +peft==0.13.2 +webdataset==1.0.2 +huggingface-hub==0.26.1 +GitPython==3.1.45 +pytorchvideo==0.1.5 +scikit-learn==1.5.2 +bitsandbytes==0.48.1 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cuda-cupti-cu12==12.6.80 +imageio==2.36.0 +pydub==0.25.1 +image-reward==1.5 +absl-py==2.3.1 +blessed==1.22.0 +torchdiffeq==0.2.4 diff --git a/wandb/run-20260124_132718-kot97lcx/files/wandb-metadata.json b/wandb/run-20260124_132718-kot97lcx/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6f125541f26c5e57f8a4801e0de66a0ea2557b6 --- /dev/null +++ b/wandb/run-20260124_132718-kot97lcx/files/wandb-metadata.json @@ -0,0 +1,96 @@ +{ + "os": "Linux-6.8.0-86-generic-x86_64-with-glibc2.35", + "python": "3.10.19", + "startedAt": "2026-01-24T05:27:18.297822Z", + "args": [ + "--config", + "fastvideo/config_sd/base.py", + "--eta_step_list", + "0,1,2,3,4,5,6,7", + "--eta_step_merge_list", + "1,1,1,2,2,2,3,3", + "--granular_list", + "1", + "--num_generations", + "4", + "--eta", + "1.0", + "--init_same_noise" + ], + "program": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", + "codePath": "fastvideo/train_g2rpo_sd_merge.py", + "email": "zhangemail1428@163.com", + "root": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code", + "host": "abc", + "username": "zsj", + "executable": "/home/zsj/anaconda3/envs/g2rpo/bin/python", + "codePathLocal": "fastvideo/train_g2rpo_sd_merge.py", + "cpu_count": 48, + "cpu_count_logical": 96, + "gpu": "NVIDIA RTX 5880 Ada Generation", + "gpu_count": 8, + "disk": { + "/": { + "total": "1006773899264", + "used": "803118243840" + } + }, + "memory": { + "total": "540697153536" + }, + "cpu": { + "count": 48, + "countLogical": 96 + }, + "gpu_nvidia": [ + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + } + ], + "cudaVersion": "12.9" +} \ No newline at end of file diff --git a/wandb/run-20260124_132718-kot97lcx/logs/debug-core.log b/wandb/run-20260124_132718-kot97lcx/logs/debug-core.log new file mode 100644 index 0000000000000000000000000000000000000000..7b2a7af649f26c5787e273a3abd72b7a645d07a4 --- /dev/null +++ b/wandb/run-20260124_132718-kot97lcx/logs/debug-core.log @@ -0,0 +1,8 @@ +{"time":"2026-01-24T13:27:17.479328305+08:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmp3yb0luu5/port-15390.txt","pid":15390,"debug":false,"disable-analytics":false} +{"time":"2026-01-24T13:27:17.479348762+08:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false} +{"time":"2026-01-24T13:27:17.480307676+08:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":15390} +{"time":"2026-01-24T13:27:17.480315566+08:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":41159,"Zone":""}} +{"time":"2026-01-24T13:27:17.653874599+08:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:40180"} +{"time":"2026-01-24T13:27:18.301714865+08:00","level":"INFO","msg":"handleInformInit: received","streamId":"kot97lcx","id":"127.0.0.1:40180"} +{"time":"2026-01-24T13:27:18.420914331+08:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"kot97lcx","id":"127.0.0.1:40180"} +{"time":"2026-01-24T13:27:37.439602212+08:00","level":"INFO","msg":"Parent process exited, terminating service process."} diff --git a/wandb/run-20260124_132718-kot97lcx/logs/debug-internal.log b/wandb/run-20260124_132718-kot97lcx/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..47fdcaf141dbac81c96d7f73b3581462faff5193 --- /dev/null +++ b/wandb/run-20260124_132718-kot97lcx/logs/debug-internal.log @@ -0,0 +1,8 @@ +{"time":"2026-01-24T13:27:18.302247221+08:00","level":"INFO","msg":"using version","core version":"0.18.5"} +{"time":"2026-01-24T13:27:18.302283615+08:00","level":"INFO","msg":"created symlink","path":"/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_132718-kot97lcx/logs/debug-core.log"} +{"time":"2026-01-24T13:27:18.42083917+08:00","level":"INFO","msg":"created new stream","id":"kot97lcx"} +{"time":"2026-01-24T13:27:18.42090629+08:00","level":"INFO","msg":"stream: started","id":"kot97lcx"} +{"time":"2026-01-24T13:27:18.421086737+08:00","level":"INFO","msg":"sender: started","stream_id":"kot97lcx"} +{"time":"2026-01-24T13:27:18.421086062+08:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"kot97lcx"}} +{"time":"2026-01-24T13:27:18.421218478+08:00","level":"INFO","msg":"handler: started","stream_id":{"value":"kot97lcx"}} +{"time":"2026-01-24T13:27:18.953305384+08:00","level":"INFO","msg":"Starting system monitor"} diff --git a/wandb/run-20260124_132718-kot97lcx/logs/debug.log b/wandb/run-20260124_132718-kot97lcx/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..4368942582c59c0e7417d8dd0b43863e6433517f --- /dev/null +++ b/wandb/run-20260124_132718-kot97lcx/logs/debug.log @@ -0,0 +1,26 @@ +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Current SDK version is 0.18.5 +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Configure stats pid to 15390 +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Loading settings from /home/zsj/.config/wandb/settings +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Loading settings from /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/settings +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Loading settings from environment variables: {} +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': None, '_disable_service': None} +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': 'fastvideo/train_g2rpo_sd_merge.py', 'program_abspath': '/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py', 'program': '/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py'} +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_setup.py:_flush():79] Applying login settings: {} +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_init.py:_log_setup():534] Logging user logs to /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_132718-kot97lcx/logs/debug.log +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_init.py:_log_setup():535] Logging internal logs to /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_132718-kot97lcx/logs/debug-internal.log +2026-01-24 13:27:18,294 INFO MainThread:15390 [wandb_init.py:init():621] calling init triggers +2026-01-24 13:27:18,295 INFO MainThread:15390 [wandb_init.py:init():628] wandb.init called with sweep_config: {} +config: {} +2026-01-24 13:27:18,295 INFO MainThread:15390 [wandb_init.py:init():671] starting backend +2026-01-24 13:27:18,295 INFO MainThread:15390 [wandb_init.py:init():675] sending inform_init request +2026-01-24 13:27:18,297 INFO MainThread:15390 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn +2026-01-24 13:27:18,297 INFO MainThread:15390 [wandb_init.py:init():688] backend started and connected +2026-01-24 13:27:18,301 INFO MainThread:15390 [wandb_init.py:init():783] updated telemetry +2026-01-24 13:27:18,301 INFO MainThread:15390 [wandb_init.py:init():816] communicating run to backend with 90.0 second timeout +2026-01-24 13:27:18,945 INFO MainThread:15390 [wandb_init.py:init():867] starting run threads in backend +2026-01-24 13:27:19,112 INFO MainThread:15390 [wandb_run.py:_console_start():2463] atexit reg +2026-01-24 13:27:19,112 INFO MainThread:15390 [wandb_run.py:_redirect():2311] redirect: wrap_raw +2026-01-24 13:27:19,113 INFO MainThread:15390 [wandb_run.py:_redirect():2376] Wrapping output streams. +2026-01-24 13:27:19,113 INFO MainThread:15390 [wandb_run.py:_redirect():2401] Redirects installed. +2026-01-24 13:27:19,114 INFO MainThread:15390 [wandb_init.py:init():911] run started, returning control to user process +2026-01-24 13:27:19,115 INFO MainThread:15390 [wandb_run.py:_config_callback():1390] config_cb None None {'allow_tf32': True, 'logdir': 'logs', 'mixed_precision': 'bf16', 'num_checkpoint_limit': 5, 'num_epochs': 300, 'pretrained': {'model': './data/StableDiffusion', 'revision': 'main'}, 'prompt_fn': 'imagenet_animals', 'prompt_fn_kwargs': {}, 'resume_from': '', 'reward_fn': 'hpsv2', 'run_name': '2026.01.24_13.27.17', 'sample': {'batch_size': 1, 'eta': 1.0, 'guidance_scale': 5.0, 'num_batches_per_epoch': 2, 'num_steps': 50}, 'save_freq': 20, 'seed': 42, 'train': {'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'adam_weight_decay': 0.0001, 'adv_clip_max': 5, 'batch_size': 1, 'cfg': True, 'clip_range': 0.0001, 'gradient_accumulation_steps': 1, 'learning_rate': 1e-05, 'max_grad_norm': 1.0, 'num_inner_epochs': 1, 'timestep_fraction': 1.0, 'use_8bit_adam': False}, 'use_lora': False} diff --git a/wandb/run-20260124_133035-730g3p9r/files/output.log b/wandb/run-20260124_133035-730g3p9r/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..65049c4d0b1155c1ded39b79b2e4144e73bcd839 --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/files/output.log @@ -0,0 +1,143 @@ +I0124 13:30:36.244537 134824147314496 train_g2rpo_sd_merge.py:478] +allow_tf32: true +logdir: logs +mixed_precision: bf16 +num_checkpoint_limit: 5 +num_epochs: 300 +pretrained: + model: ./data/StableDiffusion + revision: main +prompt_fn: imagenet_animals +prompt_fn_kwargs: {} +resume_from: '' +reward_fn: hpsv2 +run_name: 2026.01.24_13.30.34 +sample: + batch_size: 1 + eta: 1.0 + guidance_scale: 5.0 + num_batches_per_epoch: 2 + num_steps: 50 +save_freq: 20 +seed: 42 +train: + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1.0e-08 + adam_weight_decay: 0.0001 + adv_clip_max: 5 + batch_size: 1 + cfg: true + clip_range: 0.0001 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-05 + max_grad_norm: 1.0 + num_inner_epochs: 1 + timestep_fraction: 1.0 + use_8bit_adam: false +use_lora: false + +Loading pipeline components...: 100%|███████████████████████████████████████| 7/7 [00:02<00:00, 3.01it/s] +/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers + warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning) +I0124 13:30:39.152479 134824147314496 factory.py:159] Loaded ViT-H-14 model config. +I0124 13:30:43.499639 134824147314496 factory.py:207] Loading pretrained ViT-H-14 weights (./data/hps/open_clip_pytorch_model.bin). +I0124 13:30:59.466224 134824147314496 train_g2rpo_sd_merge.py:670] ***** Running E-GRPO (G2RPO) Training for Stable Diffusion ***** +I0124 13:30:59.467000 134824147314496 train_g2rpo_sd_merge.py:671] Num Epochs = 300 +I0124 13:30:59.467134 134824147314496 train_g2rpo_sd_merge.py:672] Num generations per prompt = 4 +I0124 13:30:59.467244 134824147314496 train_g2rpo_sd_merge.py:673] Eta step list = [0, 1, 2, 3, 4, 5, 6, 7] +I0124 13:30:59.467338 134824147314496 train_g2rpo_sd_merge.py:674] Eta step merge list = [1, 1, 1, 2, 2, 2, 3, 3] +I0124 13:30:59.467431 134824147314496 train_g2rpo_sd_merge.py:675] Granular list = [1] +I0124 13:31:39.156423 134824147314496 train_g2rpo_sd_merge.py:892] Epoch 0: Mean eval reward = 0.2323 +Traceback (most recent call last): + File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 1003, in + app.run(main) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 316, in run + _run_main(main, args) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 261, in _run_main + sys.exit(main(argv)) + File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 831, in main + final_latents = run_ode_sample_step( + File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 324, in run_ode_sample_step + noise_pred = pipeline.unet( + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl + return forward_call(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/utils/operations.py", line 820, in forward + return model_forward(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/utils/operations.py", line 808, in __call__ + return convert_to_fp32(self.model_forward(*args, **kwargs)) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1281, in forward + sample = upsample_block( + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl + return forward_call(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2618, in forward + hidden_states = attn( + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl + return forward_call(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 407, in forward + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 481, in _operate_on_continuous_inputs + hidden_states = self.norm(hidden_states) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl + return forward_call(*args, **kwargs) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 325, in forward + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/functional.py", line 2956, in group_norm + return torch.group_norm( +KeyboardInterrupt +[rank0]: Traceback (most recent call last): +[rank0]: File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 1003, in +[rank0]: app.run(main) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 316, in run +[rank0]: _run_main(main, args) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/absl/app.py", line 261, in _run_main +[rank0]: sys.exit(main(argv)) +[rank0]: File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 831, in main +[rank0]: final_latents = run_ode_sample_step( +[rank0]: File "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", line 324, in run_ode_sample_step +[rank0]: noise_pred = pipeline.unet( +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl +[rank0]: return self._call_impl(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl +[rank0]: return forward_call(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/utils/operations.py", line 820, in forward +[rank0]: return model_forward(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/accelerate/utils/operations.py", line 808, in __call__ +[rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs)) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast +[rank0]: return func(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1281, in forward +[rank0]: sample = upsample_block( +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl +[rank0]: return self._call_impl(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl +[rank0]: return forward_call(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2618, in forward +[rank0]: hidden_states = attn( +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl +[rank0]: return self._call_impl(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl +[rank0]: return forward_call(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 407, in forward +[rank0]: hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 481, in _operate_on_continuous_inputs +[rank0]: hidden_states = self.norm(hidden_states) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl +[rank0]: return self._call_impl(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl +[rank0]: return forward_call(*args, **kwargs) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 325, in forward +[rank0]: return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) +[rank0]: File "/home/zsj/anaconda3/envs/g2rpo/lib/python3.10/site-packages/torch/nn/functional.py", line 2956, in group_norm +[rank0]: return torch.group_norm( +[rank0]: KeyboardInterrupt diff --git a/wandb/run-20260124_133035-730g3p9r/files/requirements.txt b/wandb/run-20260124_133035-730g3p9r/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee5c7ffa6079b296e15f3c9ff9edceed1bfe0802 --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/files/requirements.txt @@ -0,0 +1,189 @@ +scipy==1.13.0 +regex==2024.9.11 +sentencepiece==0.2.0 +six==1.16.0 +anyio==4.11.0 +nvidia-cuda-nvrtc-cu12==12.6.77 +scikit-video==1.1.11 +platformdirs==4.5.0 +mypy==1.11.1 +ruff==0.6.5 +charset-normalizer==3.4.4 +torch==2.9.0+cu126 +av==13.1.0 +pillow==10.2.0 +gpustat==1.1.1 +torchvision==0.24.0+cu126 +multidict==6.7.0 +torchmetrics==1.5.1 +aiohttp==3.13.1 +transformers==4.46.1 +decord==0.6.0 +wcwidth==0.2.14 +sphinx-lint==1.0.0 +nvidia-cuda-runtime-cu12==12.6.77 +pytz==2025.2 +codespell==2.3.0 +hpsv2==1.2.0 +mypy_extensions==1.1.0 +numpy==1.26.3 +omegaconf==2.3.0 +Markdown==3.9 +tzdata==2025.2 +pandas==2.2.3 +pytorch-lightning==2.4.0 +aiosignal==1.4.0 +aiohappyeyeballs==2.6.1 +python-dateutil==2.9.0.post0 +seaborn==0.13.2 +beautifulsoup4==4.12.3 +isort==5.13.2 +httpx==0.28.1 +certifi==2025.10.5 +ml_collections==1.1.0 +nvidia-cudnn-cu12==9.10.2.21 +hf-xet==1.2.0 +requests==2.31.0 +inflect==6.0.4 +iniconfig==2.1.0 +braceexpand==0.1.7 +h5py==3.12.1 +wandb==0.18.5 +protobuf==3.20.3 +ninja==1.13.0 +kiwisolver==1.4.9 +networkx==3.3 +packaging==25.0 +fvcore==0.1.5.post20221221 +pyparsing==3.2.5 +starlette==0.41.3 +frozenlist==1.8.0 +docker-pycreds==0.4.0 +Werkzeug==3.1.3 +MarkupSafe==2.1.5 +einops==0.8.0 +sentry-sdk==2.42.0 +PyYAML==6.0.1 +nvidia-nccl-cu12==2.27.5 +datasets==4.3.0 +polib==1.2.0 +safetensors==0.6.2 +async-timeout==5.0.1 +setproctitle==1.3.7 +clint==0.5.1 +matplotlib==3.9.2 +propcache==0.4.1 +termcolor==3.1.0 +antlr4-python3-runtime==4.9.3 +cycler==0.12.1 +fastvideo==1.2.0 +toml==0.10.2 +xxhash==3.6.0 +wheel==0.44.0 +albumentations==1.4.20 +fastapi==0.115.3 +nvidia-cufft-cu12==11.3.0.4 +yarl==1.22.0 +psutil==7.1.0 +tensorboard-data-server==0.7.2 +pydantic==2.9.2 +nvidia-nvtx-cu12==12.6.77 +portalocker==3.2.0 +triton==3.5.0 +annotated-types==0.7.0 +proglog==0.1.12 +nvidia-cusparselt-cu12==0.7.1 +yapf==0.32.0 +Jinja2==3.1.6 +types-requests==2.32.4.20250913 +lightning-utilities==0.15.2 +grpcio==1.75.1 +uvicorn==0.32.0 +typing_extensions==4.15.0 +nvidia-nvjitlink-cu12==12.6.85 +watch==0.2.7 +moviepy==1.0.3 +timm==1.0.11 +pytest-split==0.8.0 +gdown==5.2.0 +types-setuptools==80.9.0.20250822 +nvidia-cusolver-cu12==11.7.1.2 +types-PyYAML==6.0.12.20250915 +pip==25.2 +qwen-vl-utils==0.0.14 +soupsieve==2.8 +zipp==3.23.0 +flash_attn==2.8.3 +yacs==0.1.8 +diffusers==0.32.0 +pluggy==1.6.0 +opencv-python-headless==4.11.0.86 +mpmath==1.3.0 +test_tube==0.7.5 +stringzilla==4.2.1 +fonttools==4.60.1 +nvidia-ml-py==13.580.82 +parameterized==0.9.0 +loguru==0.7.3 +tabulate==0.9.0 +idna==3.6 +iopath==0.1.10 +decorator==4.4.2 +nvidia-cufile-cu12==1.11.1.6 +threadpoolctl==3.6.0 +pyarrow==21.0.0 +httpcore==1.0.9 +hydra-core==1.3.2 +multiprocess==0.70.16 +contourpy==1.3.2 +clip==1.0 +tqdm==4.66.5 +open_clip_torch==3.2.0 +accelerate==1.0.1 +gitdb==4.0.12 +importlib_metadata==8.7.0 +nvidia-cublas-cu12==12.6.4.1 +h11==0.16.0 +filelock==3.19.1 +liger_kernel==0.4.1 +click==8.3.0 +urllib3==2.2.0 +imageio-ffmpeg==0.5.1 +setuptools==80.9.0 +joblib==1.5.2 +tensorboard==2.20.0 +attrs==25.4.0 +future==1.0.0 +albucore==0.0.19 +fsspec==2025.9.0 +sympy==1.14.0 +eval_type_backport==0.2.2 +pydantic_core==2.23.4 +sniffio==1.3.1 +nvidia-nvshmem-cu12==3.3.20 +exceptiongroup==1.3.0 +smmap==5.0.2 +tomli==2.0.2 +ftfy==6.3.0 +dill==0.4.0 +pytest==7.2.0 +PySocks==1.7.1 +nvidia-curand-cu12==10.3.7.77 +tokenizers==0.20.1 +args==0.1.0 +fairscale==0.4.13 +peft==0.13.2 +webdataset==1.0.2 +huggingface-hub==0.26.1 +GitPython==3.1.45 +pytorchvideo==0.1.5 +scikit-learn==1.5.2 +bitsandbytes==0.48.1 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cuda-cupti-cu12==12.6.80 +imageio==2.36.0 +pydub==0.25.1 +image-reward==1.5 +absl-py==2.3.1 +blessed==1.22.0 +torchdiffeq==0.2.4 diff --git a/wandb/run-20260124_133035-730g3p9r/files/wandb-metadata.json b/wandb/run-20260124_133035-730g3p9r/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..2920e41bf633b292eb639cb21451ea329285c40e --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/files/wandb-metadata.json @@ -0,0 +1,96 @@ +{ + "os": "Linux-6.8.0-86-generic-x86_64-with-glibc2.35", + "python": "3.10.19", + "startedAt": "2026-01-24T05:30:35.245116Z", + "args": [ + "--config", + "fastvideo/config_sd/base.py", + "--eta_step_list", + "0,1,2,3,4,5,6,7", + "--eta_step_merge_list", + "1,1,1,2,2,2,3,3", + "--granular_list", + "1", + "--num_generations", + "4", + "--eta", + "1.0", + "--init_same_noise" + ], + "program": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py", + "codePath": "fastvideo/train_g2rpo_sd_merge.py", + "email": "zhangemail1428@163.com", + "root": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code", + "host": "abc", + "username": "zsj", + "executable": "/home/zsj/anaconda3/envs/g2rpo/bin/python", + "codePathLocal": "fastvideo/train_g2rpo_sd_merge.py", + "cpu_count": 48, + "cpu_count_logical": 96, + "gpu": "NVIDIA RTX 5880 Ada Generation", + "gpu_count": 8, + "disk": { + "/": { + "total": "1006773899264", + "used": "803118403584" + } + }, + "memory": { + "total": "540697153536" + }, + "cpu": { + "count": 48, + "countLogical": 96 + }, + "gpu_nvidia": [ + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + } + ], + "cudaVersion": "12.9" +} \ No newline at end of file diff --git a/wandb/run-20260124_133035-730g3p9r/files/wandb-summary.json b/wandb/run-20260124_133035-730g3p9r/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..ccc9fc140b79e363c80ea4867fd0603c4ebd8067 --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/files/wandb-summary.json @@ -0,0 +1 @@ +{"epoch":0,"eval_reward_mean":0.2323201596736908,"_timestamp":1.7692327027971902e+09,"_runtime":88.509799362,"_step":1,"_wandb":{"runtime":88}} \ No newline at end of file diff --git a/wandb/run-20260124_133035-730g3p9r/logs/debug-core.log b/wandb/run-20260124_133035-730g3p9r/logs/debug-core.log new file mode 100644 index 0000000000000000000000000000000000000000..18194fe651ea5769a873347071aaf3058020572d --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/logs/debug-core.log @@ -0,0 +1,12 @@ +{"time":"2026-01-24T13:30:34.428962886+08:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpskpp_6ir/port-16468.txt","pid":16468,"debug":false,"disable-analytics":false} +{"time":"2026-01-24T13:30:34.42899829+08:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false} +{"time":"2026-01-24T13:30:34.429925965+08:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":42405,"Zone":""}} +{"time":"2026-01-24T13:30:34.430034082+08:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":16468} +{"time":"2026-01-24T13:30:34.61883156+08:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:36330"} +{"time":"2026-01-24T13:30:35.248196515+08:00","level":"INFO","msg":"handleInformInit: received","streamId":"730g3p9r","id":"127.0.0.1:36330"} +{"time":"2026-01-24T13:30:35.365853872+08:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"730g3p9r","id":"127.0.0.1:36330"} +{"time":"2026-01-24T13:32:03.754835769+08:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:36330"} +{"time":"2026-01-24T13:32:03.754897961+08:00","level":"INFO","msg":"connection: Close: initiating connection closure","id":"127.0.0.1:36330"} +{"time":"2026-01-24T13:32:03.754948774+08:00","level":"INFO","msg":"server is shutting down"} +{"time":"2026-01-24T13:32:03.754986715+08:00","level":"INFO","msg":"connection: Close: connection successfully closed","id":"127.0.0.1:36330"} +{"time":"2026-01-24T13:32:03.998463353+08:00","level":"INFO","msg":"Parent process exited, terminating service process."} diff --git a/wandb/run-20260124_133035-730g3p9r/logs/debug-internal.log b/wandb/run-20260124_133035-730g3p9r/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..3a029019773c9d7eb794e32dab2bf077b8432e5f --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/logs/debug-internal.log @@ -0,0 +1,11 @@ +{"time":"2026-01-24T13:30:35.248464828+08:00","level":"INFO","msg":"using version","core version":"0.18.5"} +{"time":"2026-01-24T13:30:35.248483451+08:00","level":"INFO","msg":"created symlink","path":"/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_133035-730g3p9r/logs/debug-core.log"} +{"time":"2026-01-24T13:30:35.365767136+08:00","level":"INFO","msg":"created new stream","id":"730g3p9r"} +{"time":"2026-01-24T13:30:35.365849238+08:00","level":"INFO","msg":"stream: started","id":"730g3p9r"} +{"time":"2026-01-24T13:30:35.365874251+08:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"730g3p9r"}} +{"time":"2026-01-24T13:30:35.365904813+08:00","level":"INFO","msg":"handler: started","stream_id":{"value":"730g3p9r"}} +{"time":"2026-01-24T13:30:35.365992322+08:00","level":"INFO","msg":"sender: started","stream_id":"730g3p9r"} +{"time":"2026-01-24T13:30:36.096372999+08:00","level":"INFO","msg":"Starting system monitor"} +{"time":"2026-01-24T13:32:03.754897796+08:00","level":"INFO","msg":"stream: closing","id":"730g3p9r"} +{"time":"2026-01-24T13:32:03.754937215+08:00","level":"INFO","msg":"Stopping system monitor"} +{"time":"2026-01-24T13:32:03.755857193+08:00","level":"INFO","msg":"Stopped system monitor"} diff --git a/wandb/run-20260124_133035-730g3p9r/logs/debug.log b/wandb/run-20260124_133035-730g3p9r/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..b51c6152b00a0d1dc49f280adba5c9ae33203a00 --- /dev/null +++ b/wandb/run-20260124_133035-730g3p9r/logs/debug.log @@ -0,0 +1,27 @@ +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Current SDK version is 0.18.5 +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Configure stats pid to 16468 +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Loading settings from /home/zsj/.config/wandb/settings +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Loading settings from /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/settings +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Loading settings from environment variables: {} +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': None, '_disable_service': None} +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': 'fastvideo/train_g2rpo_sd_merge.py', 'program_abspath': '/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py', 'program': '/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_sd_merge.py'} +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_setup.py:_flush():79] Applying login settings: {} +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_init.py:_log_setup():534] Logging user logs to /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_133035-730g3p9r/logs/debug.log +2026-01-24 13:30:35,241 INFO MainThread:16468 [wandb_init.py:_log_setup():535] Logging internal logs to /data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/wandb/run-20260124_133035-730g3p9r/logs/debug-internal.log +2026-01-24 13:30:35,242 INFO MainThread:16468 [wandb_init.py:init():621] calling init triggers +2026-01-24 13:30:35,242 INFO MainThread:16468 [wandb_init.py:init():628] wandb.init called with sweep_config: {} +config: {} +2026-01-24 13:30:35,242 INFO MainThread:16468 [wandb_init.py:init():671] starting backend +2026-01-24 13:30:35,242 INFO MainThread:16468 [wandb_init.py:init():675] sending inform_init request +2026-01-24 13:30:35,244 INFO MainThread:16468 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn +2026-01-24 13:30:35,244 INFO MainThread:16468 [wandb_init.py:init():688] backend started and connected +2026-01-24 13:30:35,248 INFO MainThread:16468 [wandb_init.py:init():783] updated telemetry +2026-01-24 13:30:35,249 INFO MainThread:16468 [wandb_init.py:init():816] communicating run to backend with 90.0 second timeout +2026-01-24 13:30:36,086 INFO MainThread:16468 [wandb_init.py:init():867] starting run threads in backend +2026-01-24 13:30:36,239 INFO MainThread:16468 [wandb_run.py:_console_start():2463] atexit reg +2026-01-24 13:30:36,239 INFO MainThread:16468 [wandb_run.py:_redirect():2311] redirect: wrap_raw +2026-01-24 13:30:36,239 INFO MainThread:16468 [wandb_run.py:_redirect():2376] Wrapping output streams. +2026-01-24 13:30:36,239 INFO MainThread:16468 [wandb_run.py:_redirect():2401] Redirects installed. +2026-01-24 13:30:36,241 INFO MainThread:16468 [wandb_init.py:init():911] run started, returning control to user process +2026-01-24 13:30:36,242 INFO MainThread:16468 [wandb_run.py:_config_callback():1390] config_cb None None {'allow_tf32': True, 'logdir': 'logs', 'mixed_precision': 'bf16', 'num_checkpoint_limit': 5, 'num_epochs': 300, 'pretrained': {'model': './data/StableDiffusion', 'revision': 'main'}, 'prompt_fn': 'imagenet_animals', 'prompt_fn_kwargs': {}, 'resume_from': '', 'reward_fn': 'hpsv2', 'run_name': '2026.01.24_13.30.34', 'sample': {'batch_size': 1, 'eta': 1.0, 'guidance_scale': 5.0, 'num_batches_per_epoch': 2, 'num_steps': 50}, 'save_freq': 20, 'seed': 42, 'train': {'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'adam_weight_decay': 0.0001, 'adv_clip_max': 5, 'batch_size': 1, 'cfg': True, 'clip_range': 0.0001, 'gradient_accumulation_steps': 1, 'learning_rate': 1e-05, 'max_grad_norm': 1.0, 'num_inner_epochs': 1, 'timestep_fraction': 1.0, 'use_8bit_adam': False}, 'use_lora': False} +2026-01-24 13:32:03,755 WARNING MsgRouterThr:16468 [router.py:message_loop():77] message_loop has been closed diff --git a/wandb/run-20260124_154730-hq86r6nt/files/requirements.txt b/wandb/run-20260124_154730-hq86r6nt/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e5e7de03dd36845d392ea536c1fc4a3e640eda3b --- /dev/null +++ b/wandb/run-20260124_154730-hq86r6nt/files/requirements.txt @@ -0,0 +1,191 @@ +scipy==1.13.0 +regex==2024.9.11 +sentencepiece==0.2.0 +six==1.16.0 +anyio==4.11.0 +nvidia-cuda-nvrtc-cu12==12.6.77 +scikit-video==1.1.11 +platformdirs==4.5.0 +mypy==1.11.1 +ruff==0.6.5 +charset-normalizer==3.4.4 +torch==2.9.0+cu126 +av==13.1.0 +pillow==10.2.0 +gpustat==1.1.1 +torchvision==0.24.0+cu126 +multidict==6.7.0 +torchmetrics==1.5.1 +aiohttp==3.13.1 +decord==0.6.0 +wcwidth==0.2.14 +sphinx-lint==1.0.0 +nvidia-cuda-runtime-cu12==12.6.77 +pytz==2025.2 +codespell==2.3.0 +hpsv2==1.2.0 +mypy_extensions==1.1.0 +numpy==1.26.3 +omegaconf==2.3.0 +Markdown==3.9 +tzdata==2025.2 +pandas==2.2.3 +pytorch-lightning==2.4.0 +aiosignal==1.4.0 +aiohappyeyeballs==2.6.1 +python-dateutil==2.9.0.post0 +seaborn==0.13.2 +beautifulsoup4==4.12.3 +isort==5.13.2 +httpx==0.28.1 +certifi==2025.10.5 +ml_collections==1.1.0 +nvidia-cudnn-cu12==9.10.2.21 +peft==0.18.1 +hf-xet==1.2.0 +requests==2.31.0 +inflect==6.0.4 +iniconfig==2.1.0 +braceexpand==0.1.7 +h5py==3.12.1 +wandb==0.18.5 +protobuf==3.20.3 +ninja==1.13.0 +kiwisolver==1.4.9 +networkx==3.3 +packaging==25.0 +fvcore==0.1.5.post20221221 +pyparsing==3.2.5 +starlette==0.41.3 +frozenlist==1.8.0 +docker-pycreds==0.4.0 +Werkzeug==3.1.3 +MarkupSafe==2.1.5 +shellingham==1.5.4 +einops==0.8.0 +sentry-sdk==2.42.0 +PyYAML==6.0.1 +nvidia-nccl-cu12==2.27.5 +datasets==4.3.0 +polib==1.2.0 +safetensors==0.6.2 +async-timeout==5.0.1 +setproctitle==1.3.7 +clint==0.5.1 +matplotlib==3.9.2 +propcache==0.4.1 +termcolor==3.1.0 +antlr4-python3-runtime==4.9.3 +transformers==4.57.6 +cycler==0.12.1 +fastvideo==1.2.0 +toml==0.10.2 +xxhash==3.6.0 +wheel==0.44.0 +albumentations==1.4.20 +fastapi==0.115.3 +nvidia-cufft-cu12==11.3.0.4 +yarl==1.22.0 +psutil==7.1.0 +tensorboard-data-server==0.7.2 +huggingface-hub==0.36.0 +pydantic==2.9.2 +nvidia-nvtx-cu12==12.6.77 +portalocker==3.2.0 +triton==3.5.0 +annotated-types==0.7.0 +proglog==0.1.12 +nvidia-cusparselt-cu12==0.7.1 +yapf==0.32.0 +Jinja2==3.1.6 +types-requests==2.32.4.20250913 +lightning-utilities==0.15.2 +grpcio==1.75.1 +uvicorn==0.32.0 +typing_extensions==4.15.0 +nvidia-nvjitlink-cu12==12.6.85 +watch==0.2.7 +moviepy==1.0.3 +timm==1.0.11 +pytest-split==0.8.0 +gdown==5.2.0 +types-setuptools==80.9.0.20250822 +nvidia-cusolver-cu12==11.7.1.2 +types-PyYAML==6.0.12.20250915 +pip==25.2 +typer-slim==0.21.1 +qwen-vl-utils==0.0.14 +soupsieve==2.8 +zipp==3.23.0 +flash_attn==2.8.3 +yacs==0.1.8 +pluggy==1.6.0 +opencv-python-headless==4.11.0.86 +mpmath==1.3.0 +test_tube==0.7.5 +stringzilla==4.2.1 +fonttools==4.60.1 +nvidia-ml-py==13.580.82 +parameterized==0.9.0 +loguru==0.7.3 +diffusers==0.36.0 +tabulate==0.9.0 +idna==3.6 +iopath==0.1.10 +decorator==4.4.2 +nvidia-cufile-cu12==1.11.1.6 +threadpoolctl==3.6.0 +pyarrow==21.0.0 +httpcore==1.0.9 +hydra-core==1.3.2 +multiprocess==0.70.16 +contourpy==1.3.2 +clip==1.0 +tqdm==4.66.5 +tokenizers==0.22.2 +open_clip_torch==3.2.0 +accelerate==1.0.1 +gitdb==4.0.12 +importlib_metadata==8.7.0 +nvidia-cublas-cu12==12.6.4.1 +h11==0.16.0 +filelock==3.19.1 +liger_kernel==0.4.1 +click==8.3.0 +urllib3==2.2.0 +imageio-ffmpeg==0.5.1 +setuptools==80.9.0 +joblib==1.5.2 +tensorboard==2.20.0 +attrs==25.4.0 +future==1.0.0 +albucore==0.0.19 +fsspec==2025.9.0 +sympy==1.14.0 +eval_type_backport==0.2.2 +pydantic_core==2.23.4 +sniffio==1.3.1 +nvidia-nvshmem-cu12==3.3.20 +exceptiongroup==1.3.0 +smmap==5.0.2 +tomli==2.0.2 +ftfy==6.3.0 +dill==0.4.0 +pytest==7.2.0 +PySocks==1.7.1 +nvidia-curand-cu12==10.3.7.77 +args==0.1.0 +fairscale==0.4.13 +webdataset==1.0.2 +GitPython==3.1.45 +pytorchvideo==0.1.5 +scikit-learn==1.5.2 +bitsandbytes==0.48.1 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cuda-cupti-cu12==12.6.80 +imageio==2.36.0 +pydub==0.25.1 +image-reward==1.5 +absl-py==2.3.1 +blessed==1.22.0 +torchdiffeq==0.2.4 diff --git a/wandb/run-20260124_154730-hq86r6nt/files/wandb-metadata.json b/wandb/run-20260124_154730-hq86r6nt/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..5c1eae790a92b246f2d216c4a31bbed4ed0670f7 --- /dev/null +++ b/wandb/run-20260124_154730-hq86r6nt/files/wandb-metadata.json @@ -0,0 +1,137 @@ +{ + "os": "Linux-6.8.0-86-generic-x86_64-with-glibc2.35", + "python": "3.10.19", + "startedAt": "2026-01-24T07:47:30.394392Z", + "args": [ + "--pretrained_model_name_or_path", + "./data/QwenImage", + "--data_json_path", + "./data/qwenimage_rl_embeddings/videos2caption.json", + "--output_dir", + "./output/g2rpo_qwenimage", + "--hps_path", + "./data/hps/HPS_v2.1_compressed.pt", + "--hps_clip_path", + "./data/hps/open_clip_pytorch_model.bin", + "--h", + "1024", + "--w", + "1024", + "--sampling_steps", + "16", + "--eta", + "0.7", + "--shift", + "3.0", + "--num_generations", + "12", + "--learning_rate", + "2e-6", + "--max_train_steps", + "301", + "--checkpointing_steps", + "50", + "--eta_step_list", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "--eta_step_merge_list", + "1", + "1", + "1", + "2", + "2", + "2", + "3", + "3", + "--granular_list", + "1", + "--init_same_noise", + "--clip_range", + "1e-4", + "--adv_clip_max", + "5.0", + "--use_hpsv2" + ], + "program": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code/fastvideo/train_g2rpo_qwenimage_merge.py", + "codePath": "fastvideo/train_g2rpo_qwenimage_merge.py", + "email": "zhangemail1428@163.com", + "root": "/data1/zsj/SceneDPO/Rebuttal/E-GRPO/scoure_code", + "host": "abc", + "username": "zsj", + "executable": "/home/zsj/anaconda3/envs/g2rpo/bin/python", + "codePathLocal": "fastvideo/train_g2rpo_qwenimage_merge.py", + "cpu_count": 48, + "cpu_count_logical": 96, + "gpu": "NVIDIA RTX 5880 Ada Generation", + "gpu_count": 8, + "disk": { + "/": { + "total": "1006773899264", + "used": "803119620096" + } + }, + "memory": { + "total": "540697153536" + }, + "cpu": { + "count": 48, + "countLogical": 96 + }, + "gpu_nvidia": [ + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + }, + { + "name": "NVIDIA RTX 5880 Ada Generation", + "memoryTotal": "51527024640", + "cudaCores": 14080, + "architecture": "Ada" + } + ], + "cudaVersion": "12.9" +} \ No newline at end of file