| import cv2 |
| import torch |
| import pytorch_lightning as pl |
| import numpy as np |
| import argparse |
| from hmr4d.utils.pylogger import Log |
| import hydra |
| from hydra import initialize_config_module, compose |
| from pathlib import Path |
| from pytorch3d.transforms import quaternion_to_matrix |
|
|
| from hmr4d.configs import register_store_gvhmr |
| from hmr4d.utils.video_io_utils import ( |
| get_video_lwh, |
| read_video_np, |
| save_video, |
| merge_videos_horizontal, |
| get_writer, |
| get_video_reader, |
| ) |
| from hmr4d.utils.vis.cv2_utils import draw_bbx_xyxy_on_image_batch, draw_coco17_skeleton_batch |
|
|
| from hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SLAMModel |
|
|
| from hmr4d.utils.geo.hmr_cam import get_bbx_xys_from_xyxy, estimate_K, convert_K_to_K4, create_camera_sensor |
| from hmr4d.utils.geo_transform import compute_cam_angvel |
| from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL |
| from hmr4d.utils.net_utils import detach_to_cpu, to_cuda |
| from hmr4d.utils.smplx_utils import make_smplx |
| from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static, get_ground_params_from_points |
| from tqdm import tqdm |
| from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay |
| from einops import einsum, rearrange |
|
|
|
|
| CRF = 23 |
|
|
|
|
| def parse_args_to_cfg(): |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--video", type=str, default="inputs/demo/dance_3.mp4") |
| parser.add_argument("--output_root", type=str, default=None, help="by default to outputs/demo") |
| parser.add_argument("-s", "--static_cam", action="store_true", help="If true, skip DPVO") |
| parser.add_argument("--verbose", action="store_true", help="If true, draw intermediate results") |
| args = parser.parse_args() |
|
|
| |
| video_path = Path(args.video) |
| assert video_path.exists(), f"Video not found at {video_path}" |
| length, width, height = get_video_lwh(video_path) |
| Log.info(f"[Input]: {video_path}") |
| Log.info(f"(L, W, H) = ({length}, {width}, {height})") |
| |
| with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"): |
| overrides = [ |
| f"video_name={video_path.stem}", |
| f"static_cam={args.static_cam}", |
| f"verbose={args.verbose}", |
| ] |
|
|
| |
| if args.output_root is not None: |
| overrides.append(f"output_root={args.output_root}") |
| register_store_gvhmr() |
| cfg = compose(config_name="demo", overrides=overrides) |
|
|
| |
| Log.info(f"[Output Dir]: {cfg.output_dir}") |
| Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) |
| Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True) |
|
|
| |
| Log.info(f"[Copy Video] {video_path} -> {cfg.video_path}") |
| if not Path(cfg.video_path).exists() or get_video_lwh(video_path)[0] != get_video_lwh(cfg.video_path)[0]: |
| reader = get_video_reader(video_path) |
| writer = get_writer(cfg.video_path, fps=30, crf=CRF) |
| for img in tqdm(reader, total=get_video_lwh(video_path)[0], desc=f"Copy"): |
| writer.write_frame(img) |
| writer.close() |
| reader.close() |
|
|
| return cfg |
|
|
|
|
| @torch.no_grad() |
| def run_preprocess(cfg, progress): |
| Log.info(f"[Preprocess] Start!") |
| tic = Log.time() |
| video_path = cfg.video_path |
| paths = cfg.paths |
| static_cam = cfg.static_cam |
| verbose = cfg.verbose |
|
|
| |
| progress(0, '[Preprocess] YoloV8 Tracking') |
| if not Path(paths.bbx).exists(): |
| tracker = Tracker() |
| bbx_xyxy = tracker.get_one_track(video_path).float() |
| bbx_xys = get_bbx_xys_from_xyxy(bbx_xyxy, base_enlarge=1.2).float() |
| torch.save({"bbx_xyxy": bbx_xyxy, "bbx_xys": bbx_xys}, paths.bbx) |
| del tracker |
| else: |
| bbx_xys = torch.load(paths.bbx)["bbx_xys"] |
| Log.info(f"[Preprocess] bbx (xyxy, xys) from {paths.bbx}") |
| if verbose: |
| video = read_video_np(video_path) |
| bbx_xyxy = torch.load(paths.bbx)["bbx_xyxy"] |
| video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video) |
| save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay) |
|
|
| |
| progress(1/4, '[Preprocess] ViTPose') |
| if not Path(paths.vitpose).exists(): |
| vitpose_extractor = VitPoseExtractor() |
| vitpose = vitpose_extractor.extract(video_path, bbx_xys) |
| torch.save(vitpose, paths.vitpose) |
| del vitpose_extractor |
| else: |
| vitpose = torch.load(paths.vitpose) |
| Log.info(f"[Preprocess] vitpose from {paths.vitpose}") |
| if verbose: |
| video = read_video_np(video_path) |
| video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5) |
| save_video(video_overlay, paths.vitpose_video_overlay) |
|
|
| |
| progress(2/4, '[Preprocess] HMR2 Feature') |
| if not Path(paths.vit_features).exists(): |
| extractor = Extractor() |
| vit_features = extractor.extract_video_features(video_path, bbx_xys) |
| torch.save(vit_features, paths.vit_features) |
| del extractor |
| else: |
| Log.info(f"[Preprocess] vit_features from {paths.vit_features}") |
|
|
| |
| progress(3/4, '[Preprocess] DPVO') |
| if not static_cam: |
| if not Path(paths.slam).exists(): |
| length, width, height = get_video_lwh(cfg.video_path) |
| K_fullimg = estimate_K(width, height) |
| intrinsics = convert_K_to_K4(K_fullimg) |
| slam = SLAMModel(video_path, width, height, intrinsics, buffer=4000, resize=0.5) |
| bar = tqdm(total=length, desc="DPVO") |
| while True: |
| ret = slam.track() |
| if ret: |
| bar.update() |
| else: |
| break |
| slam_results = slam.process() |
| torch.save(slam_results, paths.slam) |
| else: |
| Log.info(f"[Preprocess] slam results from {paths.slam}") |
|
|
| Log.info(f"[Preprocess] End. Time elapsed: {Log.time()-tic:.2f}s") |
|
|
|
|
| def load_data_dict(cfg): |
| paths = cfg.paths |
| length, width, height = get_video_lwh(cfg.video_path) |
| if cfg.static_cam: |
| R_w2c = torch.eye(3).repeat(length, 1, 1) |
| else: |
| traj = torch.load(cfg.paths.slam) |
| traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]]) |
| R_w2c = quaternion_to_matrix(traj_quat).mT |
| K_fullimg = estimate_K(width, height).repeat(length, 1, 1) |
| |
|
|
| data = { |
| "length": torch.tensor(length), |
| "bbx_xys": torch.load(paths.bbx)["bbx_xys"], |
| "kp2d": torch.load(paths.vitpose), |
| "K_fullimg": K_fullimg, |
| "cam_angvel": compute_cam_angvel(R_w2c), |
| "f_imgseq": torch.load(paths.vit_features), |
| } |
| return data |
|
|
|
|
| def render_incam(cfg, pred, smpl_utils): |
| incam_video_path = Path(cfg.paths.incam_video) |
| if incam_video_path.exists(): |
| Log.info(f"[Render Incam] Video already exists at {incam_video_path}") |
| return |
|
|
| |
| smplx = smpl_utils['smplx'] |
| smplx2smpl = smpl_utils['smplx2smpl'] |
| faces_smpl = smpl_utils['faces_smpl'] |
|
|
| |
| smplx_out = smplx(**to_cuda(pred["smpl_params_incam"])) |
| pred_c_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]) |
|
|
| |
| video_path = cfg.video_path |
| length, width, height = get_video_lwh(video_path) |
| K = pred["K_fullimg"][0] |
|
|
| |
| renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K) |
| reader = get_video_reader(video_path) |
| bbx_xys_render = torch.load(cfg.paths.bbx)["bbx_xys"] |
|
|
| |
| verts_incam = pred_c_verts |
| writer = get_writer(incam_video_path, fps=30, crf=CRF) |
| for i, img_raw in tqdm(enumerate(reader), total=get_video_lwh(video_path)[0], desc=f"Rendering Incam"): |
| img = renderer.render_mesh(verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8]) |
|
|
| |
| |
| |
| |
| |
|
|
| writer.write_frame(img) |
| writer.close() |
| reader.close() |
|
|
|
|
| def render_global(cfg, pred, smpl_utils): |
| global_video_path = Path(cfg.paths.global_video) |
| if global_video_path.exists(): |
| Log.info(f"[Render Global] Video already exists at {global_video_path}") |
| return |
|
|
| debug_cam = False |
| |
| smplx = smpl_utils['smplx'] |
| smplx2smpl = smpl_utils['smplx2smpl'] |
| faces_smpl = smpl_utils['faces_smpl'] |
| J_regressor = smpl_utils['J_regressor'] |
|
|
| |
| smplx_out = smplx(**to_cuda(pred["smpl_params_global"])) |
| pred_ay_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]) |
|
|
| def move_to_start_point_face_z(verts): |
| "XZ to origin, Start from the ground, Face-Z" |
| |
| verts = verts.clone() |
| offset = einsum(J_regressor, verts[0], "j v, v i -> j i")[0] |
| offset[1] = verts[:, :, [1]].min() |
| verts = verts - offset |
| |
| T_ay2ayfz = compute_T_ayfz2ay(einsum(J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True) |
| verts = apply_T_on_points(verts, T_ay2ayfz) |
| return verts |
|
|
| verts_glob = move_to_start_point_face_z(pred_ay_verts) |
| joints_glob = einsum(J_regressor, verts_glob, "j v, l v i -> l j i") |
| global_R, global_T, global_lights = get_global_cameras_static( |
| verts_glob.cpu(), |
| beta=2.0, |
| cam_height_degree=20, |
| target_center_height=1.0, |
| ) |
|
|
| |
| video_path = cfg.video_path |
| length, width, height = get_video_lwh(video_path) |
| _, _, K = create_camera_sensor(width, height, 24) |
|
|
| |
| renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K) |
| |
|
|
| |
| scale, cx, cz = get_ground_params_from_points(joints_glob[:, 0], verts_glob) |
| renderer.set_ground(scale * 1.5, cx, cz) |
| color = torch.ones(3).float().cuda() * 0.8 |
|
|
| render_length = length if not debug_cam else 8 |
| writer = get_writer(global_video_path, fps=30, crf=CRF) |
| for i in tqdm(range(render_length), desc=f"Rendering Global"): |
| cameras = renderer.create_camera(global_R[i], global_T[i]) |
| img = renderer.render_with_ground(verts_glob[[i]], color[None], cameras, global_lights) |
| writer.write_frame(img) |
| writer.close() |
|
|
|
|
| if __name__ == "__main__": |
| cfg = parse_args_to_cfg() |
| paths = cfg.paths |
| Log.info(f"[GPU]: {torch.cuda.get_device_name()}") |
| Log.info(f'[GPU]: {torch.cuda.get_device_properties("cuda")}') |
|
|
| |
| run_preprocess(cfg) |
| data = load_data_dict(cfg) |
|
|
| |
| if not Path(paths.hmr4d_results).exists(): |
| Log.info("[HMR4D] Predicting") |
| model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False) |
| model.load_pretrained_model(cfg.ckpt_path) |
| model = model.eval().cuda() |
| tic = Log.sync_time() |
| pred = model.predict(data, static_cam=cfg.static_cam) |
| pred = detach_to_cpu(pred) |
| data_time = data["length"] / 30 |
| Log.info(f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s") |
| torch.save(pred, paths.hmr4d_results) |
|
|
| |
| render_incam(cfg) |
| render_global(cfg) |
| if not Path(paths.incam_global_horiz_video).exists(): |
| Log.info("[Merge Videos]") |
| merge_videos_horizontal([paths.incam_video, paths.global_video], paths.incam_global_horiz_video) |