EasyTemporalPointProcess-main / examples /train_nhp_omegaconf.py
Abigail99216's picture
Upload folder using huggingface_hub
f43af3c verified
from omegaconf import OmegaConf
from easy_tpp.config_factory import ModelConfig
from easy_tpp.model.torch_model.torch_nhp import NHP
def main():
config_omegaconf = OmegaConf.load('configs/experiment_config.yaml')
model_config_dict = config_omegaconf.get('NHP_train').get('model_config')
model_config_dict['num_event_types'] = 10
model_config_dict['num_event_types_pad'] = 11
model_config_dict['event_pad_index'] = 10
model_config = ModelConfig.parse_from_yaml_config(model_config_dict)
nhp_model = NHP(model_config)
print(nhp_model.__dict__)
# config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
#
# model_runner = Runner.build_from_config(config)
#
# model_runner.run()
if __name__ == '__main__':
main()