TunaDance / dataset /preprocess.py
NikhilMarisetty's picture
Upload folder using huggingface_hub
eb71a72 verified
import glob
import os
import re
from pathlib import Path
import torch
from .scaler import MinMaxScaler
import pickle
def increment_path(path, exist_ok=False, sep="", mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
suffix = path.suffix
path = path.with_suffix("")
dirs = glob.glob(f"{path}{sep}*") # similar paths
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
i = [int(m.groups()[0]) for m in matches if m] # indices
n = max(i) + 1 if i else 2 # increment number
path = Path(f"{path}{sep}{n}{suffix}") # update path
dir = path if path.suffix == "" else path.parent # directory
if not dir.exists() and mkdir:
dir.mkdir(parents=True, exist_ok=True) # make directory
return path
class Normalizer:
def __init__(self, data):
flat = data.reshape(-1, data.shape[-1]) # bxt , 151
self.scaler = MinMaxScaler((-1, 1), clip=True)
self.scaler.fit(flat)
def normalize(self, x):
batch, seq, ch = x.shape
x = x.reshape(-1, ch)
return self.scaler.transform(x).reshape((batch, seq, ch))
def unnormalize(self, x):
batch, seq, ch = x.shape
x = x.reshape(-1, ch)
x = torch.clip(x, -1, 1) # clip to force compatibility
return self.scaler.inverse_transform(x).reshape((batch, seq, ch))
class My_Normalizer:
def __init__(self, data):
if isinstance(data, str):
self.scaler = MinMaxScaler((-1, 1), clip=True)
with open(data, 'rb') as f:
normalizer_state_dict = pickle.load(f)
# normalizer_state_dict = torch.load(data)
self.scaler.scale_ = normalizer_state_dict["scale"]
self.scaler.min_ = normalizer_state_dict["min"]
else:
flat = data.reshape(-1, data.shape[-1]) # bxt , 151
self.scaler = MinMaxScaler((-1, 1), clip=True)
self.scaler.fit(flat)
def normalize(self, x):
if len(x.shape) == 3:
batch, seq, ch = x.shape
x = x.reshape(-1, ch)
return self.scaler.transform(x).reshape((batch, seq, ch))
elif len(x.shape) == 2:
batch, ch = x.shape
return self.scaler.transform(x)
else:
raise("input error!")
def unnormalize(self, x):
if len(x.shape) == 3:
batch, seq, ch = x.shape
x = x.reshape(-1, ch)
x = torch.clip(x, -1, 1) # clip to force compatibility
return self.scaler.inverse_transform(x).reshape((batch, seq, ch))
elif len(x.shape) == 2:
x = torch.clip(x, -1, 1)
return self.scaler.inverse_transform(x)
else:
raise("input error!")
def vectorize_many(data):
# given a list of batch x seqlen x joints? x channels, flatten all to batch x seqlen x -1, concatenate
batch_size = data[0].shape[0]
seq_len = data[0].shape[1]
out = [x.reshape(batch_size, seq_len, -1).contiguous() for x in data]
global_pose_vec_gt = torch.cat(out, dim=2)
return global_pose_vec_gt