File size: 807 Bytes
f43af3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from omegaconf import OmegaConf
from easy_tpp.config_factory import ModelConfig
from easy_tpp.model.torch_model.torch_nhp import NHP
def main():
config_omegaconf = OmegaConf.load('configs/experiment_config.yaml')
model_config_dict = config_omegaconf.get('NHP_train').get('model_config')
model_config_dict['num_event_types'] = 10
model_config_dict['num_event_types_pad'] = 11
model_config_dict['event_pad_index'] = 10
model_config = ModelConfig.parse_from_yaml_config(model_config_dict)
nhp_model = NHP(model_config)
print(nhp_model.__dict__)
# config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
#
# model_runner = Runner.build_from_config(config)
#
# model_runner.run()
if __name__ == '__main__':
main()
|