egrpo / fastvideo /data_preprocess /preprocess_validation_text_embeddings.py
studyOverflow's picture
Add files using upload-large-folder tool
b171568 verified
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
import argparse
import os
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from fastvideo.utils.load import load_text_encoder
logger = get_logger(__name__)
def main(args):
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size, "local rank", local_rank)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend="nccl",
init_method="env://",
world_size=world_size,
rank=local_rank)
text_encoder = load_text_encoder(args.model_type,
args.model_path,
device=device)
autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16
# output_dir/validation/prompt_attention_mask
# output_dir/validation/prompt_embed
os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
os.makedirs(
os.path.join(args.output_dir, "validation", "prompt_attention_mask"),
exist_ok=True,
)
os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"),
exist_ok=True)
with open(args.validation_prompt_txt, "r", encoding="utf-8") as file:
lines = file.readlines()
prompts = [line.strip() for line in lines]
for prompt in prompts:
with torch.inference_mode():
with torch.autocast("cuda", dtype=autocast_type):
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
prompt)
file_name = prompt.split(".")[0]
prompt_embed_path = os.path.join(args.output_dir, "validation",
"prompt_embed",
f"{file_name}.pt")
prompt_attention_mask_path = os.path.join(
args.output_dir,
"validation",
"prompt_attention_mask",
f"{file_name}.pt",
)
torch.save(prompt_embeds[0], prompt_embed_path)
torch.save(prompt_attention_mask[0],
prompt_attention_mask_path)
print(f"sample {file_name} saved")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--model_type", type=str, default="mochi")
parser.add_argument("--validation_prompt_txt", type=str)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help=
"The output directory where the model predictions and checkpoints will be written.",
)
args = parser.parse_args()
main(args)