| |
| from typing import Type, Literal, Dict, List, Union, Literal |
|
|
| |
| import os |
| import argparse |
| import numpy as np |
|
|
| |
| from TorchJaekwon.GetModule import GetModule |
|
|
| |
| from HParams import HParams |
|
|
| class Controller(): |
| def __init__(self) -> None: |
| self.set_argparse() |
|
|
| self.config_name:str = HParams().mode.config_name |
| self.stage: Literal['preprocess', 'train', 'inference', 'evaluate'] = HParams().mode.stage |
|
|
| self.config_per_dataset_dict: Dict[str, dict] = HParams().data.config_per_dataset_dict |
| self.train_mode: Literal['start', 'resume'] = HParams().mode.train |
| self.train_resume_path: str = HParams().mode.resume_path |
| self.eval_class_meta:dict = HParams().evaluate.class_meta |
|
|
| def run(self) -> None: |
| print("=============================================") |
| print(f"{self.stage} start.") |
| print("=============================================") |
| print(f"{self.config_name} start.") |
| print("=============================================") |
| |
| getattr(self,self.stage)() |
| |
| print("Finish app.") |
|
|
| def preprocess(self) -> None: |
| from TorchJaekwon.DataProcess.Preprocess.Preprocessor import Preprocessor |
| for data_name in self.config_per_dataset_dict: |
| for preprocessor_meta in self.config_per_dataset_dict[data_name]['preprocessor_class_meta_list']: |
| preprocessor_class_name:str = preprocessor_meta['name'] |
| preprocessor_args:dict = { |
| 'data_name': data_name, |
| 'root_dir': HParams().data.root_path, |
| 'num_workers': HParams().resource.preprocess['num_workers'], |
| 'device': HParams().resource.device, |
| } |
| preprocessor_args.update(preprocessor_meta['args']) |
|
|
| preprocessor_class:Type[Preprocessor] = GetModule.get_module_class( "./DataProcess/Preprocess", preprocessor_class_name ) |
| preprocessor:Preprocessor = preprocessor_class(**preprocessor_args) |
| preprocessor.preprocess_data() |
|
|
| def train(self) -> None: |
| import torch |
| from TorchJaekwon.Train.Trainer.Trainer import Trainer |
| |
| train_class_meta:dict = HParams().train.class_meta |
| trainer_args:dict = { |
| 'device': HParams().resource.device, |
| 'data_class_meta_dict': HParams().pytorch_data.class_meta, |
| 'model_class_name': HParams().model.class_name, |
| 'model_class_meta_dict': HParams().model.class_meta_dict, |
| 'optimizer_class_meta_dict': HParams().train.optimizer['class_meta'], |
| 'lr_scheduler_class_meta_dict': HParams().train.scheduler['class_meta'], |
| 'loss_class_meta': HParams().train.loss_dict, |
| 'max_norm_value_for_gradient_clip': getattr(HParams().train,'max_norm_value_for_gradient_clip',None), |
| 'total_epoch': getattr(HParams().train, 'total_epoch', int(1e20)), |
| 'total_step': getattr(HParams().train, 'total_step', np.inf), |
| 'save_model_every_step': getattr(HParams().train, 'save_model_every_step', None), |
| 'do_log_every_epoch': getattr(HParams().train, 'do_log_every_epoch', True), |
| 'seed': (int)(torch.cuda.initial_seed() / (2**32)) if HParams().train.seed is None else HParams().train.seed, |
| 'seed_strict': HParams().train.seed_strict, |
| 'debug_mode': getattr(HParams().mode, 'debug_mode', False), |
| 'use_torch_compile': getattr(HParams().mode, 'use_torch_compile', True), |
| } |
| trainer_args.update(train_class_meta['args']) |
| |
| trainer_class:Type[Trainer] = GetModule.get_module_class('./Train/Trainer', train_class_meta['name']) |
| trainer:Trainer = trainer_class(**trainer_args) |
| trainer.init_train() |
| |
| if self.train_mode == "resume": |
| print('resume the training') |
| trainer.load_train(self.train_resume_path + "/train_checkpoint.pth") |
| |
| trainer.fit() |
|
|
| def inference(self) -> None: |
| from TorchJaekwon.Inference.Inferencer.Inferencer import Inferencer |
| |
| infer_class_meta:dict = HParams().inference.class_meta |
| inferencer_args:dict = { |
| 'output_dir': HParams().inference.output_dir, |
| 'experiment_name': HParams().mode.config_name, |
| 'model': None, |
| 'model_class_name': HParams().model.class_name, |
| 'set_type': HParams().inference.set_type, |
| 'set_meta_dict': HParams().inference.set_meta_dict, |
| 'device': HParams().resource.device |
| } |
| inferencer_args.update(infer_class_meta['args']) |
|
|
| inferencer_class:Type[Inferencer] = GetModule.get_module_class("./Inference/Inferencer", infer_class_meta['name']) |
| inferencer:Inferencer = inferencer_class(**inferencer_args) |
| inferencer.inference( |
| pretrained_root_dir = HParams().inference.pretrain_root_dir, |
| pretrained_dir_name = HParams().mode.config_name if HParams().inference.pretrain_dir == '' else HParams().inference.pretrain_dir, |
| pretrain_module_name = HParams().inference.pretrain_module_name |
| ) |
|
|
| def evaluate(self) -> None: |
| from TorchJaekwon.Evaluater.Evaluater import Evaluater |
| evaluater_class:Type[Evaluater] = GetModule.get_module_class("./Evaluater", self.eval_class_meta['name']) |
| evaluater_args:dict = self.eval_class_meta['args'] |
| evaluater_args.update({ |
| 'device': HParams().resource.device |
| }) |
| if evaluater_args.get('source_dir','') == '': |
| source_dir_prefix:str = f'{HParams().inference.output_dir}/{HParams().mode.config_name}' |
| source_dir_parent:str = '/'.join(source_dir_prefix.split('/')[:-1]) |
| source_dir_tag:str = source_dir_prefix.split('/')[-1] |
| source_dir_name_candidate = [dir_name for dir_name in os.listdir(source_dir_parent) if source_dir_tag in dir_name] |
| source_dir_name_candidate.sort() |
| evaluater_args['source_dir'] = f'{source_dir_parent}/{source_dir_name_candidate[-1]}' |
| evaluater:Evaluater = evaluater_class(**evaluater_args) |
| evaluater.evaluate() |
| |
| def set_argparse(self) -> None: |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "-c", |
| "--config_path", |
| type=str, |
| required=False, |
| default=None, |
| help="", |
| ) |
|
|
| parser.add_argument( |
| "-s", |
| "--stage", |
| type=str, |
| required=False, |
| default=None, |
| choices = ['preprocess', 'train', 'inference', 'evaluate'], |
| help="", |
| ) |
|
|
| parser.add_argument( |
| '-r', |
| '--resume', |
| help='train resume', |
| action='store_true' |
| ) |
|
|
| parser.add_argument( |
| "-do", |
| "--debug_off", |
| help="debug mode off", |
| action='store_true' |
| ) |
|
|
| parser.add_argument( |
| "-lv", |
| "--log_visualizer", |
| type=str, |
| required=False, |
| default=None, |
| choices = ['tensorboard', 'wandb'], |
| help="", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.config_path is not None: HParams().set_config(args.config_path) |
| if args.stage is not None: HParams().mode.stage = args.stage |
| if args.log_visualizer is not None: HParams().log.visualizer_type = args.log_visualizer |
| if args.resume: HParams().mode.train = "resume" |
| if args.debug_off: HParams().mode.debug_mode = False |
|
|
| return args |