File size: 807 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from omegaconf import OmegaConf

from easy_tpp.config_factory import ModelConfig
from easy_tpp.model.torch_model.torch_nhp import NHP


def main():
    config_omegaconf = OmegaConf.load('configs/experiment_config.yaml')

    model_config_dict = config_omegaconf.get('NHP_train').get('model_config')
    model_config_dict['num_event_types'] = 10
    model_config_dict['num_event_types_pad'] = 11
    model_config_dict['event_pad_index'] = 10

    model_config = ModelConfig.parse_from_yaml_config(model_config_dict)

    nhp_model = NHP(model_config)

    print(nhp_model.__dict__)

    # config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
    #
    # model_runner = Runner.build_from_config(config)
    #
    # model_runner.run()


if __name__ == '__main__':
    main()