|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from accelerate.logging import get_logger |
|
|
|
|
|
from fastvideo.utils.load import load_text_encoder |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
local_rank = int(os.getenv("RANK", 0)) |
|
|
world_size = int(os.getenv("WORLD_SIZE", 1)) |
|
|
print("world_size", world_size, "local rank", local_rank) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
torch.cuda.set_device(local_rank) |
|
|
if not dist.is_initialized(): |
|
|
dist.init_process_group(backend="nccl", |
|
|
init_method="env://", |
|
|
world_size=world_size, |
|
|
rank=local_rank) |
|
|
|
|
|
text_encoder = load_text_encoder(args.model_type, |
|
|
args.model_path, |
|
|
device=device) |
|
|
autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16 |
|
|
|
|
|
|
|
|
os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True) |
|
|
os.makedirs( |
|
|
os.path.join(args.output_dir, "validation", "prompt_attention_mask"), |
|
|
exist_ok=True, |
|
|
) |
|
|
os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"), |
|
|
exist_ok=True) |
|
|
|
|
|
with open(args.validation_prompt_txt, "r", encoding="utf-8") as file: |
|
|
lines = file.readlines() |
|
|
prompts = [line.strip() for line in lines] |
|
|
for prompt in prompts: |
|
|
with torch.inference_mode(): |
|
|
with torch.autocast("cuda", dtype=autocast_type): |
|
|
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( |
|
|
prompt) |
|
|
file_name = prompt.split(".")[0] |
|
|
prompt_embed_path = os.path.join(args.output_dir, "validation", |
|
|
"prompt_embed", |
|
|
f"{file_name}.pt") |
|
|
prompt_attention_mask_path = os.path.join( |
|
|
args.output_dir, |
|
|
"validation", |
|
|
"prompt_attention_mask", |
|
|
f"{file_name}.pt", |
|
|
) |
|
|
torch.save(prompt_embeds[0], prompt_embed_path) |
|
|
torch.save(prompt_attention_mask[0], |
|
|
prompt_attention_mask_path) |
|
|
print(f"sample {file_name} saved") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--model_path", type=str, default="data/mochi") |
|
|
parser.add_argument("--model_type", type=str, default="mochi") |
|
|
parser.add_argument("--validation_prompt_txt", type=str) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default=None, |
|
|
help= |
|
|
"The output directory where the model predictions and checkpoints will be written.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|