File size: 2,427 Bytes
dfd1909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | from typing import Optional, Callable
import os
import importlib
import TorchJaekwon
TORCH_JAEKWON_PATH = os.path.dirname(TorchJaekwon.__file__)
try: import torch.nn as nn
except: print('''Can't import torch.nn''')
try: from HParams import HParams
except: print('There is no Hparams')
class GetModule:
@staticmethod
def get_import_path_of_module(root_path:str, module_name:str) -> Optional[str]:
root_path_list:list = [root_path]
root_path_list.append(root_path.replace("./",f'{TORCH_JAEKWON_PATH}/'))
for root_path in root_path_list:
for root,dirs,files in os.walk(root_path):
if len(files) > 0:
for file in files:
if os.path.splitext(file)[0] == module_name:
if TORCH_JAEKWON_PATH in root:
torch_jaekwon_parent_path:str = '/'.join(TORCH_JAEKWON_PATH.split('/')[:-1])
return f'{root}/{os.path.splitext(file)[0]}'.replace(torch_jaekwon_parent_path+'/','').replace("/",".")
else:
return f'{root}/{os.path.splitext(file)[0]}'.replace("./","").replace("/",".")
return None
@staticmethod
def get_module_class(root_path:str,module_name:str):
module_path:str = GetModule.get_import_path_of_module(root_path,module_name)
module_from = importlib.import_module(module_path)
return getattr(module_from,module_name)
@staticmethod
def get_model(
model_name:str,
root_path:str = './Model'
) -> nn.Module:
module_file_path:str = GetModule.get_import_path_of_module(root_path, model_name)
file_module = importlib.import_module(module_file_path)
class_module = getattr(file_module,model_name)
argument_getter:Callable[[],dict] = getattr(class_module,'get_argument_of_this_model',lambda: dict())
model_parameter:dict = argument_getter()
if len(model_parameter) == 0:
model_parameter = HParams().model.class_meta_dict.get(model_name,{})
if not model_parameter:
model_parameter = getattr(HParams().model,model_name,dict())
if not model_parameter: print(f'''GetModule: Model [{model_name}] doesn't have changed arguments''')
model:nn.Module = class_module(**model_parameter)
return model |