File size: 4,771 Bytes
fb5159d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121





import inspect
import logging
logging.basicConfig()
log = logging.getLogger("ModuleRegister")
log.setLevel(logging.DEBUG)

MODULE_MAPS = []


def registerModuleMap(module_map):
    MODULE_MAPS.append(module_map)
    log.info("ModuleRegister get modules from  ModuleMap content: {}".
             format(inspect.getsource(module_map)))


def constructTrainerClass(myTrainerClass, opts):

    log.info("ModuleRegister, myTrainerClass name is {}".
             format(myTrainerClass.__name__))
    log.info("ModuleRegister, myTrainerClass type is {}".
             format(type(myTrainerClass)))
    log.info("ModuleRegister, myTrainerClass dir is {}".
             format(dir(myTrainerClass)))

    myInitializeModelModule = getModule(opts['model']['model_name_py'])
    log.info("ModuleRegister, myInitializeModelModule dir is {}".
             format(dir(myInitializeModelModule)))

    myTrainerClass.init_model = myInitializeModelModule.init_model
    myTrainerClass.run_training_net = myInitializeModelModule.run_training_net
    myTrainerClass.fun_per_iter_b4RunNet = \
        myInitializeModelModule.fun_per_iter_b4RunNet
    myTrainerClass.fun_per_epoch_b4RunNet = \
        myInitializeModelModule.fun_per_epoch_b4RunNet

    myInputModule = getModule(opts['input']['input_name_py'])
    log.info("ModuleRegister, myInputModule {} dir is {}".
             format(opts['input']['input_name_py'], myInputModule.__name__))

    # Override input methods of the myTrainerClass class
    myTrainerClass.get_input_dataset = myInputModule.get_input_dataset
    myTrainerClass.get_model_input_fun = myInputModule.get_model_input_fun
    myTrainerClass.gen_input_builder_fun = myInputModule.gen_input_builder_fun

    # myForwardPassModule = GetForwardPassModule(opts)
    myForwardPassModule = getModule(opts['model']['forward_pass_py'])
    myTrainerClass.gen_forward_pass_builder_fun = \
        myForwardPassModule.gen_forward_pass_builder_fun

    myParamUpdateModule = getModule(opts['model']['parameter_update_py'])
    myTrainerClass.gen_param_update_builder_fun =\
        myParamUpdateModule.gen_param_update_builder_fun \
        if myParamUpdateModule is not None else None

    myOptimizerModule = getModule(opts['model']['optimizer_py'])
    myTrainerClass.gen_optimizer_fun = \
        myOptimizerModule.gen_optimizer_fun \
        if myOptimizerModule is not None else None

    myRendezvousModule = getModule(opts['model']['rendezvous_py'])
    myTrainerClass.gen_rendezvous_ctx = \
        myRendezvousModule.gen_rendezvous_ctx \
        if myRendezvousModule is not None else None

    # override output module
    myOutputModule = getModule(opts['output']['gen_output_py'])

    log.info("ModuleRegister, myOutputModule is {}".
             format(myOutputModule.__name__))
    myTrainerClass.fun_conclude_operator = myOutputModule.fun_conclude_operator
    myTrainerClass.assembleAllOutputs = myOutputModule.assembleAllOutputs

    return myTrainerClass


def overrideAdditionalMethods(myTrainerClass, opts):
    log.info("B4 additional override myTrainerClass source {}".
        format(inspect.getsource(myTrainerClass)))
    # override any additional modules
    myAdditionalOverride = getModule(opts['model']['additional_override_py'])
    if myAdditionalOverride is not None:
        for funcName, funcValue in inspect.getmembers(myAdditionalOverride,
                                                      inspect.isfunction):
            setattr(myTrainerClass, funcName, funcValue)
    log.info("Aft additional override myTrainerClass's source {}".
        format(inspect.getsource(myTrainerClass)))
    return myTrainerClass


def getModule(moduleName):
    log.info("get module {} from MODULE_MAPS content {}".format(moduleName, str(MODULE_MAPS)))
    myModule = None
    for ModuleMap in MODULE_MAPS:
        log.info("iterate through MODULE_MAPS content {}".
                 format(str(ModuleMap)))
        for name, obj in inspect.getmembers(ModuleMap):
            log.info("iterate through MODULE_MAPS a name {}".format(str(name)))
            if name == moduleName:
                log.info("AnyExp get module {} with source:{}".
                         format(moduleName, inspect.getsource(obj)))
                myModule = obj
                return myModule
    return None


def getClassFromModule(moduleName, className):
    myClass = None
    for ModuleMap in MODULE_MAPS:
        for name, obj in inspect.getmembers(ModuleMap):
            if name == moduleName:
                log.info("ModuleRegistry from module {} get class {} of source:{}".
                         format(moduleName, className, inspect.getsource(obj)))
                myClass = getattr(obj, className)
                return myClass
    return None