| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | from typing import Any, Dict, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from omegaconf import DictConfig, OmegaConf |
| |
|
| | from core.utils.hf_hub import wrap_model_hub |
| |
|
| | |
| | DEFAULT_BATCH_SIZE = 40 |
| |
|
| |
|
| | def parse_app_configs( |
| | model_cards: Dict[str, Dict[str, str]], |
| | ) -> Tuple[DictConfig, DictConfig]: |
| | """Parse model configuration from environment variables and config files. |
| | |
| | Returns: |
| | A tuple of (cfg, cfg_train) containing merged configurations. |
| | """ |
| | cli_cfg = OmegaConf.create() |
| | cfg = OmegaConf.create() |
| |
|
| | app_model_name = os.environ.get("APP_MODEL_NAME") |
| | if app_model_name is None: |
| | raise NotImplementedError("APP_MODEL_NAME environment variable must be set") |
| |
|
| | model_card = model_cards[app_model_name] |
| | model_path = model_card["model_path"] |
| | model_config = model_card["model_config"] |
| |
|
| | cli_cfg.model_name = model_path |
| |
|
| | if model_config is not None: |
| | cfg_train = OmegaConf.load(model_config) |
| | cfg.source_size = cfg_train.dataset.source_image_res |
| | try: |
| | cfg.src_head_size = cfg_train.dataset.src_head_size |
| | except AttributeError: |
| | cfg.src_head_size = 112 |
| | cfg.render_size = cfg_train.dataset.render_image.high |
| |
|
| | _relative_path = os.path.join( |
| | cfg_train.experiment.parent, |
| | cfg_train.experiment.child, |
| | os.path.basename(cli_cfg.model_name).split("_")[-1], |
| | ) |
| |
|
| | cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path) |
| | cfg.image_dump = os.path.join("exps", "images", _relative_path) |
| | cfg.video_dump = os.path.join("exps", "videos", _relative_path) |
| |
|
| | cfg.motion_video_read_fps = 6 |
| | cfg.merge_with(cli_cfg) |
| | cfg.setdefault("logger", "INFO") |
| | assert cfg.model_name is not None, "model_name is required" |
| |
|
| | return cfg, cfg_train |
| |
|
| |
|
| | def build_app_model(cfg: DictConfig) -> torch.nn.Module: |
| | """Build and load the LHM model from pretrained weights. |
| | |
| | Args: |
| | cfg: Configuration object containing model_name and other parameters. |
| | |
| | Returns: |
| | Loaded LHM model ready for inference. |
| | """ |
| | from core.models import model_dict |
| |
|
| | model_cls = wrap_model_hub(model_dict["human_lrm_a4o"]) |
| | model = model_cls.from_pretrained(cfg.model_name) |
| | return model |
| |
|
| |
|
| | @torch.no_grad() |
| | def inference_results( |
| | model: torch.nn.Module, |
| | ref_img_tensors: torch.Tensor, |
| | smplx_params: Dict[str, torch.Tensor], |
| | motion_seq: Dict[str, Any], |
| | video_size: int = 40, |
| | ref_imgs_bool: Optional[torch.Tensor] = None, |
| | visualized_center: bool = False, |
| | batch_size: int = DEFAULT_BATCH_SIZE, |
| | device: str = "cuda", |
| | ) -> np.ndarray: |
| | """Run inference on a motion sequence with batching to prevent OOM. |
| | |
| | Args: |
| | model: LHM model for human animation. |
| | ref_img_tensors: Reference image tensors of shape (N, C, H, W). |
| | smplx_params: SMPL-X parameters for the initial pose. |
| | motion_seq: Dictionary containing motion sequence data. |
| | video_size: Total number of frames to render. |
| | ref_imgs_bool: Boolean mask indicating which reference images to use. |
| | visualized_center: If True, crops output to subject bounds with 10% padding. |
| | batch_size: Number of frames to process in each batch. |
| | device: Device to run inference on. |
| | |
| | Returns: |
| | Rendered RGB frames as numpy array of shape (T, H, W, 3). |
| | """ |
| | offset_list = motion_seq.get("offset_list") |
| | ori_h, ori_w = motion_seq.get("ori_size", (512, 512)) |
| | output_rgb = torch.ones((ori_h, ori_w, 3)) |
| | ref_imgs_bool = torch.ones( |
| | ref_img_tensors.shape[0], dtype=torch.bool, device=device |
| | ) |
| |
|
| | model_outputs = model.infer_single_view( |
| | ref_img_tensors.unsqueeze(0).to(device), |
| | None, |
| | None, |
| | render_c2ws=motion_seq["render_c2ws"].to(device), |
| | render_intrs=motion_seq["render_intrs"].to(device), |
| | render_bg_colors=motion_seq["render_bg_colors"].to(device), |
| | smplx_params={k: v.to(device) for k, v in smplx_params.items()}, |
| | ref_imgs_bool=ref_imgs_bool.unsqueeze(0), |
| | ) |
| |
|
| | if len(model_outputs) == 7: |
| | ( |
| | gs_model_list, |
| | query_points, |
| | transform_mat_neutral_pose, |
| | gs_hidden_features, |
| | image_latents, |
| | motion_emb, |
| | pos_emb, |
| | ) = model_outputs |
| | else: |
| | ( |
| | gs_model_list, |
| | query_points, |
| | transform_mat_neutral_pose, |
| | gs_hidden_features, |
| | image_latents, |
| | motion_emb, |
| | ) = model_outputs |
| | pos_emb = None |
| |
|
| | batch_smplx_params = { |
| | "betas": smplx_params["betas"].to(device), |
| | "transform_mat_neutral_pose": transform_mat_neutral_pose, |
| | } |
| |
|
| | frame_varying_keys = [ |
| | "root_pose", |
| | "body_pose", |
| | "jaw_pose", |
| | "leye_pose", |
| | "reye_pose", |
| | "lhand_pose", |
| | "rhand_pose", |
| | "trans", |
| | "focal", |
| | "princpt", |
| | "img_size_wh", |
| | "expr", |
| | ] |
| |
|
| | batch_rgb_list = [] |
| | batch_mask_list = [] |
| | num_batches = (video_size + batch_size - 1) // batch_size |
| |
|
| | for batch_idx in range(0, video_size, batch_size): |
| | current_batch = batch_idx // batch_size + 1 |
| | print(f"Processing batch {current_batch}/{num_batches}") |
| |
|
| | batch_smplx_params.update( |
| | { |
| | key: motion_seq["smplx_params"][key][ |
| | :, batch_idx : batch_idx + batch_size |
| | ].to(device) |
| | for key in frame_varying_keys |
| | } |
| | ) |
| |
|
| | mask_seqs = ( |
| | motion_seq.get("masks", [])[batch_idx : batch_idx + batch_size] |
| | if "masks" in motion_seq |
| | else None |
| | ) |
| |
|
| | anim_kwargs = { |
| | "gs_model_list": gs_model_list, |
| | "query_points": query_points, |
| | "smplx_params": batch_smplx_params, |
| | "render_c2ws": motion_seq["render_c2ws"][ |
| | :, batch_idx : batch_idx + batch_size |
| | ].to(device), |
| | "render_intrs": motion_seq["render_intrs"][ |
| | :, batch_idx : batch_idx + batch_size |
| | ].to(device), |
| | "render_bg_colors": motion_seq["render_bg_colors"][ |
| | :, batch_idx : batch_idx + batch_size |
| | ].to(device), |
| | "gs_hidden_features": gs_hidden_features, |
| | "image_latents": image_latents, |
| | "motion_emb": motion_emb, |
| | } |
| |
|
| | if pos_emb is not None: |
| | anim_kwargs["pos_emb"] = pos_emb |
| | if offset_list is not None: |
| | anim_kwargs["offset_list"] = offset_list[batch_idx : batch_idx + batch_size] |
| | if mask_seqs is not None: |
| | anim_kwargs["mask_seqs"] = mask_seqs |
| | if output_rgb is not None: |
| | anim_kwargs["output_rgb"] = output_rgb |
| |
|
| | batch_rgb, batch_mask = model.animation_infer(**anim_kwargs) |
| | batch_rgb_list.append((batch_rgb.clamp(0, 1) * 255).to(torch.uint8).numpy()) |
| | batch_mask_list.append((batch_mask.clamp(0, 1) * 255).to(torch.uint8).numpy()) |
| | |
| | print("End of inference") |
| | |
| | if visualized_center: |
| | mask_numpy = np.concatenate(batch_mask_list, axis=0) |
| | h_indices, w_indices = np.where(mask_numpy > 0.25)[1:] |
| |
|
| | if len(h_indices) > 0 and len(w_indices) > 0: |
| | top, bottom = h_indices.min(), h_indices.max() |
| | left, right = w_indices.min(), w_indices.max() |
| |
|
| | center_y, center_x = (top + bottom) / 2, (left + right) / 2 |
| | height, width = bottom - top, right - left |
| | new_height, new_width = height * 1.1, width * 1.1 |
| |
|
| | top_new = max(0, int(center_y - new_height / 2)) |
| | bottom_new = int(center_y + new_height / 2) |
| | left_new = max(0, int(center_x - new_width / 2)) |
| | right_new = int(center_x + new_width / 2) |
| |
|
| | rgb = np.concatenate(batch_rgb_list, axis=0) |
| | output = rgb[:, top_new:bottom_new, left_new:right_new] |
| | else: |
| | output = np.concatenate(batch_rgb_list, axis=0) |
| | else: |
| | output = np.concatenate(batch_rgb_list, axis=0) |
| |
|
| | return output |
| |
|