File size: 2,427 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import Optional, Callable

import os
import importlib

import TorchJaekwon
TORCH_JAEKWON_PATH = os.path.dirname(TorchJaekwon.__file__)

try: import torch.nn as nn
except: print('''Can't import torch.nn''')
try: from HParams import HParams
except: print('There is no Hparams')

class GetModule:
    @staticmethod
    def get_import_path_of_module(root_path:str, module_name:str) -> Optional[str]:
        root_path_list:list = [root_path]
        root_path_list.append(root_path.replace("./",f'{TORCH_JAEKWON_PATH}/'))
        
        for root_path in root_path_list:
            for root,dirs,files in os.walk(root_path):
                if len(files) > 0:
                    for file in files:
                        if os.path.splitext(file)[0] == module_name:
                            if TORCH_JAEKWON_PATH in root:
                                torch_jaekwon_parent_path:str = '/'.join(TORCH_JAEKWON_PATH.split('/')[:-1])
                                return f'{root}/{os.path.splitext(file)[0]}'.replace(torch_jaekwon_parent_path+'/','').replace("/",".")
                            else:
                                return f'{root}/{os.path.splitext(file)[0]}'.replace("./","").replace("/",".")
        return None
    
    @staticmethod
    def get_module_class(root_path:str,module_name:str):
        module_path:str = GetModule.get_import_path_of_module(root_path,module_name)
        module_from = importlib.import_module(module_path)
        return getattr(module_from,module_name)
    
    @staticmethod
    def get_model(
        model_name:str,
        root_path:str = './Model'
        ) -> nn.Module:
        module_file_path:str = GetModule.get_import_path_of_module(root_path, model_name)
        file_module = importlib.import_module(module_file_path)
        class_module = getattr(file_module,model_name)
        argument_getter:Callable[[],dict] = getattr(class_module,'get_argument_of_this_model',lambda: dict())
        model_parameter:dict = argument_getter()
        if len(model_parameter) == 0:
            model_parameter = HParams().model.class_meta_dict.get(model_name,{})
        if not model_parameter: 
            model_parameter = getattr(HParams().model,model_name,dict())
            if not model_parameter: print(f'''GetModule: Model [{model_name}] doesn't have changed arguments''')
        model:nn.Module = class_module(**model_parameter)
        return model