ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
#type
from typing import Type, Literal, Dict, List, Union, Literal
#package
import os
import argparse
import numpy as np
#torchjaekwon
from TorchJaekwon.GetModule import GetModule
#internal
from HParams import HParams
class Controller():
def __init__(self) -> None:
self.set_argparse()
self.config_name:str = HParams().mode.config_name
self.stage: Literal['preprocess', 'train', 'inference', 'evaluate'] = HParams().mode.stage
self.config_per_dataset_dict: Dict[str, dict] = HParams().data.config_per_dataset_dict
self.train_mode: Literal['start', 'resume'] = HParams().mode.train
self.train_resume_path: str = HParams().mode.resume_path
self.eval_class_meta:dict = HParams().evaluate.class_meta # {'name': 'Evaluater', 'args': {}}
def run(self) -> None:
print("=============================================")
print(f"{self.stage} start.")
print("=============================================")
print(f"{self.config_name} start.")
print("=============================================")
getattr(self,self.stage)()
print("Finish app.")
def preprocess(self) -> None:
from TorchJaekwon.DataProcess.Preprocess.Preprocessor import Preprocessor
for data_name in self.config_per_dataset_dict:
for preprocessor_meta in self.config_per_dataset_dict[data_name]['preprocessor_class_meta_list']:
preprocessor_class_name:str = preprocessor_meta['name']
preprocessor_args:dict = {
'data_name': data_name,
'root_dir': HParams().data.root_path,
'num_workers': HParams().resource.preprocess['num_workers'],
'device': HParams().resource.device,
}
preprocessor_args.update(preprocessor_meta['args'])
preprocessor_class:Type[Preprocessor] = GetModule.get_module_class( "./DataProcess/Preprocess", preprocessor_class_name )
preprocessor:Preprocessor = preprocessor_class(**preprocessor_args)
preprocessor.preprocess_data()
def train(self) -> None:
import torch
from TorchJaekwon.Train.Trainer.Trainer import Trainer
train_class_meta:dict = HParams().train.class_meta # {'name': 'Trainer', 'args': {}}
trainer_args:dict = {
'device': HParams().resource.device,
'data_class_meta_dict': HParams().pytorch_data.class_meta,
'model_class_name': HParams().model.class_name,
'model_class_meta_dict': HParams().model.class_meta_dict,
'optimizer_class_meta_dict': HParams().train.optimizer['class_meta'],
'lr_scheduler_class_meta_dict': HParams().train.scheduler['class_meta'],
'loss_class_meta': HParams().train.loss_dict,
'max_norm_value_for_gradient_clip': getattr(HParams().train,'max_norm_value_for_gradient_clip',None),
'total_epoch': getattr(HParams().train, 'total_epoch', int(1e20)),
'total_step': getattr(HParams().train, 'total_step', np.inf),
'save_model_every_step': getattr(HParams().train, 'save_model_every_step', None),
'do_log_every_epoch': getattr(HParams().train, 'do_log_every_epoch', True),
'seed': (int)(torch.cuda.initial_seed() / (2**32)) if HParams().train.seed is None else HParams().train.seed,
'seed_strict': HParams().train.seed_strict,
'debug_mode': getattr(HParams().mode, 'debug_mode', False),
'use_torch_compile': getattr(HParams().mode, 'use_torch_compile', True),
}
trainer_args.update(train_class_meta['args'])
trainer_class:Type[Trainer] = GetModule.get_module_class('./Train/Trainer', train_class_meta['name'])
trainer:Trainer = trainer_class(**trainer_args)
trainer.init_train()
if self.train_mode == "resume":
print('resume the training')
trainer.load_train(self.train_resume_path + "/train_checkpoint.pth")
trainer.fit()
def inference(self) -> None:
from TorchJaekwon.Inference.Inferencer.Inferencer import Inferencer
infer_class_meta:dict = HParams().inference.class_meta # {'name': 'Inferencer', 'args': {}}
inferencer_args:dict = {
'output_dir': HParams().inference.output_dir,
'experiment_name': HParams().mode.config_name,
'model': None,
'model_class_name': HParams().model.class_name,
'set_type': HParams().inference.set_type,
'set_meta_dict': HParams().inference.set_meta_dict,
'device': HParams().resource.device
}
inferencer_args.update(infer_class_meta['args'])
inferencer_class:Type[Inferencer] = GetModule.get_module_class("./Inference/Inferencer", infer_class_meta['name'])
inferencer:Inferencer = inferencer_class(**inferencer_args)
inferencer.inference(
pretrained_root_dir = HParams().inference.pretrain_root_dir,
pretrained_dir_name = HParams().mode.config_name if HParams().inference.pretrain_dir == '' else HParams().inference.pretrain_dir,
pretrain_module_name = HParams().inference.pretrain_module_name
)
def evaluate(self) -> None:
from TorchJaekwon.Evaluater.Evaluater import Evaluater
evaluater_class:Type[Evaluater] = GetModule.get_module_class("./Evaluater", self.eval_class_meta['name'])
evaluater_args:dict = self.eval_class_meta['args']
evaluater_args.update({
'device': HParams().resource.device
})
if evaluater_args.get('source_dir','') == '':
source_dir_prefix:str = f'{HParams().inference.output_dir}/{HParams().mode.config_name}'
source_dir_parent:str = '/'.join(source_dir_prefix.split('/')[:-1])
source_dir_tag:str = source_dir_prefix.split('/')[-1]
source_dir_name_candidate = [dir_name for dir_name in os.listdir(source_dir_parent) if source_dir_tag in dir_name]
source_dir_name_candidate.sort()
evaluater_args['source_dir'] = f'{source_dir_parent}/{source_dir_name_candidate[-1]}'
evaluater:Evaluater = evaluater_class(**evaluater_args)
evaluater.evaluate()
def set_argparse(self) -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_path",
type=str,
required=False,
default=None,
help="",
)
parser.add_argument(
"-s",
"--stage",
type=str,
required=False,
default=None,
choices = ['preprocess', 'train', 'inference', 'evaluate'],
help="",
)
parser.add_argument(
'-r',
'--resume',
help='train resume',
action='store_true'
)
parser.add_argument(
"-do",
"--debug_off",
help="debug mode off",
action='store_true'
)
parser.add_argument(
"-lv",
"--log_visualizer",
type=str,
required=False,
default=None,
choices = ['tensorboard', 'wandb'],
help="",
)
args = parser.parse_args()
if args.config_path is not None: HParams().set_config(args.config_path)
if args.stage is not None: HParams().mode.stage = args.stage
if args.log_visualizer is not None: HParams().log.visualizer_type = args.log_visualizer
if args.resume: HParams().mode.train = "resume"
if args.debug_off: HParams().mode.debug_mode = False
return args