| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| import torch.nn.functional as F |
|
|
| |
|
|
| num_frequencies = None |
|
|
| |
|
|
|
|
| class CharacterDataset(Dataset): |
| def __init__( |
| self, |
| name: str, |
| dataset_dir: str, |
| standardize: bool, |
| num_feats: int, |
| num_cams: int, |
| sequential: bool, |
| num_frequencies: int, |
| min_freq: int, |
| max_freq: int, |
| load_vertices: bool, |
| **kwargs, |
| ): |
| super().__init__() |
| self.modality = "char" |
| self.name = name |
| self.dataset_dir = Path(dataset_dir) |
| self.traj_dir = self.dataset_dir / "traj" |
| self.data_dir = self.dataset_dir / self.name |
| self.vert_dir = self.dataset_dir / "vert_raw" |
| self.center_dir = self.dataset_dir / "char_raw" |
|
|
| self.filenames = None |
| self.standardize = standardize |
| if self.standardize: |
| mean_std = kwargs["standardization"] |
| self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None] |
| self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None] |
| self.velocity = mean_std["velocity"] |
|
|
| self.num_cams = num_cams |
| self.num_feats = num_feats |
| self.sequential = sequential |
| self.num_frequencies = num_frequencies |
| self.min_freq = min_freq |
| self.max_freq = max_freq |
|
|
| self.load_vertices = load_vertices |
|
|
| def __len__(self): |
| return len(self.filenames) |
|
|
| def __getitem__(self, index): |
| filename = self.filenames[index] |
|
|
| char_filename = filename + ".npy" |
| char_path = self.data_dir / char_filename |
|
|
| raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32) |
| padding_size = self.num_cams - raw_char_feature.shape[0] |
| padded_raw_char_feature = F.pad( |
| raw_char_feature, (0, 0, 0, padding_size) |
| ).permute(1, 0) |
|
|
| center_path = self.center_dir / char_filename |
| center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32) |
| if self.load_vertices: |
| vert_path = self.vert_dir / char_filename |
| raw_verts = np.load(vert_path, allow_pickle=True)[()] |
| if raw_verts["vertices"] is None: |
| num_frames = raw_char_feature.shape[0] |
| verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32) |
| padded_verts = torch.zeros( |
| (self.num_cams, 6890, 3), dtype=torch.float32 |
| ) |
| faces = torch.zeros((13776, 3), dtype=torch.int16) |
| else: |
| verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32) |
| verts -= center_offset |
| padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size)) |
| faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16) |
|
|
| char_feature = raw_char_feature.clone() |
| if self.velocity: |
| velocity = char_feature[1:].clone() - char_feature[:-1].clone() |
| char_feature = torch.cat([raw_char_feature[0][None], velocity]) |
|
|
| if self.standardize: |
| |
| if len(self.norm_mean) == 6: |
| char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device) |
| char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device) |
| char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device) |
| char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device) |
| |
| else: |
| char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device) |
| char_feature /= self.norm_std[:, 0].to(raw_char_feature.device) |
| padded_char_feature = F.pad( |
| char_feature, |
| (0, 0, 0, self.num_cams - char_feature.shape[0]), |
| ) |
|
|
| if self.sequential: |
| padded_char_feature = padded_char_feature.permute(1, 0) |
| else: |
| padded_char_feature = padded_char_feature.reshape(-1) |
|
|
| raw_feats = {"char_raw_feat": padded_raw_char_feature} |
| if self.load_vertices: |
| raw_feats["char_vertices"] = padded_verts |
| raw_feats["char_faces"] = faces |
|
|
| return char_filename, padded_char_feature, raw_feats |
|
|