| | import json |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | class SGDataset(Dataset): |
| | def __init__( |
| | self, |
| | path_data_definition, |
| | path_processed_data, |
| | window, |
| | style_encoding_type, |
| | example_window_length, |
| | ): |
| | """PyTorch Dataset Instance |
| | |
| | Args: |
| | path_data_definition : Path to data_definition file |
| | path_processed_data : Path to processed_data npz file |
| | window : Length of the input-output slice |
| | style_encoding_type : "label" or "example" |
| | example_window_length : Length of example window |
| | """ |
| |
|
| | with open(path_data_definition, "r") as f: |
| | details = json.load(f) |
| | self.details = details |
| | self.njoints = len(details["bone_names"]) |
| | self.nlabels = len(details["label_names"]) |
| | self.label_names = details["label_names"] |
| | self.bone_names = details["bone_names"] |
| | self.parents = torch.LongTensor(details["parents"]) |
| | self.dt = details["dt"] |
| | self.window = window |
| | self.style_encoding_type = style_encoding_type |
| | self.example_window_length = example_window_length |
| |
|
| | |
| | processed_data = np.load(path_processed_data) |
| |
|
| | self.ranges_train = processed_data["ranges_train"] |
| | self.ranges_valid = processed_data["ranges_valid"] |
| | self.ranges_train_labels = processed_data["ranges_train_labels"] |
| | self.ranges_valid_labels = processed_data["ranges_valid_labels"] |
| |
|
| | self.X_audio_features = torch.as_tensor( |
| | processed_data["X_audio_features"], dtype=torch.float32 |
| | ) |
| | self.Y_root_pos = torch.as_tensor(processed_data["Y_root_pos"], dtype=torch.float32) |
| | self.Y_root_rot = torch.as_tensor(processed_data["Y_root_rot"], dtype=torch.float32) |
| | self.Y_root_vel = torch.as_tensor(processed_data["Y_root_vel"], dtype=torch.float32) |
| | self.Y_root_vrt = torch.as_tensor(processed_data["Y_root_vrt"], dtype=torch.float32) |
| | self.Y_lpos = torch.as_tensor(processed_data["Y_lpos"], dtype=torch.float32) |
| | self.Y_ltxy = torch.as_tensor(processed_data["Y_ltxy"], dtype=torch.float32) |
| | self.Y_lvel = torch.as_tensor(processed_data["Y_lvel"], dtype=torch.float32) |
| | self.Y_lvrt = torch.as_tensor(processed_data["Y_lvrt"], dtype=torch.float32) |
| | self.Y_gaze_pos = torch.as_tensor(processed_data["Y_gaze_pos"], dtype=torch.float32) |
| |
|
| | self.audio_input_mean = torch.as_tensor( |
| | processed_data["audio_input_mean"], dtype=torch.float32 |
| | ) |
| | self.audio_input_std = torch.as_tensor( |
| | processed_data["audio_input_std"], dtype=torch.float32 |
| | ) |
| | self.anim_input_mean = torch.as_tensor( |
| | processed_data["anim_input_mean"], dtype=torch.float32 |
| | ) |
| | self.anim_input_std = torch.as_tensor(processed_data["anim_input_std"], dtype=torch.float32) |
| | self.anim_output_mean = torch.as_tensor( |
| | processed_data["anim_output_mean"], dtype=torch.float32 |
| | ) |
| | self.anim_output_std = torch.as_tensor( |
| | processed_data["anim_output_std"], dtype=torch.float32 |
| | ) |
| |
|
| | |
| | R = [] |
| | L = [] |
| | S = [] |
| | for sample_number, ((range_start, range_end), range_label) in enumerate( |
| | zip(self.ranges_train, self.ranges_train_labels) |
| | ): |
| |
|
| | one_hot_label = np.zeros(self.nlabels, dtype=np.float32) |
| | one_hot_label[range_label] = 1.0 |
| |
|
| | for ri in range(range_start, range_end - window): |
| | R.append(np.arange(ri, ri + window)) |
| | L.append(one_hot_label) |
| | S.append(sample_number) |
| |
|
| | self.R = torch.as_tensor(np.array(R), dtype=torch.long) |
| | self.L = torch.as_tensor(np.array(L), dtype=torch.float32) |
| | self.S = torch.as_tensor(S, dtype=torch.short) |
| | |
| |
|
| | @property |
| | def example_window_length(self): |
| | return self._example_window_length |
| |
|
| | @example_window_length.setter |
| | def example_window_length(self, a): |
| | self._example_window_length = a |
| |
|
| | def __len__(self): |
| | return len(self.R) |
| |
|
| | def __getitem__(self, index): |
| | |
| | Rwindow = self.R[index] |
| | Rwindow = Rwindow.contiguous() |
| |
|
| | |
| | Rlabel = self.L[index] |
| |
|
| | |
| | RInd = self.S[index] |
| | sample_range = self.ranges_train[RInd] |
| |
|
| | |
| | W_audio_features = self.X_audio_features[Rwindow] |
| |
|
| | |
| | W_root_pos = self.Y_root_pos[Rwindow] |
| | W_root_rot = self.Y_root_rot[Rwindow] |
| | W_root_vel = self.Y_root_vel[Rwindow] |
| | W_root_vrt = self.Y_root_vrt[Rwindow] |
| | W_lpos = self.Y_lpos[Rwindow] |
| | W_ltxy = self.Y_ltxy[Rwindow] |
| | W_lvel = self.Y_lvel[Rwindow] |
| | W_lvrt = self.Y_lvrt[Rwindow] |
| | W_gaze_pos = self.Y_gaze_pos[Rwindow] |
| |
|
| | if self.style_encoding_type == "label": |
| | style = Rlabel |
| | elif self.style_encoding_type == "example": |
| | style = self.get_example(Rwindow, sample_range, self.example_window_length) |
| |
|
| | return ( |
| | W_audio_features, |
| | W_root_pos, |
| | W_root_rot, |
| | W_root_vel, |
| | W_root_vrt, |
| | W_lpos, |
| | W_ltxy, |
| | W_lvel, |
| | W_lvrt, |
| | W_gaze_pos, |
| | style, |
| | ) |
| |
|
| | def get_shapes(self): |
| | num_audio_features = self.X_audio_features.shape[1] |
| | pose_input_size = len(self.anim_input_std) |
| | pose_output_size = len(self.anim_output_std) |
| | dimensions = dict( |
| | num_audio_features=num_audio_features, |
| | pose_input_size=pose_input_size, |
| | pose_output_size=pose_output_size, |
| | ) |
| | return dimensions |
| |
|
| | def get_means_stds(self, device): |
| | return ( |
| | self.audio_input_mean.to(device), |
| | self.audio_input_std.to(device), |
| | self.anim_input_mean.to(device), |
| | self.anim_input_std.to(device), |
| | self.anim_output_mean.to(device), |
| | self.anim_output_std.to(device), |
| | ) |
| |
|
| | def get_example( |
| | self, Rwindow, sample_range, example_window_length, |
| | ): |
| |
|
| | ext_window = (example_window_length - self.window) // 2 |
| | ws = min(ext_window, Rwindow[0] - sample_range[0]) |
| | we = min(ext_window, sample_range[1] - Rwindow[-1]) |
| | s_ext = ws + ext_window - we |
| | w_ext = we + ext_window - ws |
| | start = max(Rwindow[0] - s_ext, sample_range[0]) |
| | end = min(Rwindow[-1] + w_ext, sample_range[1]) + 1 |
| | end = min(end, len(self.Y_root_vel)) |
| | S_root_vel = self.Y_root_vel[start:end].reshape(end - start, -1) |
| | S_root_vrt = self.Y_root_vrt[start:end].reshape(end - start, -1) |
| | S_lpos = self.Y_lpos[start:end].reshape(end - start, -1) |
| | S_ltxy = self.Y_ltxy[start:end].reshape(end - start, -1) |
| | S_lvel = self.Y_lvel[start:end].reshape(end - start, -1) |
| | S_lvrt = self.Y_lvrt[start:end].reshape(end - start, -1) |
| | example_feature_vec = torch.cat( |
| | [S_root_vel, S_root_vrt, S_lpos, S_ltxy, S_lvel, S_lvrt, torch.zeros_like(S_root_vel), ], |
| | dim=1, |
| | ) |
| | curr_len = len(example_feature_vec) |
| | if curr_len < example_window_length: |
| | example_feature_vec = torch.cat( |
| | [example_feature_vec, example_feature_vec[-example_window_length + curr_len:]], |
| | dim=0, |
| | ) |
| | return example_feature_vec |
| |
|
| | def get_sample(self, dataset, length=None, range_index=None): |
| | if dataset == "train": |
| | if range_index is None: |
| | range_index = np.random.randint(len(self.ranges_train)) |
| | (s, e), label = self.ranges_train[range_index], self.ranges_train_labels[range_index] |
| | elif dataset == "valid": |
| | if range_index is None: |
| | range_index = np.random.randint(len(self.ranges_valid)) |
| | (s, e), label = self.ranges_valid[range_index], self.ranges_valid_labels[range_index] |
| |
|
| | if length is not None: |
| | e = min(s + length * 60, e) |
| |
|
| | return ( |
| | self.X_audio_features[s:e][np.newaxis], |
| | self.Y_root_pos[s:e][np.newaxis], |
| | self.Y_root_rot[s:e][np.newaxis], |
| | self.Y_root_vel[s:e][np.newaxis], |
| | self.Y_root_vrt[s:e][np.newaxis], |
| | self.Y_lpos[s:e][np.newaxis], |
| | self.Y_ltxy[s:e][np.newaxis], |
| | self.Y_lvel[s:e][np.newaxis], |
| | self.Y_lvrt[s:e][np.newaxis], |
| | self.Y_gaze_pos[s:e][np.newaxis], |
| | label, |
| | [s, e], |
| | range_index, |
| | ) |
| |
|
| | def get_stats(self): |
| | from rich.console import Console |
| | from rich.table import Table |
| |
|
| | console = Console(record=True) |
| | |
| | df = pd.DataFrame() |
| | df["Dataset"] = ["Train", "Validation", "Total"] |
| | pd.set_option("display.max_rows", None, "display.max_columns", None) |
| | table = Table(title="Data Info", show_lines=True, row_styles=["magenta"]) |
| | table.add_column("Dataset") |
| | data_len = 0 |
| | for i in range(self.nlabels): |
| | ind_mask = self.ranges_train_labels == i |
| | ranges = self.ranges_train[ind_mask] |
| | num_train_frames = ( |
| | np.sum(ranges[:, 1] - ranges[:, 0]) / 2 |
| | ) |
| | ind_mask = self.ranges_valid_labels == i |
| | ranges = self.ranges_valid[ind_mask] |
| | num_valid_frames = np.sum(ranges[:, 1] - ranges[:, 0]) / 2 |
| | total = num_train_frames + num_valid_frames |
| | df[self.label_names[i]] = [ |
| | f"{num_train_frames} frames - {num_train_frames / 60:.1f} secs", |
| | f"{num_valid_frames} frames - {num_valid_frames / 60:.1f} secs", |
| | f"{total} frames - {total / 60:.1f} secs", |
| | ] |
| | table.add_column(self.label_names[i]) |
| | data_len += total |
| |
|
| | for i in range(3): |
| | table.add_row(*list(df.iloc[i])) |
| | console.print(table) |
| | dimensions = self.get_shapes() |
| | console.print(f"Total length of dataset is {data_len} frames - {data_len / 60:.1f} seconds") |
| | console.print("Num features: ", dimensions) |
| |
|