| from typing import Optional, Callable |
|
|
| import os |
| import importlib |
|
|
| import TorchJaekwon |
| TORCH_JAEKWON_PATH = os.path.dirname(TorchJaekwon.__file__) |
|
|
| try: import torch.nn as nn |
| except: print('''Can't import torch.nn''') |
| try: from HParams import HParams |
| except: print('There is no Hparams') |
|
|
| class GetModule: |
| @staticmethod |
| def get_import_path_of_module(root_path:str, module_name:str) -> Optional[str]: |
| root_path_list:list = [root_path] |
| root_path_list.append(root_path.replace("./",f'{TORCH_JAEKWON_PATH}/')) |
| |
| for root_path in root_path_list: |
| for root,dirs,files in os.walk(root_path): |
| if len(files) > 0: |
| for file in files: |
| if os.path.splitext(file)[0] == module_name: |
| if TORCH_JAEKWON_PATH in root: |
| torch_jaekwon_parent_path:str = '/'.join(TORCH_JAEKWON_PATH.split('/')[:-1]) |
| return f'{root}/{os.path.splitext(file)[0]}'.replace(torch_jaekwon_parent_path+'/','').replace("/",".") |
| else: |
| return f'{root}/{os.path.splitext(file)[0]}'.replace("./","").replace("/",".") |
| return None |
| |
| @staticmethod |
| def get_module_class(root_path:str,module_name:str): |
| module_path:str = GetModule.get_import_path_of_module(root_path,module_name) |
| module_from = importlib.import_module(module_path) |
| return getattr(module_from,module_name) |
| |
| @staticmethod |
| def get_model( |
| model_name:str, |
| root_path:str = './Model' |
| ) -> nn.Module: |
| module_file_path:str = GetModule.get_import_path_of_module(root_path, model_name) |
| file_module = importlib.import_module(module_file_path) |
| class_module = getattr(file_module,model_name) |
| argument_getter:Callable[[],dict] = getattr(class_module,'get_argument_of_this_model',lambda: dict()) |
| model_parameter:dict = argument_getter() |
| if len(model_parameter) == 0: |
| model_parameter = HParams().model.class_meta_dict.get(model_name,{}) |
| if not model_parameter: |
| model_parameter = getattr(HParams().model,model_name,dict()) |
| if not model_parameter: print(f'''GetModule: Model [{model_name}] doesn't have changed arguments''') |
| model:nn.Module = class_module(**model_parameter) |
| return model |