SATA / src /sata /utils /data_utils.py
zzysteve
Initial commit
5221c8c
Raw
History Blame Contribute Delete
2.45 kB
import copy
import numpy as np
def safe_normalize_to_one(np_vec, eps=1e-4):
vec_norm = np.linalg.norm(np_vec)
if vec_norm > eps:
return np_vec / vec_norm
else:
return np_vec
def safe_normalize(data, axis=0, ignore_std=1e-4):
mean, std = data.mean(axis=axis), data.std(axis=axis)
for i in range(len(std)):
if std[i] < ignore_std:
std[i] = 1.0
return (data - mean) / std, (mean, std)
def safe_normalize_pre_stat(data, mean, std, ignore_std=1e-4):
for i in range(len(std)):
if std[i] < ignore_std:
std[i] = 1.0
return (data - mean) / std
def gather_safe_normalize(data, axis=-1, ignore_std=1e-4, real=None):
"""
safe normalize data of shape [A, B, C, D, E] (for example) with axis = 2
it means, to squeeze all axis execpt 2, get safe_mean_std, and then reshape data back to its original data
real must have shape of [A,B,D,E] in this case.
"""
data_gathered_by_axis = copy.deepcopy(data)
dim = data.shape[axis]
# (if axis != -1) : before flattening except target axis, first transpose target axis with the last axis
if axis != -1:
transpose_order = list(range(len(data.shape)))
transpose_order[-1] = axis
transpose_order[axis] = -1
data_gathered_by_axis = np.transpose(data_gathered_by_axis, transpose_order)
data_reordered_dim = data_gathered_by_axis.shape
data_gathered_by_axis = data_gathered_by_axis.reshape(-1, dim)
# filter real only if needed
if real is not None:
assert data_reordered_dim[:-1] == real.shape
real_flatten = real.flatten()
data_gathered_by_axis_real_only = data_gathered_by_axis[real_flatten]
else:
data_gathered_by_axis_real_only = data_gathered_by_axis
# get mean,std using flattened, filtered data
mean, std = data_gathered_by_axis_real_only.mean(
axis=0
), data_gathered_by_axis_real_only.std(axis=0)
for i in range(len(std)):
if std[i] < ignore_std:
std[i] = 1.0
# normalize flattened (not filtered: in order to preserve the shape when reshaping back) data
data_gathered_by_axis = (data_gathered_by_axis - mean) / std
# reshape back
data_reshape_back = data_gathered_by_axis.reshape(data_reordered_dim)
if axis != -1:
data_reshape_back = np.transpose(data_reshape_back, transpose_order)
return data_reshape_back, (mean, std)