File size: 2,250 Bytes
dfd1909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from typing import List
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
import os
import time
import torch
from tqdm import tqdm
class Preprocessor(ABC):
def __init__(self,
data_name:str = None,
root_dir:str = None,
device:torch.device = None,
num_workers:int = 1,
) -> None:
# args to class variable
self.data_name:str = data_name
self.root_dir:str = root_dir
self.num_workers:int = num_workers
self.device:torch.device = device
if self.root_dir is not None and self.data_name is not None:
self.output_dir = self.get_output_dir()
os.makedirs(self.output_dir,exist_ok=True)
else:
print('Warning: root_dir or data_name is None')
def get_output_dir(self) -> str:
return os.path.join(self.root_dir, self.data_name)
def write_message(self,message_type:str,message:str) -> None:
with open(f"{self.preprocessed_data_path}/{message_type}.txt",'a') as file_writer:
file_writer.write(message+'\n')
def preprocess_data(self) -> None:
meta_param_list:list = self.get_meta_data_param()
if meta_param_list is None:
print('meta_param_list is None, So we skip preprocess data')
return
start_time:float = time.time()
if self.num_workers > 2:
with ProcessPoolExecutor(max_workers=self.num_workers) as pool:
pool.map(self.preprocess_one_data, meta_param_list)
else:
for meta_param in tqdm(meta_param_list,desc='preprocess data'):
self.preprocess_one_data(meta_param)
self.final_process()
print("{:.3f} s".format(time.time() - start_time))
@abstractmethod
def get_meta_data_param(self) -> list:
'''
meta_data_param_list = list()
'''
raise NotImplementedError
@abstractmethod
def preprocess_one_data(self,param: tuple) -> None:
'''
ex) (subset, file_name) = param
'''
raise NotImplementedError
def final_process(self) -> None:
print("Finish preprocess") |