LHMPP / scripts /inference /infer_eval_cano.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2025-07-03 20:11:35
# @Function : PF-LHM Inference Demo Code.
# Given specific motion sequences.
# We reenattach the motion sequences to the corresponding human body.
import os
import sys
sys.path.append("./")
import time
import cv2
import numpy as np
import torch
torch._dynamo.config.disable = True
import glob
import json
from typing import Dict, Optional, Tuple
import torch
from accelerate import Accelerator
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from core.runners.infer.utils import (
prepare_motion_seqs_cano,
prepare_motion_seqs_eval,
)
from core.utils.hf_hub import wrap_model_hub
def resize_with_padding(images, target_size):
"""
Combine 4 images into a 2x2 grid, then resize with aspect ratio preserved,
and pad with white to match the target size.
Args:
images: List[np.ndarray], each of shape (H, W), dtype usually uint8
target_size: tuple (H1, W1)
Returns:
np.ndarray: Output image of shape (H1, W1), dtype uint8, padded with white (255)
"""
assert len(images) == 4, "Exactly 4 images are required"
H, W = images[0].shape[:2]
assert all(
img.shape[:2] == (H, W) for img in images
), "All images must have the same shape (H, W)"
# 1. Concatenate into a 2H x 2W image (2x2 grid)
top_row = np.hstack([images[0], images[1]]) # Shape: (H, 2W, C)
bottom_row = np.hstack([images[2], images[3]]) # Shape: (H, 2W, C)
combined = np.vstack([top_row, bottom_row]) # Shape: (2H, 2W, C)
Hc, Wc, _ = combined.shape # = (2H, 2W)
target_h, target_w = target_size
# 2. Resize with preserved aspect ratio
scale_h = target_h / Hc
scale_w = target_w / Wc
scale = min(scale_h, scale_w) # Use the smaller scale to avoid overflow
new_h = int(Hc * scale)
new_w = int(Wc * scale)
# Resize the combined image
resized = cv2.resize(combined, (new_w, new_h), interpolation=cv2.INTER_AREA)
# 3. Pad with white (255) to reach target size (centered)
padded = np.full((target_h, target_w, 3), 255, dtype=np.uint8) # White background
# Calculate padding offsets (center the resized image)
top = (target_h - new_h) // 2
left = (target_w - new_w) // 2
# Place the resized image in the center
padded[top : top + new_h, left : left + new_w] = resized
return padded
DATASETS_CONFIG = {
"eval": dict(
root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/LHM_video_dataset/",
meta_path="./train_data/ClothVideo/label/valid_LHM_dataset_train_val_100.json",
),
"dataset5": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v5_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv5_val_filter460-self-rotated-69.json",
),
"dataset6": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v6_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv6_test_100-self-rotated-25.json",
),
"synthetic": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/synthetic_data_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_LHM_synthetic_dataset_val_17.json",
),
"dataset5_train": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v5_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv5_train_filter40K-self-rotated-7147.json",
),
"dataset6_train": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v6_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv6_train_5W-self-rotated-11341.json",
),
"eval_train": dict(
root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/LHM_video_dataset/",
meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/clean_labels/valid_LHM_dataset_train_filter_16W.json",
),
"in_the_wild": dict(
root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/PFLHM-Causal-Video/",
meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/clean_labels/eval_sparse_lhm_wild.json",
),
"in_the_wild_people_snapshot": dict(
root_dirs="/mnt/workspaces/dataset/people_snapshot/",
meta_path="/mnt/workspaces/dataset/people_snapshot/peoplesnapshot.json",
),
"in_the_wild_rec_mv": dict(
root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/rec_mv_dataset",
meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/rec_mv_dataset/rec_mv_dataset.json",
),
"in_the_wild_mvhumannet": dict(
root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/mvhumannet/",
meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/mvhumannet/mvhumannet.json",
),
"dataset_train_real": dict(
root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/LHM_video_dataset/",
meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/clean_labels/valid_LHM_dataset_train_filter_16W.json",
),
"dataset5_train_real": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v5_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv5_train_filter40K.json",
),
"dataset6_train_real": dict(
root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v6_tar/",
meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv6_train_5W.json",
),
"web_dresscode": dict(
root_dirs="/mnt/workspaces/datasets/lhm_human_datasets/DressCode/", # heyuan generation dataset.
meta_path=None,
),
"hweb_hero": dict(
root_dirs="/mnt/workspaces/datasets/lhm_human_datasets/pinterest_download_0903_gen_full_fixed_view_random_pose_2_filtered", # heyuan generation dataset.
meta_path=None,
),
}
def obtain_motion_sequence(motion_seqs):
motion_seqs = sorted(glob.glob(os.path.join(motion_seqs, "*.json")))
smplx_list = []
for motion in motion_seqs:
with open(motion) as reader:
smplx_params = json.load(reader)
flame_path = motion.replace("smplx_params", "flame_params")
if os.path.exists(flame_path):
with open(flame_path) as reader:
flame_params = json.load(reader)
smplx_params["expr"] = torch.FloatTensor(flame_params["expcode"])
# replace with flame's jaw_pose
smplx_params["jaw_pose"] = torch.FloatTensor(flame_params["posecode"][3:])
smplx_params["leye_pose"] = torch.FloatTensor(flame_params["eyecode"][:3])
smplx_params["reye_pose"] = torch.FloatTensor(flame_params["eyecode"][3:])
else:
smplx_params["expr"] = torch.FloatTensor([0.0] * 100)
smplx_list.append(smplx_params)
return smplx_list
def _build_model(cfg):
from core.models import model_dict
hf_model_cls = wrap_model_hub(model_dict["human_lrm_a4o"])
model = hf_model_cls.from_pretrained(cfg.model_name)
return model
def get_smplx_params(data, device):
smplx_params = {}
smplx_keys = [
"root_pose",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"lhand_pose",
"rhand_pose",
"expr",
"trans",
"betas",
]
for k, v in data.items():
if k in smplx_keys:
# print(k, v.shape)
smplx_params[k] = data[k].unsqueeze(0).to(device)
return smplx_params
def animation_infer(
renderer,
gs_model_list,
query_points,
smplx_params,
render_c2ws,
render_intrs,
render_bg_colors,
) -> dict:
"""Render animation frames in parallel without redundant computations.
Args:
renderer: The rendering engine
gs_model_list: List of Gaussian models
query_points: 3D query points
smplx_params: SMPL-X parameters
render_c2ws: Camera-to-world matrices
render_intrs: Intrinsic camera parameters
render_bg_colors: Background colors
Returns:
Dictionary of rendered results (rgb, mask, depth, etc.)
"""
render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int(
render_intrs[0, 0, 0, 2] * 2
)
# render target views
render_res_list = []
num_views = render_c2ws.shape[1]
start_time = time.time()
# render target views
render_res_list = [
renderer.forward_animate_gs(
gs_model_list,
query_points,
renderer.get_single_view_smpl_data(smplx_params, view_idx),
render_c2ws[:, view_idx : view_idx + 1],
render_intrs[:, view_idx : view_idx + 1],
render_h,
render_w,
render_bg_colors[:, view_idx : view_idx + 1],
)
for view_idx in range(num_views)
]
# Calculate and print timing information
avg_time = (time.time() - start_time) / num_views
print(f"Average time per frame: {avg_time:.4f}s")
# Aggregate and format results
out = defaultdict(list)
for res in render_res_list:
for k, v in res.items():
out[k].append(v.detach().cpu() if isinstance(v, torch.Tensor) else v)
# Post-process tensor outputs
for k, v in out.items():
if isinstance(v[0], torch.Tensor):
out[k] = torch.concat(v, dim=1)
if k in {"comp_rgb", "comp_mask", "comp_depth"}:
out[k] = out[k][0].permute(
0, 2, 3, 1
) # [1, Nv, 3, H, W] -> [Nv, H, W, 3]
return out
@torch.no_grad()
def inference_results(
lhm: torch.nn.Module,
batch: Dict,
smplx_params: Dict,
motion_seq: Dict,
camera_size: int = 40,
ref_imgs_bool=None,
batch_size: int = 40,
device: str = "cuda",
) -> np.ndarray:
"""Perform inference on a motion sequence in batches to avoid OOM."""
offset_list = motion_seq["offset_list"]
ori_h, ori_w = motion_seq["ori_size"]
output_rgb = torch.ones((ori_h, ori_w, 3)) # draw board
# Extract model outputs
(
gs_model_list,
query_points,
transform_mat_neutral_pose,
gs_hidden_features,
image_latents,
motion_emb,
pos_emb,
) = lhm.infer_single_view(
batch["source_rgbs"].unsqueeze(0).to(device),
None,
None,
render_c2ws=motion_seq["render_c2ws"].to(device),
render_intrs=motion_seq["render_intrs"].to(device),
render_bg_colors=motion_seq["render_bg_colors"].to(device),
smplx_params={k: v.to(device) for k, v in smplx_params.items()},
ref_imgs_bool=ref_imgs_bool,
)
# Prepare parameters that are constant across batches
batch_smplx_params = {
"betas": smplx_params["betas"].to(device),
"transform_mat_neutral_pose": transform_mat_neutral_pose,
}
keys = [
"root_pose",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"lhand_pose",
"rhand_pose",
"trans",
"focal",
"princpt",
"img_size_wh",
"expr",
]
batch_list = []
batch_mask_list = []
for batch_i in range(0, camera_size, batch_size):
print(
f"Processing batch {batch_i//batch_size + 1}/{(camera_size + batch_size - 1)//batch_size}"
)
# Update batch-specific parameters
batch_smplx_params.update(
{
key: motion_seq["smplx_params"][key][
:, batch_i : batch_i + batch_size
].to(device)
for key in keys
}
)
# Perform inference
batch_rgb, batch_mask = lhm.animation_infer(
gs_model_list,
query_points,
batch_smplx_params,
render_c2ws=motion_seq["render_c2ws"][:, batch_i : batch_i + batch_size].to(
device
),
render_intrs=motion_seq["render_intrs"][
:, batch_i : batch_i + batch_size
].to(device),
render_bg_colors=motion_seq["render_bg_colors"][
:, batch_i : batch_i + batch_size
].to(device),
gs_hidden_features=gs_hidden_features,
image_latents=image_latents,
motion_emb=motion_emb,
pos_emb=pos_emb,
offset_list=offset_list[batch_i : batch_i + batch_size],
mask_seqs=motion_seq["masks"][batch_i : batch_i + batch_size],
output_rgb=output_rgb,
)
# Convert and store results
batch_list.append((batch_rgb.clamp(0, 1) * 255).to(torch.uint8).numpy())
batch_mask_list.append((batch_mask.clamp(0, 1) * 255).to(torch.uint8).numpy())
return np.concatenate(batch_list, axis=0), np.concatenate(batch_mask_list, axis=0)
@torch.no_grad()
def inference_gs_model(
lhm: torch.nn.Module,
batch: Dict,
smplx_params: Dict,
motion_seq: Dict,
camera_size: int = 40,
ref_imgs_bool=None,
batch_size: int = 40,
device: str = "cuda",
) -> np.ndarray:
"""Perform inference on a motion sequence in batches to avoid OOM."""
# Extract model outputs
(
gs_model_list,
query_points,
transform_mat_neutral_pose,
gs_hidden_features,
image_latents,
motion_emb,
) = lhm.infer_single_view(
batch["source_rgbs"].unsqueeze(0).to(device),
None,
None,
render_c2ws=motion_seq["render_c2ws"].to(device),
render_intrs=motion_seq["render_intrs"].to(device),
render_bg_colors=motion_seq["render_bg_colors"].to(device),
smplx_params={k: v.to(device) for k, v in smplx_params.items()},
ref_imgs_bool=ref_imgs_bool,
)
# Prepare parameters that are constant across batches
batch_smplx_params = {
"betas": smplx_params["betas"].to(device),
"transform_mat_neutral_pose": transform_mat_neutral_pose,
}
keys = [
"root_pose",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"lhand_pose",
"rhand_pose",
"trans",
"focal",
"princpt",
"img_size_wh",
"expr",
]
batch_list = []
for batch_i in range(0, camera_size, batch_size):
print(
f"Processing batch {batch_i//batch_size + 1}/{(camera_size + batch_size - 1)//batch_size}"
)
# Update batch-specific parameters
batch_smplx_params.update(
{
key: motion_seq["smplx_params"][key][
:, batch_i : batch_i + batch_size
].to(device)
for key in keys
}
)
# Perform inference
gs_model = lhm.inference_gs(
gs_model_list,
query_points,
batch_smplx_params,
render_c2ws=motion_seq["render_c2ws"][:, batch_i : batch_i + batch_size].to(
device
),
render_intrs=motion_seq["render_intrs"][
:, batch_i : batch_i + batch_size
].to(device),
render_bg_colors=motion_seq["render_bg_colors"][
:, batch_i : batch_i + batch_size
].to(device),
gs_hidden_features=gs_hidden_features,
image_latents=image_latents,
motion_emb=motion_emb,
)
return gs_model
@torch.no_grad()
def lhm_validation_inference(
lhm: Optional[torch.nn.Module],
save_path: str,
view: int = 16,
cfg: Optional[Dict] = None,
motion_path: Optional[str] = None,
exp_name: str = "eval",
debug: bool = False,
split: int = 1,
gpus: int = 0,
) -> None:
"""Run validation inference on the model."""
if lhm is not None:
lhm.cuda().eval()
assert motion_path is not None
cfg = cfg or {}
# Setup paths
gt_save_path = os.path.join(os.path.dirname(save_path), "gt")
gt_mask_save_path = os.path.join(os.path.dirname(save_path), "mask")
os.makedirs(gt_save_path, exist_ok=True)
os.makedirs(gt_mask_save_path, exist_ok=True)
# Load dataset
dataset_config = DATASETS_CONFIG[exp_name]
kwargs = {}
if exp_name == "eval" or exp_name == "eval_train":
from core.datasets.video_human_lhm_dataset_a4o import (
VideoHumanLHMA4ODatasetEval as VideoDataset,
)
elif "in_the_wild" in exp_name:
from core.datasets.video_in_the_wild_dataset import (
VideoInTheWildEval as VideoDataset,
)
kwargs["heuristic_sampling"] = True
elif "hweb" in exp_name:
from core.datasets.video_in_the_wild_web_dataset import (
WebInTheWildHeurEval as VideoDataset,
)
kwargs["heuristic_sampling"] = False
elif "web" in exp_name:
from core.datasets.video_in_the_wild_web_dataset import (
WebInTheWildEval as VideoDataset,
)
kwargs["heuristic_sampling"] = False
else:
from core.datasets.video_human_dataset_a4o import (
VideoHumanA4ODatasetEval as VideoDataset,
)
dataset = VideoDataset(
root_dirs=dataset_config["root_dirs"],
meta_path=dataset_config["meta_path"],
sample_side_views=7,
render_image_res_low=420,
render_image_res_high=420,
render_region_size=(682, 420),
source_image_res=512,
debug=False,
use_flame=True,
ref_img_size=view,
womask=True,
is_val=True,
processing_pipeline=[
dict(name="PadRatioWithScale", target_ratio=5 / 3, tgt_max_size_list=[840]),
dict(name="ToTensor"),
],
**kwargs,
)
# Prepare motion sequences
smplx_path = os.path.join(motion_path, "smplx_params")
mask_path = os.path.join(motion_path, "samurai_seg")
motion_seqs = sorted(glob.glob(os.path.join(smplx_path, "*.json")))
motion_id_seqs = [
motion_seq.split("/")[-1].replace(".json", "") for motion_seq in motion_seqs
]
mask_paths = [
os.path.join(mask_path, motion_id_seq + ".png")
for motion_id_seq in motion_id_seqs
]
motion_seqs = prepare_motion_seqs_cano(
obtain_motion_sequence(smplx_path),
mask_paths=mask_paths,
bg_color=1.0,
aspect_standard=5.0 / 3,
enlarge_ratio=[1.0, 1.0],
tgt_size=cfg.get("render_size", 420),
render_image_res=cfg.get("render_size", 420),
need_mask=cfg.get("motion_img_need_mask", False),
vis_motion=cfg.get("vis_motion", False),
motion_size=100 if debug else 1000,
specific_id_list=None,
)
motion_id = motion_seqs["motion_id"]
# split_dataset
dataset_size = len(dataset)
bins = int(np.ceil(dataset_size / split))
for idx in range(bins * gpus, bins * (gpus + 1)):
try:
item = dataset.__getitem__(idx, view)
except:
continue
uid = item["uid"]
save_folder = os.path.join(save_path, uid)
os.makedirs(save_folder, exist_ok=True)
print(f"Processing {uid}, idx: {idx}")
video_dir = os.path.join(save_path, f"view_{view:03d}")
os.makedirs(video_dir, exist_ok=True)
# Update parameters and perform inference
motion_seqs["smplx_params"]["betas"] = item["betas"].unsqueeze(0)
try:
rgbs, masks = inference_results(
lhm,
item,
motion_seqs["smplx_params"],
motion_seqs,
camera_size=motion_seqs["smplx_params"]["root_pose"].shape[1],
ref_imgs_bool=item["ref_imgs_bool"].unsqueeze(0),
)
except:
print("Error in infering")
continue
for rgb, mask, mi in zip(rgbs, masks, motion_id):
pred_image = Image.fromarray(rgb)
pred_mask = Image.fromarray(mask)
idx = mi
pred_name = f"rgb_{idx:05d}.png"
mask_pred_name = f"mask_{idx:05d}.png"
save_img_path = os.path.join(save_folder, pred_name)
save_mask_path = os.path.join(save_folder, mask_pred_name)
pred_image.save(save_img_path)
pred_mask.save(save_mask_path)
@torch.no_grad()
def lhm_validation_inference_gs(
lhm: Optional[torch.nn.Module],
save_path: str,
view: int = 16,
cfg: Optional[Dict] = None,
motion_path: Optional[str] = None,
exp_name: str = "eval",
debug: bool = False,
split: int = 1,
gpus: int = 0,
) -> None:
"""Run validation inference on the model."""
if lhm is not None:
lhm.cuda().eval()
assert motion_path is not None
cfg = cfg or {}
# Load dataset
dataset_config = DATASETS_CONFIG[exp_name]
kwargs = {}
if exp_name == "eval" or exp_name == "eval_train":
from core.datasets.video_human_lhm_dataset_a4o import (
VideoHumanLHMA4ODatasetEval as VideoDataset,
)
elif "in_the_wild" in exp_name:
from core.datasets.video_in_the_wild_dataset import (
VideoInTheWildEval as VideoDataset,
)
kwargs["heuristic_sampling"] = True
else:
from core.datasets.video_human_dataset_a4o import (
VideoHumanA4ODatasetEval as VideoDataset,
)
dataset = VideoDataset(
root_dirs=dataset_config["root_dirs"],
meta_path=dataset_config["meta_path"],
sample_side_views=7,
render_image_res_low=420,
render_image_res_high=420,
render_region_size=(700, 420),
source_image_res=420,
debug=False,
use_flame=True,
ref_img_size=view,
womask=True,
is_val=True,
**kwargs,
)
# Prepare motion sequences
smplx_path = os.path.join(motion_path, "smplx_params")
motion_seqs = prepare_motion_seqs_eval(
obtain_motion_sequence(smplx_path),
bg_color=1.0,
aspect_standard=5.0 / 3,
enlarge_ratio=[1.0, 1.0],
render_image_res=cfg.get("render_size", 384),
need_mask=cfg.get("motion_img_need_mask", False),
vis_motion=cfg.get("vis_motion", False),
motion_size=1, # a trick, only choose one, as we do not query the motion
specific_id_list=None,
)
# split_dataset
dataset_size = len(dataset)
bins = int(np.ceil(dataset_size / split))
for idx in range(bins * gpus, bins * (gpus + 1)):
try:
item = dataset.__getitem__(idx, view)
except:
continue
uid = item["uid"]
print(f"Processing {uid}, idx: {idx}")
gs_dir = os.path.join(save_path, f"view_{view:03d}")
os.makedirs(gs_dir, exist_ok=True)
gs_file_path = os.path.join(gs_dir, f"{uid}.ply")
if os.path.exists(gs_file_path):
continue
# Update parameters and perform inference
motion_seqs["smplx_params"]["betas"] = item["betas"]
gs_model = inference_gs_model(
lhm,
item,
motion_seqs["smplx_params"],
motion_seqs,
camera_size=motion_seqs["smplx_params"]["root_pose"].shape[1],
)
print(f"generated GS Model!")
gs_model.save_ply(gs_file_path)
def parse_configs() -> Tuple[DictConfig, str]:
"""Parse configuration from environment variables and model config files.
Returns:
Tuple containing:
- Merged configuration object
- Model name extracted from MODEL_PATH
"""
cli_cfg = OmegaConf.create()
cfg = OmegaConf.create()
# Parse from environment variables
model_path = os.environ.get("MODEL_PATH", "").rstrip("/")
if not model_path:
raise ValueError("MODEL_PATH environment variable is required")
cli_cfg.model_name = model_path
model_name = model_path.split("/")[-2] # Assuming standard path structure
# Load model config if available
if model_config := os.environ.get("APP_MODEL_CONFIG"):
cfg_train = OmegaConf.load(model_config)
# Set basic configuration parameters
cfg.update(
{
"source_size": cfg_train.dataset.source_image_res,
"src_head_size": getattr(cfg_train.dataset, "src_head_size", 112),
"render_size": cfg_train.dataset.render_image.high,
"motion_video_read_fps": 30,
"logger": "INFO",
}
)
# Set paths based on experiment configuration
exp_id = os.path.basename(model_path).split("_")[-1]
relative_path = os.path.join(
cfg_train.experiment.parent, cfg_train.experiment.child, exp_id
)
cfg.update(
{
"save_tmp_dump": os.path.join("exps", "save_tmp", relative_path),
"image_dump": os.path.join("exps", "images", relative_path),
"video_dump": os.path.join("exps", "videos", relative_path),
}
)
# Merge CLI config and validate
cfg.merge_with(cli_cfg)
if not cfg.get("model_name"):
raise ValueError("model_name is required in configuration")
return cfg, model_name
def get_parse():
import argparse
parser = argparse.ArgumentParser(description="Inference_Config")
parser.add_argument("-c", "--config", required=True, help="config file path")
parser.add_argument("-w", "--ckpt", required=True, help="model checkpoint")
parser.add_argument("-v", "--view", help="input views", default=16, type=int)
parser.add_argument("-p", "--pre", help="exp_name", type=str)
parser.add_argument("-m", "--motion", help="motion_path", type=str)
parser.add_argument(
"-s",
"--split",
help="split_dataset, used for distribution inference.",
type=int,
default=1,
)
parser.add_argument("-g", "--gpus", help="current gpu id", type=int, default=0)
parser.add_argument("--gs", help="output gaussian model", action="store_true")
parser.add_argument(
"--output", help="output_cano_folder", default="./debug/cano_output", type=str
)
parser.add_argument("--debug", help="motion_path", action="store_true")
args = parser.parse_args()
return args
def main():
args = get_parse()
os.environ.update(
{
"APP_ENABLED": "1",
"APP_MODEL_CONFIG": args.config, # choice from MODEL_CARD
"APP_TYPE": "infer.human_lrm_a4o",
"APP_TYPE": "infer.human_lrm_a4o",
"NUMBA_THREADING_LAYER": "omp",
"MODEL_PATH": args.ckpt,
}
)
exp_name = args.pre
motion = args.motion
motion = motion[:-1] if motion[-1] == "/" else motion
cfg, model_name = parse_configs()
# save_path = os.path.join('./exps/validation/public_benchmark', model_name)
assert exp_name in list(DATASETS_CONFIG.keys())
output_path = args.output
lhm = _build_model(cfg)
if not args.gs:
os.makedirs(output_path, exist_ok=True)
lhm_validation_inference(
lhm,
output_path,
args.view,
cfg,
motion,
exp_name,
debug=args.debug,
split=args.split,
gpus=args.gpus,
)
else:
raise NotImplementedError
if __name__ == "__main__":
main()