from typing import Union,Dict,List from numpy import ndarray from torch import Tensor import os from tqdm import tqdm import random import copy import numpy as np import torch import torch.nn.functional as F import pickle, yaml, csv, json from pathlib import Path from inspect import isfunction class UtilData: @staticmethod def get_file_name(file_path:str, with_ext:bool = False) -> str: if file_path is None: print("warning: path is None") return "" path_pathlib = Path(file_path) if with_ext: return path_pathlib.name else: return path_pathlib.stem @staticmethod def pickle_save(save_path:str, data:Union[ndarray,Tensor]) -> None: if not (os.path.splitext(save_path)[1] == ".pkl"): print("file extension should be '.pkl'") save_path = f'{save_path}.pkl' os.makedirs(os.path.dirname(save_path),exist_ok=True) with open(save_path,'wb') as file_writer: pickle.dump(data,file_writer) @staticmethod def pickle_load(data_path:str) -> Union[ndarray,Tensor]: with open(data_path, 'rb') as pickle_file: data:Union[ndarray,Tensor] = pickle.load(pickle_file) return data @staticmethod def yaml_save(save_path:str, data:Union[dict,list], sort_keys:bool = False) -> None: assert(os.path.splitext(save_path)[1] == ".yaml") , "file extension should be '.yaml'" with open(save_path, 'w') as file: yaml.dump(data, file, sort_keys = sort_keys, allow_unicode=True) @staticmethod def yaml_load(data_path:str) -> dict: yaml_file = open(data_path, 'r') return yaml.safe_load(yaml_file) @staticmethod def csv_load(data_path:str) -> list: row_result_list = list() with open(data_path, newline='') as csvfile: spamreader = csv.reader(csvfile)#, delimiter=' ', quotechar='|') for row in spamreader: row_result_list.append(row) return row_result_list @staticmethod def txt_load(data_path:str) -> list: with open(data_path, 'r') as txtfile: return txtfile.readlines() @staticmethod def txt_save(save_path:str, string_list:List[str], new_file:bool = True) -> list: os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, 'w' if new_file else 'a') as file: for line in string_list: file.write(f'{line}\n') @staticmethod def csv_save(file_path:str, data_dict_list:List[Dict[str,object]], #[ {key:object}, ... ] order_of_key:list = None # [key1, key2, ...] ) -> list: import pandas as pd if order_of_key is None: order_of_key = list(data_dict_list[0].keys()) csv_save_dict:dict = {key:list() for key in order_of_key} for data_dict in data_dict_list: for key in csv_save_dict: csv_save_dict[key].append(data_dict[key]) pd.DataFrame(csv_save_dict).to_csv(file_path) @staticmethod def json_load(file_path:str) -> dict: with open(file_path) as f: data = f.read() return json.loads(data) @staticmethod def save_data_segment(save_dir:str,data:ndarray,segment_len:int,segment_axis:int=-1,remainder:str = ['discard','pad','maintain'][1],ext:str = ['pkl'][0]): os.makedirs(save_dir,exist_ok=True) data_total = copy.deepcopy(data) total_length_of_data:int = data_total.shape[segment_axis] if total_length_of_data % segment_len != 0 and remainder in ['discard','pad']: if remainder == 'discard': data_total = data_total.take(indices=range(0, total_length_of_data - (total_length_of_data % segment_len)), axis=segment_axis) else: assert(segment_axis==-1 and (len(data_total.shape) in [1,2])),'Error[UtilData.save_data_segment] not implemented yet' pad_length:int = segment_len - (total_length_of_data % segment_len) if len(data_total.shape) == 1: data_total = np.pad(data_total, (0, pad_length), 'constant') elif len(data_total.shape) == 2: data_total = np.pad(data_total, ((0,0),(0,pad_length)), 'constant') total_length_of_data:int = data_total.shape[segment_axis] for start_idx in range(0,total_length_of_data,segment_len): end_idx:int = start_idx + segment_len if remainder == 'maintain' and end_idx >= total_length_of_data: end_idx = total_length_of_data - 1 data_segment = data_total.take(indices=range(start_idx, end_idx), axis=segment_axis) assert(data_segment.shape[segment_axis] == segment_len),'Error[UtilData.save_data_segment] segment length error!!' if ext == 'pkl': UtilData.pickle_save(f'{save_dir}/{start_idx}.{ext}',data_segment) @staticmethod def fit_shape_length(feature:Union[Tensor,ndarray],shape_length:int, dim:int = 0) -> Tensor: if shape_length == len(feature.shape): return feature if type(feature) != torch.Tensor: feature = torch.from_numpy(feature) feature = torch.squeeze(feature) for _ in range(shape_length - len(feature.shape)): feature = torch.unsqueeze(feature, dim=dim) return feature @staticmethod def sort_dict_list( dict_list: List[dict], key:str, reverse:bool = False): return sorted(dict_list, key = lambda dictionary: dictionary[key], reverse=reverse) @staticmethod def random_segment(data:ndarray, data_length:int) -> ndarray: max_data_start = len(data) - data_length data_start = random.randint(0, max_data_start) return data[data_start:data_start+data_length] @staticmethod def default(val, d): if val is not None: return val return d() if isfunction(d) else d @staticmethod def fix_length(data:Union[ndarray,Tensor], length:int, dim:int = -1 ) -> Tensor: assert len(data.shape) in [1,2,3], "Error[UtilData.fix_length] only support when data.shape is 1, 2 or 3" if data.shape[dim] < length: if isinstance(data,Tensor): return F.pad(data, (0,length - data.shape[dim]), "constant", 0) else: return F.pad(torch.from_numpy(data), (0,length - data.shape[dim]), "constant", 0).numpy() elif data.shape[dim] == length: return data else: assert dim == -1, "Error[UtilData.fix_length] slicing when dim is not -1 not implemented yet" return data[..., :length] @staticmethod def listdir(dir_name:str, ext:Union[str,list] = ['.wav', '.mp3', '.flac']) -> list: if ext is None: return os.listdir(dir_name) elif isinstance(ext,list): return [{'file_name': file_name, 'file_path':f'{dir_name}/{file_name}'} for file_name in os.listdir(dir_name) if os.path.splitext(file_name)[1] in ext] else: return [{'file_name': file_name, 'file_path':f'{dir_name}/{file_name}'} for file_name in os.listdir(dir_name) if os.path.splitext(file_name)[1] == ext] @staticmethod def walk(dir_name:str, ext:list = ['.wav', '.mp3', '.flac']) -> list: file_meta_list:list = list() for root, _, files in os.walk(dir_name): for filename in tqdm(files, desc=f'walk {root}'): if os.path.splitext(filename)[-1] in ext: file_meta_list.append({ 'file_name': UtilData.get_file_name( file_path = filename ), 'file_path': f'{root}/{filename}', 'dir_name': root.replace(dir_name,'').replace('/',''), 'dir_path': root, }) return file_meta_list @staticmethod def get_dir_name_list(root_dir:str) -> list: return [dir_name for dir_name in os.listdir(root_dir) if os.path.isdir(f'{root_dir}/{dir_name}')] @staticmethod def pretty_num(number:float) -> str: if number < 1000: return str(number) elif number < 1000000: return f'{round(number/1000,5)}K' elif number < 1000000000: return f'{round(number/1000000,5)}M' else: return f'{round(number/1000000000,5)}B' @staticmethod def extract_num_from_str(string:str) -> float: return float(''.join([c for c in string if c.isdigit() or c == '.']))