LHMPP / scripts /inference /app_inference.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2026-03-10 10:00:00
# @Function : LHM++ Gradio App inference logic
import os
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from core.utils.hf_hub import wrap_model_hub
# Default batch size for inference
DEFAULT_BATCH_SIZE = 40
def parse_app_configs(
model_cards: Dict[str, Dict[str, str]],
) -> Tuple[DictConfig, DictConfig]:
"""Parse model configuration from environment variables and config files.
Returns:
A tuple of (cfg, cfg_train) containing merged configurations.
"""
cli_cfg = OmegaConf.create()
cfg = OmegaConf.create()
app_model_name = os.environ.get("APP_MODEL_NAME")
if app_model_name is None:
raise NotImplementedError("APP_MODEL_NAME environment variable must be set")
model_card = model_cards[app_model_name]
model_path = model_card["model_path"]
model_config = model_card["model_config"]
cli_cfg.model_name = model_path
if model_config is not None:
cfg_train = OmegaConf.load(model_config)
cfg.source_size = cfg_train.dataset.source_image_res
try:
cfg.src_head_size = cfg_train.dataset.src_head_size
except AttributeError:
cfg.src_head_size = 112
cfg.render_size = cfg_train.dataset.render_image.high
_relative_path = os.path.join(
cfg_train.experiment.parent,
cfg_train.experiment.child,
os.path.basename(cli_cfg.model_name).split("_")[-1],
)
cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
cfg.image_dump = os.path.join("exps", "images", _relative_path)
cfg.video_dump = os.path.join("exps", "videos", _relative_path)
cfg.motion_video_read_fps = 6
cfg.merge_with(cli_cfg)
cfg.setdefault("logger", "INFO")
assert cfg.model_name is not None, "model_name is required"
return cfg, cfg_train
def build_app_model(cfg: DictConfig) -> torch.nn.Module:
"""Build and load the LHM model from pretrained weights.
Args:
cfg: Configuration object containing model_name and other parameters.
Returns:
Loaded LHM model ready for inference.
"""
from core.models import model_dict
model_cls = wrap_model_hub(model_dict["human_lrm_a4o"])
model = model_cls.from_pretrained(cfg.model_name)
return model
@torch.no_grad()
def inference_results(
model: torch.nn.Module,
ref_img_tensors: torch.Tensor,
smplx_params: Dict[str, torch.Tensor],
motion_seq: Dict[str, Any],
video_size: int = 40,
ref_imgs_bool: Optional[torch.Tensor] = None,
visualized_center: bool = False,
batch_size: int = DEFAULT_BATCH_SIZE,
device: str = "cuda",
) -> np.ndarray:
"""Run inference on a motion sequence with batching to prevent OOM.
Args:
model: LHM model for human animation.
ref_img_tensors: Reference image tensors of shape (N, C, H, W).
smplx_params: SMPL-X parameters for the initial pose.
motion_seq: Dictionary containing motion sequence data.
video_size: Total number of frames to render.
ref_imgs_bool: Boolean mask indicating which reference images to use.
visualized_center: If True, crops output to subject bounds with 10% padding.
batch_size: Number of frames to process in each batch.
device: Device to run inference on.
Returns:
Rendered RGB frames as numpy array of shape (T, H, W, 3).
"""
offset_list = motion_seq.get("offset_list")
ori_h, ori_w = motion_seq.get("ori_size", (512, 512))
output_rgb = torch.ones((ori_h, ori_w, 3))
ref_imgs_bool = torch.ones(
ref_img_tensors.shape[0], dtype=torch.bool, device=device
)
model_outputs = model.infer_single_view(
ref_img_tensors.unsqueeze(0).to(device),
None,
None,
render_c2ws=motion_seq["render_c2ws"].to(device),
render_intrs=motion_seq["render_intrs"].to(device),
render_bg_colors=motion_seq["render_bg_colors"].to(device),
smplx_params={k: v.to(device) for k, v in smplx_params.items()},
ref_imgs_bool=ref_imgs_bool.unsqueeze(0),
)
if len(model_outputs) == 7:
(
gs_model_list,
query_points,
transform_mat_neutral_pose,
gs_hidden_features,
image_latents,
motion_emb,
pos_emb,
) = model_outputs
else:
(
gs_model_list,
query_points,
transform_mat_neutral_pose,
gs_hidden_features,
image_latents,
motion_emb,
) = model_outputs
pos_emb = None
batch_smplx_params = {
"betas": smplx_params["betas"].to(device),
"transform_mat_neutral_pose": transform_mat_neutral_pose,
}
frame_varying_keys = [
"root_pose",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"lhand_pose",
"rhand_pose",
"trans",
"focal",
"princpt",
"img_size_wh",
"expr",
]
batch_rgb_list = []
batch_mask_list = []
num_batches = (video_size + batch_size - 1) // batch_size
for batch_idx in range(0, video_size, batch_size):
current_batch = batch_idx // batch_size + 1
print(f"Processing batch {current_batch}/{num_batches}")
batch_smplx_params.update(
{
key: motion_seq["smplx_params"][key][
:, batch_idx : batch_idx + batch_size
].to(device)
for key in frame_varying_keys
}
)
mask_seqs = (
motion_seq.get("masks", [])[batch_idx : batch_idx + batch_size]
if "masks" in motion_seq
else None
)
anim_kwargs = {
"gs_model_list": gs_model_list,
"query_points": query_points,
"smplx_params": batch_smplx_params,
"render_c2ws": motion_seq["render_c2ws"][
:, batch_idx : batch_idx + batch_size
].to(device),
"render_intrs": motion_seq["render_intrs"][
:, batch_idx : batch_idx + batch_size
].to(device),
"render_bg_colors": motion_seq["render_bg_colors"][
:, batch_idx : batch_idx + batch_size
].to(device),
"gs_hidden_features": gs_hidden_features,
"image_latents": image_latents,
"motion_emb": motion_emb,
}
if pos_emb is not None:
anim_kwargs["pos_emb"] = pos_emb
if offset_list is not None:
anim_kwargs["offset_list"] = offset_list[batch_idx : batch_idx + batch_size]
if mask_seqs is not None:
anim_kwargs["mask_seqs"] = mask_seqs
if output_rgb is not None:
anim_kwargs["output_rgb"] = output_rgb
batch_rgb, batch_mask = model.animation_infer(**anim_kwargs)
batch_rgb_list.append((batch_rgb.clamp(0, 1) * 255).to(torch.uint8).numpy())
batch_mask_list.append((batch_mask.clamp(0, 1) * 255).to(torch.uint8).numpy())
print("End of inference")
if visualized_center:
mask_numpy = np.concatenate(batch_mask_list, axis=0)
h_indices, w_indices = np.where(mask_numpy > 0.25)[1:]
if len(h_indices) > 0 and len(w_indices) > 0:
top, bottom = h_indices.min(), h_indices.max()
left, right = w_indices.min(), w_indices.max()
center_y, center_x = (top + bottom) / 2, (left + right) / 2
height, width = bottom - top, right - left
new_height, new_width = height * 1.1, width * 1.1
top_new = max(0, int(center_y - new_height / 2))
bottom_new = int(center_y + new_height / 2)
left_new = max(0, int(center_x - new_width / 2))
right_new = int(center_x + new_width / 2)
rgb = np.concatenate(batch_rgb_list, axis=0)
output = rgb[:, top_new:bottom_new, left_new:right_new]
else:
output = np.concatenate(batch_rgb_list, axis=0)
else:
output = np.concatenate(batch_rgb_list, axis=0)
return output