|
|
import torch |
|
|
from pathlib import Path |
|
|
from torch.utils.data import Dataset |
|
|
from genmo.utils.pylogger import Log |
|
|
from genmo.utils.net_utils import repeat_to_max_len, repeat_to_max_len_dict, get_valid_mask |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def axis_angle_to_6d(feat_3d): |
|
|
""" |
|
|
Converts 3D axis-angle vectors to 6D rotation representation. |
|
|
Input: (B, 3) or (L, 3) |
|
|
Output: (B, 6) or (L, 6) |
|
|
""" |
|
|
|
|
|
batch_size = feat_3d.shape[0] |
|
|
device = feat_3d.device |
|
|
dtype = feat_3d.dtype |
|
|
|
|
|
angle = torch.norm(feat_3d, dim=1, keepdim=True).clamp_min(1e-8) |
|
|
rot_dir = feat_3d / angle |
|
|
|
|
|
cos = torch.cos(angle) |
|
|
sin = torch.sin(angle) |
|
|
|
|
|
rx, ry, rz = rot_dir[:, 0], rot_dir[:, 1], rot_dir[:, 2] |
|
|
zeros = torch.zeros((batch_size), dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
K = torch.stack([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(batch_size, 3, 3) |
|
|
|
|
|
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(0) |
|
|
R = ident + sin.view(batch_size, 1, 1) * K + (1 - cos.view(batch_size, 1, 1)) * torch.bmm(K, K) |
|
|
|
|
|
|
|
|
|
|
|
return R[:, :, :2].transpose(1, 2).reshape(batch_size, 6) |
|
|
|
|
|
class UnityDataset(Dataset): |
|
|
def __init__(self, root, split, motion_frames): |
|
|
self.root = Path(root) |
|
|
self.split = split |
|
|
self.motion_frames = motion_frames |
|
|
|
|
|
|
|
|
self.feat_dir = self.root / "genmo_features" |
|
|
|
|
|
if not self.feat_dir.exists(): |
|
|
raise FileNotFoundError(f"Feature dir not found: {self.feat_dir}") |
|
|
|
|
|
self.pt_files = sorted(list(self.feat_dir.glob("*.pt"))) |
|
|
|
|
|
if not self.pt_files: |
|
|
raise FileNotFoundError(f"No .pt files found in {self.feat_dir}") |
|
|
|
|
|
Log.info(f"[UnityDataset] Found {len(self.pt_files)} sequences.") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.pt_files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
pt_path = self.pt_files[idx] |
|
|
|
|
|
|
|
|
data = torch.load(pt_path, map_location="cpu", weights_only=False) |
|
|
|
|
|
|
|
|
full_len = data["f_imgseq"].shape[0] |
|
|
|
|
|
if full_len > self.motion_frames: |
|
|
start_idx = torch.randint(0, full_len - self.motion_frames, (1,)).item() |
|
|
end_idx = start_idx + self.motion_frames |
|
|
else: |
|
|
start_idx = 0 |
|
|
end_idx = full_len |
|
|
|
|
|
length = end_idx - start_idx |
|
|
max_len = self.motion_frames |
|
|
|
|
|
def repeat_list_to_max_len(items, max_len): |
|
|
if len(items) == 0: |
|
|
return [""] * max_len |
|
|
if len(items) >= max_len: |
|
|
return items[:max_len] |
|
|
return items + [items[-1]] * (max_len - len(items)) |
|
|
|
|
|
|
|
|
def sl(tensor_or_dict): |
|
|
if isinstance(tensor_or_dict, dict): |
|
|
return {k: sl(v) for k, v in tensor_or_dict.items()} |
|
|
if isinstance(tensor_or_dict, (torch.Tensor, np.ndarray)): |
|
|
return tensor_or_dict[start_idx:end_idx] |
|
|
return tensor_or_dict |
|
|
|
|
|
|
|
|
f_imgseq = sl(data["f_imgseq"]) |
|
|
|
|
|
|
|
|
if f_imgseq.shape[-1] != 1024: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
bbx_xys = sl(data["bbx_xys"]).float() |
|
|
kp2d = sl(data["kp2d"]).float() |
|
|
K_fullimg = sl(data["K_fullimg"]) |
|
|
T_w2c = sl(data["T_w2c"]) |
|
|
|
|
|
|
|
|
R_w2c = T_w2c[:, :3, :3] |
|
|
R_c2gv = R_w2c.transpose(1, 2) |
|
|
|
|
|
smpl_params_c = sl(data["smpl_params_c"]) |
|
|
smpl_params_w = sl(data["smpl_params_w"]) |
|
|
|
|
|
|
|
|
smpl_params_metric = {} |
|
|
for k, v in smpl_params_w.items(): |
|
|
if k == "body_pose" and v.shape[-1] == 63: |
|
|
padding = torch.zeros((v.shape[0], 6), dtype=v.dtype, device=v.device) |
|
|
smpl_params_metric[k] = torch.cat([v, padding], dim=-1) |
|
|
else: |
|
|
smpl_params_metric[k] = v |
|
|
|
|
|
cam_angvel = sl(data["cam_angvel"]) |
|
|
cam_tvel = sl(data["cam_tvel"]) |
|
|
|
|
|
imgname = data.get("imgname", None) |
|
|
if isinstance(imgname, list): |
|
|
imgname = imgname[start_idx:end_idx] |
|
|
img_paths = [str(self.root / p) for p in imgname] |
|
|
img_paths = repeat_list_to_max_len(img_paths, max_len) |
|
|
else: |
|
|
img_paths = None |
|
|
|
|
|
|
|
|
if cam_angvel.shape[-1] == 3: |
|
|
cam_angvel_6d = axis_angle_to_6d(cam_angvel) |
|
|
else: |
|
|
cam_angvel_6d = cam_angvel |
|
|
|
|
|
gender = str(data.get("gender", "female")) |
|
|
|
|
|
|
|
|
ret = { |
|
|
"meta": { |
|
|
"id": pt_path.stem, |
|
|
"dataset_id": "Unity", |
|
|
"vid": pt_path.stem, |
|
|
}, |
|
|
|
|
|
"meta_render": { |
|
|
"img_paths": img_paths, |
|
|
}, |
|
|
"B": 1, |
|
|
"gender": gender, |
|
|
"length": length, |
|
|
|
|
|
|
|
|
"f_imgseq": repeat_to_max_len(f_imgseq, max_len), |
|
|
"bbx_xys": repeat_to_max_len(bbx_xys, max_len), |
|
|
"kp2d": repeat_to_max_len(kp2d, max_len), |
|
|
"cam_angvel": repeat_to_max_len(cam_angvel_6d, max_len), |
|
|
"cam_tvel": repeat_to_max_len(cam_tvel, max_len), |
|
|
|
|
|
|
|
|
"smpl_params": repeat_to_max_len_dict(smpl_params_metric, max_len), |
|
|
"smpl_params_c": repeat_to_max_len_dict(smpl_params_c, max_len), |
|
|
"smpl_params_w": repeat_to_max_len_dict(smpl_params_w, max_len), |
|
|
"gt_T_w2c": repeat_to_max_len(T_w2c, max_len), |
|
|
"T_w2c": repeat_to_max_len(T_w2c, max_len), |
|
|
"R_c2gv": repeat_to_max_len(R_c2gv, max_len), |
|
|
"K_fullimg": repeat_to_max_len(K_fullimg, max_len), |
|
|
|
|
|
|
|
|
"caption": "", |
|
|
"has_text": False, |
|
|
|
|
|
"mask": { |
|
|
|
|
|
"valid": get_valid_mask(max_len, length), |
|
|
"has_img_mask": get_valid_mask(max_len, length), |
|
|
"has_2d_mask": get_valid_mask(max_len, length), |
|
|
"has_cam_mask": get_valid_mask(max_len, length), |
|
|
|
|
|
|
|
|
"bbx_xys": True, |
|
|
"f_imgseq": True, |
|
|
"2d_only": False, |
|
|
"spv_incam_only": False, |
|
|
"vitpose": False, |
|
|
"invalid_contact": False, |
|
|
|
|
|
|
|
|
"has_audio_mask": torch.zeros(max_len).bool(), |
|
|
"has_music_mask": torch.zeros(max_len).bool(), |
|
|
} |
|
|
} |
|
|
|
|
|
return ret |
|
|
|