FlashSR_One-step_Versatile_Audio_Super-resolution / TorchJaekwon /DataProcess /Preprocess /Preprocessor.py
| from typing import List | |
| from abc import ABC, abstractmethod | |
| from concurrent.futures import ProcessPoolExecutor | |
| import os | |
| import time | |
| import torch | |
| from tqdm import tqdm | |
| class Preprocessor(ABC): | |
| def __init__(self, | |
| data_name:str = None, | |
| root_dir:str = None, | |
| device:torch.device = None, | |
| num_workers:int = 1, | |
| ) -> None: | |
| # args to class variable | |
| self.data_name:str = data_name | |
| self.root_dir:str = root_dir | |
| self.num_workers:int = num_workers | |
| self.device:torch.device = device | |
| if self.root_dir is not None and self.data_name is not None: | |
| self.output_dir = self.get_output_dir() | |
| os.makedirs(self.output_dir,exist_ok=True) | |
| else: | |
| print('Warning: root_dir or data_name is None') | |
| def get_output_dir(self) -> str: | |
| return os.path.join(self.root_dir, self.data_name) | |
| def write_message(self,message_type:str,message:str) -> None: | |
| with open(f"{self.preprocessed_data_path}/{message_type}.txt",'a') as file_writer: | |
| file_writer.write(message+'\n') | |
| def preprocess_data(self) -> None: | |
| meta_param_list:list = self.get_meta_data_param() | |
| if meta_param_list is None: | |
| print('meta_param_list is None, So we skip preprocess data') | |
| return | |
| start_time:float = time.time() | |
| if self.num_workers > 2: | |
| with ProcessPoolExecutor(max_workers=self.num_workers) as pool: | |
| pool.map(self.preprocess_one_data, meta_param_list) | |
| else: | |
| for meta_param in tqdm(meta_param_list,desc='preprocess data'): | |
| self.preprocess_one_data(meta_param) | |
| self.final_process() | |
| print("{:.3f} s".format(time.time() - start_time)) | |
| def get_meta_data_param(self) -> list: | |
| ''' | |
| meta_data_param_list = list() | |
| ''' | |
| raise NotImplementedError | |
| def preprocess_one_data(self,param: tuple) -> None: | |
| ''' | |
| ex) (subset, file_name) = param | |
| ''' | |
| raise NotImplementedError | |
| def final_process(self) -> None: | |
| print("Finish preprocess") |