|
|
""" |
|
|
Infinite World - Action-Conditioned Video Generation Inference Script |
|
|
====================================================================== |
|
|
A standalone inference script for generating long videos with action control. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import cv2 |
|
|
import math |
|
|
import torch |
|
|
import random |
|
|
import json |
|
|
import datetime |
|
|
import importlib |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from omegaconf import OmegaConf |
|
|
import torch.distributed as dist |
|
|
import torchvision.transforms as transforms |
|
|
import re |
|
|
|
|
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
sys.path.insert(0, PROJECT_ROOT) |
|
|
|
|
|
from infworld.utils.prepare_dataloader import get_obj_from_str |
|
|
from infworld.utils.data_utils import get_first_clip_from_video, save_silent_video |
|
|
from infworld.utils.dataset_utils import is_vid, is_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MOVE_ACTION_MAP = { |
|
|
'no-op': 0, |
|
|
'go forward': 1, |
|
|
'go back': 2, |
|
|
'go left': 3, |
|
|
'go right': 4, |
|
|
'go forward and go left': 5, |
|
|
'go forward and go right': 6, |
|
|
'go back and go left': 7, |
|
|
'go back and go right': 8, |
|
|
'uncertain': 9 |
|
|
} |
|
|
|
|
|
VIEW_ACTION_MAP = { |
|
|
'no-op': 0, |
|
|
'turn up': 1, |
|
|
'turn down': 2, |
|
|
'turn left': 3, |
|
|
'turn right': 4, |
|
|
'turn up and turn left': 5, |
|
|
'turn up and turn right': 6, |
|
|
'turn down and turn left': 7, |
|
|
'turn down and turn right': 8, |
|
|
'uncertain': 9 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_ckpt_step(path): |
|
|
"""Extract checkpoint step number from path.""" |
|
|
match = re.search(r'checkpoint-(\d+)\.ckpt', path) |
|
|
return int(match.group(1)) if match else 0 |
|
|
|
|
|
def resize_and_center_crop(image, target_size): |
|
|
"""Resize image and center crop to target size.""" |
|
|
orig_h, orig_w = image.shape[:2] |
|
|
target_h, target_w = target_size |
|
|
|
|
|
scale = max(target_h / orig_h, target_w / orig_w) |
|
|
final_h = math.ceil(scale * orig_h) |
|
|
final_w = math.ceil(scale * orig_w) |
|
|
|
|
|
resized = cv2.resize(image, (final_w, final_h), interpolation=cv2.INTER_AREA) |
|
|
tensor = torch.from_numpy(resized)[None, ...].permute(0, 3, 1, 2).contiguous() |
|
|
cropped = transforms.functional.center_crop(tensor, target_size) |
|
|
return cropped[:, :, None, :, :] |
|
|
|
|
|
def setup_seed(seed): |
|
|
"""Set random seeds for reproducibility.""" |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def torch_gc(): |
|
|
"""Clear GPU memory cache.""" |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
def load_action_sequence(action_path): |
|
|
"""Load action sequence from JSON file.""" |
|
|
with open(action_path, 'r') as f: |
|
|
actions = json.load(f) |
|
|
|
|
|
move_indices = [MOVE_ACTION_MAP[a['move']] for a in actions] |
|
|
view_indices = [VIEW_ACTION_MAP[a['view']] for a in actions] |
|
|
return move_indices, view_indices |
|
|
|
|
|
def load_condition_image(image_path, bucket_config): |
|
|
"""Load and preprocess condition image.""" |
|
|
if is_vid(image_path): |
|
|
frames = get_first_clip_from_video(image_path, clip_len=1) |
|
|
elif is_img(image_path): |
|
|
image = cv2.imread(image_path) |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
frames = [image] |
|
|
else: |
|
|
raise ValueError(f'Unsupported file format: {image_path}') |
|
|
|
|
|
processed_frames = [] |
|
|
for frame in frames: |
|
|
ratio = frame.shape[0] / frame.shape[1] |
|
|
closest_bucket = sorted(bucket_config.keys(), key=lambda x: abs(float(x) - ratio))[0] |
|
|
target_h, target_w = bucket_config[closest_bucket][0] |
|
|
|
|
|
tensor = resize_and_center_crop(frame, (target_h, target_w)) |
|
|
tensor = (tensor / 255 - 0.5) * 2 |
|
|
processed_frames.append(tensor) |
|
|
|
|
|
return torch.cat(processed_frames, dim=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_distributed(): |
|
|
"""Setup distributed or single-GPU mode.""" |
|
|
if 'RANK' in os.environ: |
|
|
|
|
|
rank = int(os.environ['RANK']) |
|
|
world_size = int(os.environ.get('WORLD_SIZE', 1)) |
|
|
local_rank = int(os.environ.get('LOCAL_RANK', rank % torch.cuda.device_count())) |
|
|
torch.cuda.set_device(local_rank) |
|
|
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*24)) |
|
|
global_rank = dist.get_rank() |
|
|
num_processes = dist.get_world_size() |
|
|
return local_rank, global_rank, num_processes, True |
|
|
else: |
|
|
|
|
|
local_rank = 0 |
|
|
global_rank = 0 |
|
|
num_processes = 1 |
|
|
torch.cuda.set_device(local_rank) |
|
|
return local_rank, global_rank, num_processes, False |
|
|
|
|
|
local_rank, global_rank, num_processes, use_dist = setup_distributed() |
|
|
print(f"[InfWorld] local_rank: {local_rank} | global_rank: {global_rank} | world_size: {num_processes}") |
|
|
|
|
|
|
|
|
context_parallel_size = 1 |
|
|
import infworld.context_parallel.context_parallel_util as cp_util |
|
|
if use_dist: |
|
|
from infworld.context_parallel.context_parallel_util import init_context_parallel, get_dp_size, get_dp_rank |
|
|
init_context_parallel(context_parallel_size=context_parallel_size, global_rank=global_rank, world_size=num_processes) |
|
|
dp_rank = get_dp_rank() |
|
|
dp_size = get_dp_size() |
|
|
else: |
|
|
|
|
|
cp_util.dp_rank = 0 |
|
|
cp_util.dp_size = 1 |
|
|
cp_util.cp_rank = 0 |
|
|
cp_util.cp_size = 1 |
|
|
dp_rank = 0 |
|
|
dp_size = 1 |
|
|
enable_context_parallel = (context_parallel_size > 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GLOBAL_SEED = 42 |
|
|
setup_seed(GLOBAL_SEED + global_rank) |
|
|
|
|
|
TEXT_CFG_SCALE = 5.0 |
|
|
NUM_SAMPLING_STEPS = 30 |
|
|
SHIFT = 7 |
|
|
NUM_CHUNKS = 13 |
|
|
HIGH_QUALITY_SAVE = True |
|
|
|
|
|
|
|
|
|
|
|
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'configs', 'infworld_config.yaml') |
|
|
|
|
|
PROMPTS_YAML = os.path.join(PROJECT_ROOT, 'prompts', 'demo.yaml') |
|
|
BUCKET_CONFIG_NAME = 'ASPECT_RATIO_627_F64' |
|
|
|
|
|
|
|
|
OUTPUT_BASE = os.path.join(PROJECT_ROOT, 'outputs') |
|
|
|
|
|
|
|
|
NEGATIVE_PROMPT = "many cars, crowds, Vivid hues, overexposed, static, blurry details, subtitles, style, work, artwork, image, still, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, motionless image, cluttered background, three legs, crowded background, walking backwards." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_path(path, root=PROJECT_ROOT): |
|
|
"""Resolve path: if relative, join with project root.""" |
|
|
if path is None: |
|
|
return path |
|
|
path = str(path).strip() |
|
|
if not os.path.isabs(path): |
|
|
path = os.path.join(root, path) |
|
|
return path |
|
|
|
|
|
|
|
|
def load_dit_state_dict(checkpoint_path): |
|
|
"""Load DiT state dict from .ckpt (torch) or .safetensors.""" |
|
|
checkpoint_path = resolve_path(checkpoint_path) |
|
|
if checkpoint_path.endswith(".safetensors"): |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(checkpoint_path) |
|
|
else: |
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
|
if "state_dict" in state_dict: |
|
|
state_dict = state_dict["state_dict"] |
|
|
return state_dict |
|
|
|
|
|
|
|
|
def main(): |
|
|
torch_gc() |
|
|
|
|
|
config_path = CONFIG_PATH |
|
|
args = OmegaConf.load(config_path) |
|
|
checkpoint_path = resolve_path(args.get("checkpoint_path", "checkpoints/models/diffusion_pytorch_model.safetensors")) |
|
|
|
|
|
ckpt_step = extract_ckpt_step(checkpoint_path) |
|
|
|
|
|
|
|
|
output_dir = os.path.join(OUTPUT_BASE, f"infworld-ckpt{ckpt_step}-step{NUM_SAMPLING_STEPS}-cfg{TEXT_CFG_SCALE}") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print(f"[InfWorld] Loading checkpoint: {checkpoint_path}") |
|
|
print(f"[InfWorld] Config: {config_path}") |
|
|
print(f"[InfWorld] Output directory: {output_dir}") |
|
|
|
|
|
|
|
|
if hasattr(args, "vae_cfg") and "vae_pth" in args.vae_cfg: |
|
|
args.vae_cfg.vae_pth = resolve_path(args.vae_cfg.vae_pth) |
|
|
if hasattr(args, "text_encoder_cfg"): |
|
|
if "checkpoint_path" in args.text_encoder_cfg: |
|
|
args.text_encoder_cfg.checkpoint_path = resolve_path(args.text_encoder_cfg.checkpoint_path) |
|
|
if "tokenizer_path" in args.text_encoder_cfg: |
|
|
args.text_encoder_cfg.tokenizer_path = resolve_path(args.text_encoder_cfg.tokenizer_path) |
|
|
|
|
|
|
|
|
print("[InfWorld] Loading VAE...") |
|
|
vae = get_obj_from_str(args.vae_target)(**args.vae_cfg).to(local_rank) |
|
|
|
|
|
print("[InfWorld] Loading Text Encoder...") |
|
|
text_encoder = get_obj_from_str(args.text_encoder_target)(device=local_rank, **args.text_encoder_cfg) |
|
|
text_encoder.t5.model.to(local_rank) |
|
|
|
|
|
print("[InfWorld] Loading Scheduler...") |
|
|
scheduler = get_obj_from_str(args.scheduler_target)(**args.val_scheduler_cfg) |
|
|
scheduler.num_sampling_steps = NUM_SAMPLING_STEPS |
|
|
scheduler.shift = SHIFT |
|
|
|
|
|
print("[InfWorld] Loading DiT Model...") |
|
|
dtype = getattr(torch, args.amp_dtype) |
|
|
dit = get_obj_from_str(args.model_target)( |
|
|
out_channels=vae.out_channels, |
|
|
caption_channels=text_encoder.output_dim, |
|
|
model_max_length=text_encoder.model_max_length, |
|
|
enable_context_parallel=enable_context_parallel, |
|
|
**args.model_cfg |
|
|
).to(dtype) |
|
|
dit.eval() |
|
|
|
|
|
|
|
|
state_dict = load_dit_state_dict(args.checkpoint_path) |
|
|
|
|
|
|
|
|
state_dict.pop("pos_embed_temporal", None) |
|
|
state_dict.pop("pos_embed", None) |
|
|
|
|
|
missing, unexpected = dit.load_state_dict(state_dict, strict=False) |
|
|
print(f"[InfWorld] Model loaded! Missing: {len(missing)}, Unexpected: {len(unexpected)}") |
|
|
|
|
|
dit.to(local_rank) |
|
|
|
|
|
|
|
|
from infworld.configs import bucket_config as bucket_config_module |
|
|
bucket_config = getattr(bucket_config_module, BUCKET_CONFIG_NAME) |
|
|
|
|
|
|
|
|
prompts_path = os.path.abspath(PROMPTS_YAML) |
|
|
target_prompts = OmegaConf.load(prompts_path).prompts |
|
|
print(f"[InfWorld] Loaded {len(target_prompts)} prompts") |
|
|
|
|
|
|
|
|
for task_idx, (prompt, image_path, action_path) in enumerate(target_prompts): |
|
|
if task_idx % dp_size != dp_rank: |
|
|
continue |
|
|
|
|
|
if not os.path.exists(image_path): |
|
|
print(f"[InfWorld] Skipping task {task_idx}: Image not found - {image_path}") |
|
|
continue |
|
|
|
|
|
if not os.path.exists(action_path): |
|
|
print(f"[InfWorld] Skipping task {task_idx}: Action not found - {action_path}") |
|
|
continue |
|
|
|
|
|
print(f"[InfWorld] Task {task_idx}: {prompt[:50]}...") |
|
|
|
|
|
|
|
|
cond_video = load_condition_image(image_path, bucket_config).to(local_rank) |
|
|
|
|
|
with torch.no_grad(): |
|
|
cond_latent = vae.encode(cond_video) |
|
|
|
|
|
|
|
|
move_indices, view_indices = load_action_sequence(action_path) |
|
|
|
|
|
|
|
|
video_buffer = cond_video.clone().cpu() |
|
|
|
|
|
|
|
|
latent_size = list(cond_latent.shape) |
|
|
latent_size[2] = 21 |
|
|
latent_size = torch.Size(latent_size) |
|
|
|
|
|
|
|
|
for chunk_idx in range(NUM_CHUNKS): |
|
|
print(f"[InfWorld] Generating chunk {chunk_idx + 1}/{NUM_CHUNKS}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
current_cond = video_buffer.to(local_rank) |
|
|
current_latent = vae.encode(current_cond) |
|
|
|
|
|
|
|
|
curr_start = video_buffer.shape[2] - 1 |
|
|
curr_end = curr_start + args.validation_data.num_frames |
|
|
|
|
|
move = torch.tensor(move_indices[curr_start:curr_end], dtype=torch.long, device=local_rank) |
|
|
view = torch.tensor(view_indices[curr_start:curr_end], dtype=torch.long, device=local_rank) |
|
|
|
|
|
|
|
|
num_frames = args.validation_data.num_frames |
|
|
if move.shape[0] < num_frames: |
|
|
pad_len = num_frames - move.shape[0] |
|
|
move = torch.cat([move, torch.zeros(pad_len, dtype=torch.long, device=local_rank)]) |
|
|
view = torch.cat([view, torch.zeros(pad_len, dtype=torch.long, device=local_rank)]) |
|
|
|
|
|
additional_args = { |
|
|
"image_cond": current_latent, |
|
|
"move": move.unsqueeze(0), |
|
|
"view": view.unsqueeze(0), |
|
|
} |
|
|
|
|
|
torch_gc() |
|
|
|
|
|
with torch.no_grad(): |
|
|
samples = scheduler.sample( |
|
|
model=dit, |
|
|
text_encoder=text_encoder, |
|
|
null_embedder=dit.y_embedder, |
|
|
z_size=latent_size, |
|
|
prompts=[prompt], |
|
|
guidance_scale=TEXT_CFG_SCALE, |
|
|
negative_prompts=[NEGATIVE_PROMPT], |
|
|
device=torch.device(local_rank), |
|
|
additional_args=additional_args, |
|
|
) |
|
|
|
|
|
decoded_chunk = vae.decode(samples).cpu() |
|
|
video_buffer = torch.cat([video_buffer, decoded_chunk[:, :, 1:]], dim=2) |
|
|
|
|
|
print(f"[InfWorld] Chunk {chunk_idx + 1} done. Total frames: {video_buffer.shape[2]}") |
|
|
torch_gc() |
|
|
|
|
|
|
|
|
video_name = f"{task_idx:04d}_{prompt[:30].replace(' ', '_')}" |
|
|
save_path = os.path.join(output_dir, video_name) |
|
|
|
|
|
quality = 10 if HIGH_QUALITY_SAVE else 5 |
|
|
save_silent_video(video_buffer.to(local_rank), save_path, fps=30, quality=quality) |
|
|
print(f"[InfWorld] Saved: {save_path}.mp4") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|