hmr-dataset / genmo /datasets /unity_dataset.py
zirobtc's picture
Upload folder using huggingface_hub
d229d12 verified
import torch
from pathlib import Path
from torch.utils.data import Dataset
from genmo.utils.pylogger import Log
from genmo.utils.net_utils import repeat_to_max_len, repeat_to_max_len_dict, get_valid_mask
import numpy as np
# --- Helper: Axis-Angle (3D) to Rotation 6D ---
def axis_angle_to_6d(feat_3d):
"""
Converts 3D axis-angle vectors to 6D rotation representation.
Input: (B, 3) or (L, 3)
Output: (B, 6) or (L, 6)
"""
# Axis-angle -> rotation matrix (Rodrigues), then take the first two columns (Zhou et al. 6D).
batch_size = feat_3d.shape[0]
device = feat_3d.device
dtype = feat_3d.dtype
angle = torch.norm(feat_3d, dim=1, keepdim=True).clamp_min(1e-8)
rot_dir = feat_3d / angle
cos = torch.cos(angle)
sin = torch.sin(angle)
rx, ry, rz = rot_dir[:, 0], rot_dir[:, 1], rot_dir[:, 2]
zeros = torch.zeros((batch_size), dtype=dtype, device=device)
# Cross product matrix K
K = torch.stack([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(batch_size, 3, 3)
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(0)
R = ident + sin.view(batch_size, 1, 1) * K + (1 - cos.view(batch_size, 1, 1)) * torch.bmm(K, K)
# (B, 3, 3) -> (B, 6) by flattening the first two columns column-wise:
# [r00, r10, r20, r01, r11, r21]
return R[:, :, :2].transpose(1, 2).reshape(batch_size, 6)
class UnityDataset(Dataset):
def __init__(self, root, split, motion_frames):
self.root = Path(root)
self.split = split
self.motion_frames = motion_frames
# Path to the .pt files (Cached Features)
self.feat_dir = self.root / "genmo_features"
if not self.feat_dir.exists():
raise FileNotFoundError(f"Feature dir not found: {self.feat_dir}")
self.pt_files = sorted(list(self.feat_dir.glob("*.pt")))
if not self.pt_files:
raise FileNotFoundError(f"No .pt files found in {self.feat_dir}")
Log.info(f"[UnityDataset] Found {len(self.pt_files)} sequences.")
def __len__(self):
return len(self.pt_files)
def __getitem__(self, idx):
pt_path = self.pt_files[idx]
# Load data (CPU)
data = torch.load(pt_path, map_location="cpu", weights_only=False)
# 1. Determine Length & Slice
full_len = data["f_imgseq"].shape[0]
if full_len > self.motion_frames:
start_idx = torch.randint(0, full_len - self.motion_frames, (1,)).item()
end_idx = start_idx + self.motion_frames
else:
start_idx = 0
end_idx = full_len
length = end_idx - start_idx
max_len = self.motion_frames
def repeat_list_to_max_len(items, max_len):
if len(items) == 0:
return [""] * max_len
if len(items) >= max_len:
return items[:max_len]
return items + [items[-1]] * (max_len - len(items))
# Helper to slice tensors
def sl(tensor_or_dict):
if isinstance(tensor_or_dict, dict):
return {k: sl(v) for k, v in tensor_or_dict.items()}
if isinstance(tensor_or_dict, (torch.Tensor, np.ndarray)):
return tensor_or_dict[start_idx:end_idx]
return tensor_or_dict
# --- 2. Extract Data ---
f_imgseq = sl(data["f_imgseq"])
# Check Dimension for Features (Must be 1024)
if f_imgseq.shape[-1] != 1024:
# If your features are not 1024, the Linear layer later will also crash.
# Assuming they are correct for now, or you might need a projection here.
pass
bbx_xys = sl(data["bbx_xys"]).float()
kp2d = sl(data["kp2d"]).float()
K_fullimg = sl(data["K_fullimg"])
T_w2c = sl(data["T_w2c"])
# Extrinsics calc
R_w2c = T_w2c[:, :3, :3]
R_c2gv = R_w2c.transpose(1, 2)
smpl_params_c = sl(data["smpl_params_c"])
smpl_params_w = sl(data["smpl_params_w"])
# Prepare smpl_params for MetricMocap (needs 23 joints / 69 dims)
smpl_params_metric = {}
for k, v in smpl_params_w.items():
if k == "body_pose" and v.shape[-1] == 63:
padding = torch.zeros((v.shape[0], 6), dtype=v.dtype, device=v.device)
smpl_params_metric[k] = torch.cat([v, padding], dim=-1)
else:
smpl_params_metric[k] = v
cam_angvel = sl(data["cam_angvel"]) # (L, 3)
cam_tvel = sl(data["cam_tvel"]) # (L, 3)
imgname = data.get("imgname", None)
if isinstance(imgname, list):
imgname = imgname[start_idx:end_idx]
img_paths = [str(self.root / p) for p in imgname]
img_paths = repeat_list_to_max_len(img_paths, max_len)
else:
img_paths = None
# --- FIX: Convert 3D AngVel to 6D ---
if cam_angvel.shape[-1] == 3:
cam_angvel_6d = axis_angle_to_6d(cam_angvel)
else:
cam_angvel_6d = cam_angvel
gender = str(data.get("gender", "female"))
# --- 3. Construct Return Dict ---
ret = {
"meta": {
"id": pt_path.stem,
"dataset_id": "Unity",
"vid": pt_path.stem,
},
# Rendering-only metadata. Collate keeps `meta*` keys as lists without stacking.
"meta_render": {
"img_paths": img_paths,
},
"B": 1,
"gender": gender,
"length": length,
# Input Features
"f_imgseq": repeat_to_max_len(f_imgseq, max_len),
"bbx_xys": repeat_to_max_len(bbx_xys, max_len),
"kp2d": repeat_to_max_len(kp2d, max_len),
"cam_angvel": repeat_to_max_len(cam_angvel_6d, max_len), # Pass 6D here
"cam_tvel": repeat_to_max_len(cam_tvel, max_len),
# Ground Truth
"smpl_params": repeat_to_max_len_dict(smpl_params_metric, max_len),
"smpl_params_c": repeat_to_max_len_dict(smpl_params_c, max_len),
"smpl_params_w": repeat_to_max_len_dict(smpl_params_w, max_len),
"gt_T_w2c": repeat_to_max_len(T_w2c, max_len),
"T_w2c": repeat_to_max_len(T_w2c, max_len),
"R_c2gv": repeat_to_max_len(R_c2gv, max_len),
"K_fullimg": repeat_to_max_len(K_fullimg, max_len),
# --- Text Fix (Prevents KeyError/crash if text encoder is on) ---
"caption": "",
"has_text": False,
"mask": {
# 1. Tensor Masks (Per frame)
"valid": get_valid_mask(max_len, length),
"has_img_mask": get_valid_mask(max_len, length),
"has_2d_mask": get_valid_mask(max_len, length),
"has_cam_mask": get_valid_mask(max_len, length),
# 2. Scalar Flags
"bbx_xys": True,
"f_imgseq": True,
"2d_only": False,
"spv_incam_only": False,
"vitpose": False,
"invalid_contact": False,
# 3. Missing Modalities
"has_audio_mask": torch.zeros(max_len).bool(),
"has_music_mask": torch.zeros(max_len).bool(),
}
}
return ret