import torch from pathlib import Path from torch.utils.data import Dataset from genmo.utils.pylogger import Log from genmo.utils.net_utils import repeat_to_max_len, repeat_to_max_len_dict, get_valid_mask import numpy as np # --- Helper: Axis-Angle (3D) to Rotation 6D --- def axis_angle_to_6d(feat_3d): """ Converts 3D axis-angle vectors to 6D rotation representation. Input: (B, 3) or (L, 3) Output: (B, 6) or (L, 6) """ # Axis-angle -> rotation matrix (Rodrigues), then take the first two columns (Zhou et al. 6D). batch_size = feat_3d.shape[0] device = feat_3d.device dtype = feat_3d.dtype angle = torch.norm(feat_3d, dim=1, keepdim=True).clamp_min(1e-8) rot_dir = feat_3d / angle cos = torch.cos(angle) sin = torch.sin(angle) rx, ry, rz = rot_dir[:, 0], rot_dir[:, 1], rot_dir[:, 2] zeros = torch.zeros((batch_size), dtype=dtype, device=device) # Cross product matrix K K = torch.stack([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(batch_size, 3, 3) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(0) R = ident + sin.view(batch_size, 1, 1) * K + (1 - cos.view(batch_size, 1, 1)) * torch.bmm(K, K) # (B, 3, 3) -> (B, 6) by flattening the first two columns column-wise: # [r00, r10, r20, r01, r11, r21] return R[:, :, :2].transpose(1, 2).reshape(batch_size, 6) class UnityDataset(Dataset): def __init__(self, root, split, motion_frames): self.root = Path(root) self.split = split self.motion_frames = motion_frames # Path to the .pt files (Cached Features) self.feat_dir = self.root / "genmo_features" if not self.feat_dir.exists(): raise FileNotFoundError(f"Feature dir not found: {self.feat_dir}") self.pt_files = sorted(list(self.feat_dir.glob("*.pt"))) if not self.pt_files: raise FileNotFoundError(f"No .pt files found in {self.feat_dir}") Log.info(f"[UnityDataset] Found {len(self.pt_files)} sequences.") def __len__(self): return len(self.pt_files) def __getitem__(self, idx): pt_path = self.pt_files[idx] # Load data (CPU) data = torch.load(pt_path, map_location="cpu", weights_only=False) # 1. Determine Length & Slice full_len = data["f_imgseq"].shape[0] if full_len > self.motion_frames: start_idx = torch.randint(0, full_len - self.motion_frames, (1,)).item() end_idx = start_idx + self.motion_frames else: start_idx = 0 end_idx = full_len length = end_idx - start_idx max_len = self.motion_frames def repeat_list_to_max_len(items, max_len): if len(items) == 0: return [""] * max_len if len(items) >= max_len: return items[:max_len] return items + [items[-1]] * (max_len - len(items)) # Helper to slice tensors def sl(tensor_or_dict): if isinstance(tensor_or_dict, dict): return {k: sl(v) for k, v in tensor_or_dict.items()} if isinstance(tensor_or_dict, (torch.Tensor, np.ndarray)): return tensor_or_dict[start_idx:end_idx] return tensor_or_dict # --- 2. Extract Data --- f_imgseq = sl(data["f_imgseq"]) # Check Dimension for Features (Must be 1024) if f_imgseq.shape[-1] != 1024: # If your features are not 1024, the Linear layer later will also crash. # Assuming they are correct for now, or you might need a projection here. pass bbx_xys = sl(data["bbx_xys"]).float() kp2d = sl(data["kp2d"]).float() K_fullimg = sl(data["K_fullimg"]) T_w2c = sl(data["T_w2c"]) # Extrinsics calc R_w2c = T_w2c[:, :3, :3] R_c2gv = R_w2c.transpose(1, 2) smpl_params_c = sl(data["smpl_params_c"]) smpl_params_w = sl(data["smpl_params_w"]) # Prepare smpl_params for MetricMocap (needs 23 joints / 69 dims) smpl_params_metric = {} for k, v in smpl_params_w.items(): if k == "body_pose" and v.shape[-1] == 63: padding = torch.zeros((v.shape[0], 6), dtype=v.dtype, device=v.device) smpl_params_metric[k] = torch.cat([v, padding], dim=-1) else: smpl_params_metric[k] = v cam_angvel = sl(data["cam_angvel"]) # (L, 3) cam_tvel = sl(data["cam_tvel"]) # (L, 3) imgname = data.get("imgname", None) if isinstance(imgname, list): imgname = imgname[start_idx:end_idx] img_paths = [str(self.root / p) for p in imgname] img_paths = repeat_list_to_max_len(img_paths, max_len) else: img_paths = None # --- FIX: Convert 3D AngVel to 6D --- if cam_angvel.shape[-1] == 3: cam_angvel_6d = axis_angle_to_6d(cam_angvel) else: cam_angvel_6d = cam_angvel gender = str(data.get("gender", "female")) # --- 3. Construct Return Dict --- ret = { "meta": { "id": pt_path.stem, "dataset_id": "Unity", "vid": pt_path.stem, }, # Rendering-only metadata. Collate keeps `meta*` keys as lists without stacking. "meta_render": { "img_paths": img_paths, }, "B": 1, "gender": gender, "length": length, # Input Features "f_imgseq": repeat_to_max_len(f_imgseq, max_len), "bbx_xys": repeat_to_max_len(bbx_xys, max_len), "kp2d": repeat_to_max_len(kp2d, max_len), "cam_angvel": repeat_to_max_len(cam_angvel_6d, max_len), # Pass 6D here "cam_tvel": repeat_to_max_len(cam_tvel, max_len), # Ground Truth "smpl_params": repeat_to_max_len_dict(smpl_params_metric, max_len), "smpl_params_c": repeat_to_max_len_dict(smpl_params_c, max_len), "smpl_params_w": repeat_to_max_len_dict(smpl_params_w, max_len), "gt_T_w2c": repeat_to_max_len(T_w2c, max_len), "T_w2c": repeat_to_max_len(T_w2c, max_len), "R_c2gv": repeat_to_max_len(R_c2gv, max_len), "K_fullimg": repeat_to_max_len(K_fullimg, max_len), # --- Text Fix (Prevents KeyError/crash if text encoder is on) --- "caption": "", "has_text": False, "mask": { # 1. Tensor Masks (Per frame) "valid": get_valid_mask(max_len, length), "has_img_mask": get_valid_mask(max_len, length), "has_2d_mask": get_valid_mask(max_len, length), "has_cam_mask": get_valid_mask(max_len, length), # 2. Scalar Flags "bbx_xys": True, "f_imgseq": True, "2d_only": False, "spv_incam_only": False, "vitpose": False, "invalid_contact": False, # 3. Missing Modalities "has_audio_mask": torch.zeros(max_len).bool(), "has_music_mask": torch.zeros(max_len).bool(), } } return ret