| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import argparse |
| | import torch |
| | from accelerate.logging import get_logger |
| | import cv2 |
| | import json |
| | import os |
| | import torch.distributed as dist |
| | import pandas as pd |
| | from torch.utils.data.dataset import ConcatDataset, Dataset |
| | import io |
| | import torchvision.transforms as transforms |
| | logger = get_logger(__name__) |
| | from torch.utils.data import Dataset |
| | from torch.utils.data.distributed import DistributedSampler |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | import re |
| | from diffusers import FluxPipeline |
| | from PIL import Image |
| | from diffusers.image_processor import VaeImageProcessor |
| |
|
| | def contains_chinese(text): |
| | return bool(re.search(r'[\u4e00-\u9fff]', text)) |
| |
|
| | class RFPTdataset(Dataset): |
| | def __init__( |
| | self, file_path, |
| | ): |
| | self.file_path = file_path |
| | file_names = os.listdir(self.file_path) |
| | self.file_names = [os.path.join(self.file_path, file_name) for file_name in file_names] |
| | self.train_dataset = self.read_data() |
| | self.transform = transforms.ToTensor() |
| | |
| | def read_data(self): |
| | df_list = [pd.read_parquet(file_name) for file_name in self.file_names] |
| | combined_df = pd.concat(df_list, axis=0, ignore_index=True) |
| | return combined_df |
| |
|
| | def __len__(self): |
| | return len(self.train_dataset) |
| | |
| | def __getitem__(self, index): |
| |
|
| | image = self.train_dataset.iloc[index]['image']['bytes'] |
| | image = self.transform(Image.open(io.BytesIO(image)).convert('RGB')) |
| | |
| |
|
| | caption = self.train_dataset.iloc[index]['caption_composition'] |
| | |
| | filename = str(index) |
| | if caption == None or image == None: |
| | return self.__getitem__(index+1) |
| | return dict(caption=caption, image=image, filename=filename) |
| |
|
| | class T5dataset(Dataset): |
| | def __init__( |
| | self, txt_path, vae_debug, |
| | ): |
| | self.txt_path = txt_path |
| | self.vae_debug = vae_debug |
| | with open(self.txt_path, "r", encoding="utf-8") as f: |
| | self.train_dataset = [ |
| | line for line in f.read().splitlines() if not contains_chinese(line) |
| | ][:50000] |
| |
|
| | def __getitem__(self, idx): |
| | |
| | caption = self.train_dataset[idx] |
| | filename = str(idx) |
| | |
| | if self.vae_debug: |
| | latents = torch.load( |
| | os.path.join( |
| | args.output_dir, "latent", self.train_dataset[idx]["latent_path"] |
| | ), |
| | map_location="cpu", |
| | ) |
| | else: |
| | latents = [] |
| |
|
| | return dict(caption=caption, latents=latents, filename=filename) |
| |
|
| | def __len__(self): |
| | return len(self.train_dataset) |
| |
|
| |
|
| | def main(args): |
| | local_rank = int(os.getenv("RANK", 0)) |
| | world_size = int(os.getenv("WORLD_SIZE", 1)) |
| | print("world_size", world_size, "local rank", local_rank) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | torch.cuda.set_device(local_rank) |
| | if not dist.is_initialized(): |
| | dist.init_process_group( |
| | backend="nccl", init_method="env://", world_size=world_size, rank=local_rank |
| | ) |
| |
|
| | os.makedirs(args.output_dir, exist_ok=True) |
| | os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) |
| | os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True) |
| | os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True) |
| | os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True) |
| |
|
| | |
| | |
| |
|
| | train_dataset = RFPTdataset(args.prompt_dir) |
| |
|
| |
|
| | sampler = DistributedSampler( |
| | train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True |
| | ) |
| |
|
| | train_dataloader = DataLoader( |
| | train_dataset, |
| | sampler=sampler, |
| | batch_size=args.train_batch_size, |
| | num_workers=args.dataloader_num_workers, |
| | ) |
| | flux_path = args.model_path |
| | pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16).to(device) |
| | image_processor = VaeImageProcessor(16) |
| |
|
| | json_data = [] |
| | for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): |
| | try: |
| | with torch.inference_mode(): |
| | if args.vae_debug: |
| | latents = data["latents"] |
| | for idx, video_name in enumerate(data["filename"]): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | prompt_embed_path = os.path.join( |
| | args.output_dir, "prompt_embed", video_name + ".pt" |
| | ) |
| | pooled_prompt_embeds_path = os.path.join( |
| | args.output_dir, "pooled_prompt_embeds", video_name + ".pt" |
| | ) |
| |
|
| | text_ids_path = os.path.join( |
| | args.output_dir, "text_ids", video_name + ".pt" |
| | ) |
| |
|
| | image_latents_path = os.path.join( |
| | args.output_dir, "images", video_name + ".pt" |
| | ) |
| | |
| | |
| | |
| | |
| | torch.save(data["image"].to(torch.bfloat16), image_latents_path) |
| | item = {} |
| | item["prompt_embed_path"] = video_name + ".pt" |
| | item["text_ids"] = video_name + ".pt" |
| | item["pooled_prompt_embeds_path"] = video_name + ".pt" |
| | item["caption"] = data["caption"][idx] |
| | json_data.append(item) |
| | except Exception as e: |
| | print(f"Rank {local_rank} Error: {repr(e)}") |
| | dist.barrier() |
| | raise |
| | dist.barrier() |
| | local_data = json_data |
| | gathered_data = [None] * world_size |
| | dist.all_gather_object(gathered_data, local_data) |
| | if local_rank == 0: |
| | |
| | all_json_data = [item for sublist in gathered_data for item in sublist] |
| | with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: |
| | json.dump(all_json_data, f, indent=4) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument("--model_path", type=str, default="data/mochi") |
| | parser.add_argument("--model_type", type=str, default="mochi") |
| | |
| | parser.add_argument( |
| | "--dataloader_num_workers", |
| | type=int, |
| | default=1, |
| | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", |
| | ) |
| | parser.add_argument( |
| | "--train_batch_size", |
| | type=int, |
| | default=1, |
| | help="Batch size (per device) for the training dataloader.", |
| | ) |
| | parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") |
| | parser.add_argument("--cache_dir", type=str, default="./cache_dir") |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default=None, |
| | help="The output directory where the model predictions and checkpoints will be written.", |
| | ) |
| | parser.add_argument("--vae_debug", action="store_true") |
| | parser.add_argument("--prompt_dir", type=str, default="./empty.txt") |
| | args = parser.parse_args() |
| | main(args) |