| import os
|
| import sys
|
| import logging
|
| import torch
|
| from tqdm import tqdm
|
| from omegaconf import OmegaConf
|
| from ovi.utils.io_utils import save_video
|
| from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt
|
| from ovi.utils.utils import get_arguments
|
| from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank
|
| from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info
|
| from ovi.ovi_fusion_engine import OviFusionEngine
|
|
|
|
|
|
|
| def _init_logging(rank):
|
|
|
| if rank == 0:
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format="[%(asctime)s] %(levelname)s: %(message)s",
|
| handlers=[logging.StreamHandler(stream=sys.stdout)])
|
| else:
|
| logging.basicConfig(level=logging.ERROR)
|
|
|
|
|
| def main(config, args):
|
|
|
| world_size = get_world_size()
|
| global_rank = get_global_rank()
|
| local_rank = get_local_rank()
|
| device = local_rank
|
| torch.cuda.set_device(local_rank)
|
| sp_size = config.get("sp_size", 1)
|
| assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size."
|
|
|
| _init_logging(global_rank)
|
|
|
| if world_size > 1:
|
| torch.distributed.init_process_group(
|
| backend="nccl",
|
| init_method="env://",
|
| rank=global_rank,
|
| world_size=world_size)
|
| else:
|
| assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}."
|
|
|
|
|
|
|
| initialize_sequence_parallel_state(sp_size)
|
| logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}")
|
|
|
| args.local_rank = local_rank
|
| args.device = device
|
| target_dtype = torch.bfloat16
|
|
|
|
|
| text_prompt = config.get("text_prompt")
|
| image_path = config.get("image_path", None)
|
| assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']"
|
| text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode"))
|
| if config.get("mode") != "i2v":
|
| logging.info(f"mode: {config.get('mode')}, setting all image_paths to None")
|
| image_paths = [None] * len(text_prompts)
|
| else:
|
| assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}"
|
|
|
| logging.info("Loading OVI Fusion Engine...")
|
| ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype)
|
| logging.info("OVI Fusion Engine loaded!")
|
|
|
| output_dir = config.get("output_dir", "./outputs")
|
| os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
| all_eval_data = list(zip(text_prompts, image_paths))
|
|
|
|
|
| use_sp = get_sequence_parallel_state()
|
| if use_sp:
|
| sp_size = nccl_info.sp_size
|
| sp_rank = nccl_info.rank_within_group
|
| sp_group_id = global_rank // sp_size
|
| num_sp_groups = world_size // sp_size
|
| else:
|
|
|
| sp_size = 1
|
| sp_rank = 0
|
| sp_group_id = global_rank
|
| num_sp_groups = world_size
|
|
|
|
|
| total_files = len(all_eval_data)
|
|
|
| require_sample_padding = False
|
|
|
| if total_files == 0:
|
| logging.error(f"ERROR: No evaluation files found")
|
| this_rank_eval_data = []
|
| else:
|
|
|
| remainder = total_files % num_sp_groups
|
| if require_sample_padding and remainder != 0:
|
| pad_count = num_sp_groups - remainder
|
| all_eval_data += [all_eval_data[0]] * pad_count
|
|
|
|
|
| this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups]
|
|
|
| for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)):
|
| video_frame_height_width = config.get("video_frame_height_width", None)
|
| seed = config.get("seed", 100)
|
| solver_name = config.get("solver_name", "unipc")
|
| sample_steps = config.get("sample_steps", 50)
|
| shift = config.get("shift", 5.0)
|
| video_guidance_scale = config.get("video_guidance_scale", 4.0)
|
| audio_guidance_scale = config.get("audio_guidance_scale", 3.0)
|
| slg_layer = config.get("slg_layer", 11)
|
| video_negative_prompt = config.get("video_negative_prompt", "")
|
| audio_negative_prompt = config.get("audio_negative_prompt", "")
|
| for idx in range(config.get("each_example_n_times", 1)):
|
| generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt,
|
| image_path=image_path,
|
| video_frame_height_width=video_frame_height_width,
|
| seed=seed+idx,
|
| solver_name=solver_name,
|
| sample_steps=sample_steps,
|
| shift=shift,
|
| video_guidance_scale=video_guidance_scale,
|
| audio_guidance_scale=audio_guidance_scale,
|
| slg_layer=slg_layer,
|
| video_negative_prompt=video_negative_prompt,
|
| audio_negative_prompt=audio_negative_prompt)
|
|
|
| if sp_rank == 0:
|
| formatted_prompt = format_prompt_for_filename(text_prompt)
|
| output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4")
|
| save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
|
| if generated_image is not None:
|
| generated_image.save(output_path.replace('.mp4', '.png'))
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| args = get_arguments()
|
| config = OmegaConf.load(args.config_file)
|
| main(config=config,args=args) |