| from omegaconf import OmegaConf | |
| from easy_tpp.config_factory import ModelConfig | |
| from easy_tpp.model.torch_model.torch_nhp import NHP | |
| def main(): | |
| config_omegaconf = OmegaConf.load('configs/experiment_config.yaml') | |
| model_config_dict = config_omegaconf.get('NHP_train').get('model_config') | |
| model_config_dict['num_event_types'] = 10 | |
| model_config_dict['num_event_types_pad'] = 11 | |
| model_config_dict['event_pad_index'] = 10 | |
| model_config = ModelConfig.parse_from_yaml_config(model_config_dict) | |
| nhp_model = NHP(model_config) | |
| print(nhp_model.__dict__) | |
| # config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) | |
| # | |
| # model_runner = Runner.build_from_config(config) | |
| # | |
| # model_runner.run() | |
| if __name__ == '__main__': | |
| main() | |