OmniCustom / infer.py
sunnyday1307
Initial commit
0f4bcb8
import logging
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from distributed_comms.parallel_states import (
get_sequence_parallel_state,
initialize_sequence_parallel_state,
nccl_info,
)
from distributed_comms.util import get_global_rank, get_local_rank, get_world_size
from ovi_fusion_engine import OviFusionEngine
from utils.io_utils import save_video
from utils.processing_utils import (
format_prompt_for_filename,
validate_and_process_user_prompt,
)
from utils.utils import get_arguments
GenerationItem = Tuple[str, Optional[str], Optional[str], Optional[str]]
ALLOWED_MODES = {"id2v", "t2v", "i2v", "t2i2v"}
@dataclass(frozen=True)
class RuntimeState:
world_size: int
global_rank: int
local_rank: int
device: int
sp_rank: int
sp_group_id: int
num_sp_groups: int
class SequenceNumberManager:
def __init__(self, output_dir: str) -> None:
self.output_dir = output_dir
self._next_by_condition: Dict[str, int] = {}
def next(self, condition_dir: str) -> int:
if condition_dir not in self._next_by_condition:
self._next_by_condition[condition_dir] = self._scan_next(condition_dir)
value = self._next_by_condition[condition_dir]
self._next_by_condition[condition_dir] += 1
return value
def _scan_next(self, condition_dir: str) -> int:
condition_output_dir = os.path.join(self.output_dir, condition_dir)
if not os.path.exists(condition_output_dir):
return 1
max_sequence = 0
for filename in os.listdir(condition_output_dir):
if not (filename.endswith(".mp4") or filename.endswith(".png")):
continue
parts = filename.split("_")
if parts and parts[0].isdigit():
max_sequence = max(max_sequence, int(parts[0]))
return max_sequence + 1
def _init_logging(rank: int) -> None:
if rank == 0:
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)],
)
else:
logging.basicConfig(level=logging.ERROR)
def _initialize_runtime(config, args) -> RuntimeState:
world_size = get_world_size()
global_rank = get_global_rank()
local_rank = get_local_rank()
device = local_rank
torch.cuda.set_device(local_rank)
sp_size = config.get("sp_size", 1)
assert sp_size <= world_size and world_size % sp_size == 0, (
"sp_size must be less than or equal to world_size and world_size "
"must be divisible by sp_size."
)
_init_logging(global_rank)
if world_size > 1:
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
rank=global_rank,
world_size=world_size,
)
else:
assert (
sp_size == 1
), f"When world_size is 1, sp_size must also be 1, but got {sp_size}."
initialize_sequence_parallel_state(sp_size)
logging.info("Using SP: %s, SP_SIZE: %s", get_sequence_parallel_state(), sp_size)
args.local_rank = local_rank
args.device = device
if get_sequence_parallel_state():
runtime_sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
sp_group_id = global_rank // runtime_sp_size
num_sp_groups = world_size // runtime_sp_size
else:
sp_rank = 0
sp_group_id = global_rank
num_sp_groups = world_size
return RuntimeState(
world_size=world_size,
global_rank=global_rank,
local_rank=local_rank,
device=device,
sp_rank=sp_rank,
sp_group_id=sp_group_id,
num_sp_groups=num_sp_groups,
)
def _prepare_eval_data(config) -> List[GenerationItem]:
mode = config.get("mode")
assert mode in ALLOWED_MODES, (
f"Invalid mode {mode}, must be one of {sorted(ALLOWED_MODES)}"
)
text_prompt = config.get("text_prompt")
image_path = config.get("image_path")
ip_image_path = config.get("ip_image_path")
ip_audio_path = config.get("ip_audio_path")
text_prompts, image_paths, ip_image_paths, ip_audio_paths = (
validate_and_process_user_prompt(
text_prompt,
image_path,
ip_image_path,
ip_audio_path,
mode=mode,
)
)
if mode != "i2v":
logging.info(
"mode: %s, setting all image_paths, ip_image_paths and "
"ip_audio_paths to None",
mode,
)
image_paths = [None] * len(text_prompts)
else:
assert all(
p is not None and os.path.isfile(p) for p in image_paths
), f"In i2v mode, all image paths must be provided.{image_paths}"
return list(zip(text_prompts, image_paths, ip_image_paths, ip_audio_paths))
def _split_eval_data_for_current_rank(
all_eval_data: Sequence[GenerationItem],
runtime: RuntimeState,
require_sample_padding: bool = False,
) -> List[GenerationItem]:
total_files = len(all_eval_data)
if total_files == 0:
logging.error("ERROR: No evaluation files found")
return []
eval_data = list(all_eval_data)
remainder = total_files % runtime.num_sp_groups
if require_sample_padding and remainder != 0:
pad_count = runtime.num_sp_groups - remainder
eval_data += [eval_data[0]] * pad_count
return eval_data[runtime.sp_group_id :: runtime.num_sp_groups]
def _validate_optional_path(path: Optional[str], display_name: str) -> Optional[str]:
if path is None:
return None
if not os.path.isfile(path):
logging.warning("%s %s not exists, using `None` instead", display_name, path)
return None
return path
def _frame_size_string(video_frame_height_width: Optional[Sequence[int]]) -> str:
if video_frame_height_width is None:
raise ValueError("video_frame_height_width must be provided in config.")
return "x".join(map(str, video_frame_height_width))
def _build_output_path(
output_dir: str,
sequence_manager: SequenceNumberManager,
text_prompt: str,
ip_image_path: Optional[str],
ip_audio_path: Optional[str],
crop_face: bool,
video_frame_height_width: Optional[Sequence[int]],
seed: int,
global_rank: int,
) -> str:
condition_dir = (
f"ip_image_{ip_image_path is not None}_ip_audio_{ip_audio_path is not None}"
)
condition_output_dir = os.path.join(output_dir, condition_dir)
os.makedirs(condition_output_dir, exist_ok=True)
sequence_number = sequence_manager.next(condition_dir)
sequence_str = f"{sequence_number:05d}"
formatted_prompt = format_prompt_for_filename(text_prompt)
frame_size = _frame_size_string(video_frame_height_width)
output_filename = (
f"{sequence_str}_crop-{crop_face}_{formatted_prompt}_"
f"{frame_size}_{seed}_{global_rank}.mp4"
)
return os.path.join(condition_output_dir, output_filename)
def main(config, args) -> None:
runtime = _initialize_runtime(config, args)
target_dtype = torch.bfloat16
all_eval_data = _prepare_eval_data(config)
this_rank_eval_data = _split_eval_data_for_current_rank(all_eval_data, runtime)
self_lora = config.get("self_lora", True)
logging.info("Loading OVI Fusion Engine...")
ovi_engine = OviFusionEngine(
config=config,
device=runtime.device,
target_dtype=target_dtype,
self_lora=self_lora,
)
logging.info("OVI Fusion Engine loaded!")
output_dir = config.get("output_dir", "./outputs")
os.makedirs(output_dir, exist_ok=True)
sequence_manager = SequenceNumberManager(output_dir)
generation_kwargs = {
"video_frame_height_width": config.get("video_frame_height_width"),
"seed": config.get("seed", 100),
"solver_name": config.get("solver_name", "unipc"),
"sample_steps": config.get("sample_steps", 50),
"shift": config.get("shift", 5.0),
"video_guidance_scale": config.get("video_guidance_scale", 4.0),
"audio_guidance_scale": config.get("audio_guidance_scale", 3.0),
"slg_layer": config.get("slg_layer", 11),
"video_negative_prompt": config.get("video_negative_prompt", ""),
"audio_negative_prompt": config.get("audio_negative_prompt", ""),
}
crop_face = config.get("crop_face", False)
each_example_n_times = config.get("each_example_n_times", 1)
for text_prompt, image_path, ip_image_path, ip_audio_path in tqdm(this_rank_eval_data):
ip_image_path = _validate_optional_path(ip_image_path, "IP Image")
ip_audio_path = _validate_optional_path(ip_audio_path, "IP Audio")
for idx in range(each_example_n_times):
current_seed = generation_kwargs["seed"] + idx
generated_video, generated_audio, generated_image = ovi_engine.generate(
text_prompt=text_prompt,
image_path=image_path,
ip_image_path=ip_image_path,
ip_audio_path=ip_audio_path,
video_frame_height_width=generation_kwargs["video_frame_height_width"],
seed=current_seed,
solver_name=generation_kwargs["solver_name"],
sample_steps=generation_kwargs["sample_steps"],
shift=generation_kwargs["shift"],
video_guidance_scale=generation_kwargs["video_guidance_scale"],
audio_guidance_scale=generation_kwargs["audio_guidance_scale"],
slg_layer=generation_kwargs["slg_layer"],
video_negative_prompt=generation_kwargs["video_negative_prompt"],
audio_negative_prompt=generation_kwargs["audio_negative_prompt"],
)
if runtime.sp_rank != 0:
continue
output_path = _build_output_path(
output_dir=output_dir,
sequence_manager=sequence_manager,
text_prompt=text_prompt,
ip_image_path=ip_image_path,
ip_audio_path=ip_audio_path,
crop_face=crop_face,
video_frame_height_width=generation_kwargs["video_frame_height_width"],
seed=current_seed,
global_rank=runtime.global_rank,
)
save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
logging.info("Saved video to: %s", output_path)
if generated_image is not None:
image_output_path = output_path.replace(".mp4", ".png")
generated_image.save(image_output_path)
logging.info("Saved image to: %s", image_output_path)
if __name__ == "__main__":
args = get_arguments()
config = OmegaConf.load(args.config_file)
main(config=config, args=args)