| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import glob |
| | import json |
| | import os |
| | import random |
| |
|
| | import torch |
| | import torchvision |
| | from einops import rearrange |
| | from huggingface_hub import snapshot_download |
| | from nemo.collections.diffusion.models.model import DiT7BConfig |
| | from tqdm import tqdm |
| | from transformers import T5EncoderModel, T5TokenizerFast |
| |
|
| | from .log import log |
| |
|
| |
|
| | def get_parser(): |
| | parser = argparse.ArgumentParser(description="Process some configurations.") |
| | parser.add_argument("--tokenizer_dir", type=str, default="", help="Path to the VAE model") |
| | parser.add_argument( |
| | "--dataset_path", type=str, default="video_dataset", help="Path to the dataset (a folder of videos)" |
| | ) |
| | parser.add_argument("--output_path", type=str, default="video_dataset_cached", help="Path to the output directory") |
| | parser.add_argument("--prompt", type=str, default="a video of sks.", help="Prompt for the video") |
| | parser.add_argument("--num_chunks", type=int, default=5, help="Number of random chunks to sample per video") |
| | parser.add_argument("--height", type=int, default=704, help="Height to resize video") |
| | parser.add_argument("--width", type=int, default=1280, help="Width to resize video") |
| | return parser |
| |
|
| |
|
| | def init_t5(): |
| | """Initialize and return the T5 tokenizer and text encoder.""" |
| | tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b") |
| | text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b") |
| | text_encoder.to("cuda") |
| | text_encoder.eval() |
| | return tokenizer, text_encoder |
| |
|
| |
|
| | def init_video_tokenizer(tokenizer_dir: str): |
| | """Initialize and return the Cosmos Video tokenizer.""" |
| | dit_config = DiT7BConfig(vae_path=tokenizer_dir) |
| | vae = dit_config.configure_vae() |
| | return vae |
| |
|
| |
|
| | @torch.no_grad() |
| | def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512): |
| | """ |
| | Encode a batch of text prompts to a batch of T5 embeddings. |
| | Parameters: |
| | tokenizer: T5 embedding tokenizer. |
| | encoder: T5 embedding text encoder. |
| | prompts: A batch of text prompts. |
| | max_length: Sequence length of text embedding (defaults to 512). |
| | """ |
| |
|
| | batch_encoding = tokenizer.batch_encode_plus( |
| | prompts, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding="max_length", |
| | max_length=max_length, |
| | return_length=True, |
| | return_offsets_mapping=False, |
| | ) |
| |
|
| | |
| | input_ids = batch_encoding.input_ids.cuda() |
| | attn_mask = batch_encoding.attention_mask.cuda() |
| |
|
| | outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) |
| | encoded_text = outputs.last_hidden_state |
| |
|
| | lengths = attn_mask.sum(dim=1).cpu() |
| | for batch_id in range(encoded_text.shape[0]): |
| | encoded_text[batch_id][lengths[batch_id] :] = 0 |
| |
|
| | return encoded_text |
| |
|
| |
|
| | def main(args): |
| | |
| | os.makedirs(args.output_path, exist_ok=True) |
| |
|
| | |
| | tokenizer, text_encoder = init_t5() |
| |
|
| | |
| | if args.tokenizer_dir == "": |
| | args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") |
| | vae = init_video_tokenizer(args.tokenizer_dir) |
| |
|
| | |
| | t5_embeding_max_length = 512 |
| | chunk_duration = vae.video_vae.pixel_chunk_duration |
| | cnt = 0 |
| |
|
| | |
| | files = glob.glob(os.path.join(args.dataset_path, "*.mp4")) |
| | if not files: |
| | raise ValueError(f"Dataset path {args.dataset_path} does not contain any .mp4 files.") |
| |
|
| | |
| | with torch.no_grad(): |
| | for video_path in tqdm(glob.glob(os.path.join(args.dataset_path, "*.mp4"))): |
| | |
| | video, _, meta = torchvision.io.read_video(video_path) |
| | T, H, W, C = video.shape |
| |
|
| | |
| | if T < chunk_duration: |
| | log.info(f"Video {video_path} is shorter than {chunk_duration} frames. Skipped.") |
| | continue |
| |
|
| | |
| | for _ in range(args.num_chunks): |
| | start_idx = random.randint(0, T - chunk_duration) |
| | chunk = video[start_idx : start_idx + chunk_duration] |
| |
|
| | |
| | chunk = rearrange(chunk, "t h w c -> t c h w") |
| |
|
| | |
| | chunk = torchvision.transforms.functional.resize(chunk, [args.height, args.width]) |
| |
|
| | |
| | chunk = rearrange(chunk, "(b t) c h w -> b c t h w", b=1) |
| |
|
| | |
| | chunk = chunk.to(device="cuda", dtype=torch.bfloat16, non_blocking=True) / 127.5 - 1.0 |
| |
|
| | |
| | latent = vae.encode(chunk).cpu() |
| |
|
| | |
| | out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0] |
| | encoded_text = torch.tensor(out, dtype=torch.bfloat16) |
| |
|
| | |
| | L, C_ = encoded_text.shape |
| | t5_embed = torch.zeros(1, t5_embeding_max_length, C_, dtype=torch.bfloat16) |
| | t5_embed[0, :L] = encoded_text |
| |
|
| | |
| | torch.save(latent[0], os.path.join(args.output_path, f"{cnt}.video_latent.pth")) |
| | torch.save(t5_embed[0], os.path.join(args.output_path, f"{cnt}.t5_text_embeddings.pth")) |
| |
|
| | |
| | torch.save( |
| | torch.ones(512, dtype=torch.bfloat16), os.path.join(args.output_path, f"{cnt}.t5_text_mask.pth") |
| | ) |
| |
|
| | |
| | info = { |
| | "height": H, |
| | "width": W, |
| | "fps": meta["video_fps"], |
| | "num_frames": chunk_duration, |
| | "video_path": os.path.basename(video_path), |
| | "start_frame": start_idx, |
| | } |
| | with open(os.path.join(args.output_path, f"{cnt}.info.json"), "w") as json_file: |
| | json.dump(info, json_file) |
| |
|
| | cnt += 1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = get_parser() |
| | args = parser.parse_args() |
| | main(args) |
| |
|