Spaces:
Running
Running
| import os | |
| from os import path | |
| import logging | |
| from omegaconf import OmegaConf | |
| import hydra | |
| import hashlib | |
| import json | |
| import wandb | |
| import torch | |
| ## Uncomment the following line to make the code deterministic and use CUBLAS_WORKSPACE_CONFIG=:4096:8 | |
| torch.use_deterministic_algorithms(True) | |
| import random | |
| import numpy as np | |
| from experiment import Experiment | |
| logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) | |
| logger = logging.getLogger() | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def get_model_name(config): | |
| masked_copy = OmegaConf.masked_copy( | |
| config, ["datasets", "model", "trainer", "optimizer"] | |
| ) | |
| encoded = json.dumps(OmegaConf.to_container(masked_copy), sort_keys=True).encode() | |
| # encoded['seed']= | |
| hash_obj = hashlib.md5() | |
| hash_obj.update(encoded) | |
| hash_obj.update(f"seed: {config.seed}".encode()) | |
| model_hash = str(hash_obj.hexdigest()) | |
| if len(config.datasets) > 1: | |
| dataset_name = "joint" | |
| else: | |
| dataset_name = list(config.datasets.keys())[0] | |
| if dataset_name == "litbank": | |
| cross_val_split = config.datasets[dataset_name].cross_val_split | |
| dataset_name += f"_cv_{cross_val_split}" | |
| key = f"_{config['key']}" if config["key"] != "" else "" | |
| model_name = f"{dataset_name}_{model_hash}{key}" | |
| return model_name | |
| def main_train(config): | |
| if config.paths.model_name is None: | |
| model_name = get_model_name(config) | |
| else: | |
| model_name = config.paths.model_name | |
| config.paths.model_dir = path.join( | |
| config.paths.base_model_dir, config.paths.model_name_prefix + model_name | |
| ) | |
| config.paths.best_model_dir = path.join(config.paths.model_dir, "best") | |
| for model_dir in [config.paths.model_dir, config.paths.best_model_dir]: | |
| if not path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| if config.paths.model_path is None: | |
| config.paths.model_path = path.abspath( | |
| path.join(config.paths.model_dir, config.paths.model_filename) | |
| ) | |
| config.paths.best_model_path = path.abspath( | |
| path.join(config.paths.best_model_dir, config.paths.model_filename) | |
| ) | |
| if config.paths.best_model_path is None and (config.paths.model_path is not None): | |
| config.paths.best_model_path = config.paths.model_path | |
| # Dump config file | |
| config_file = path.join(config.paths.model_dir, "config.json") | |
| with open(config_file, "w") as f: | |
| f.write(json.dumps(OmegaConf.to_container(config), indent=4, sort_keys=True)) | |
| return model_name | |
| def main_eval(config): | |
| if config.paths.model_dir is None: | |
| raise ValueError | |
| best_model_dir = path.join(config.paths.model_dir, "best") | |
| if path.exists(best_model_dir): | |
| config.paths.best_model_dir = best_model_dir | |
| else: | |
| config.paths.best_model_dir = config.paths.model_dir | |
| config.paths.best_model_path = path.abspath( | |
| path.join(config.paths.best_model_dir, config.paths.model_filename) | |
| ) | |
| def set_seed(seed): | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| torch.backends.cudnn.enabled = False | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| def main(config): | |
| set_seed(config.seed) | |
| if config.train: | |
| model_name = main_train(config) | |
| else: | |
| main_eval(config) | |
| model_name = path.basename(path.normpath(config.paths.model_dir)) | |
| # Strip prefix | |
| if model_name.startswith(config.paths.model_name_prefix): | |
| model_name = model_name[len(config.paths.model_name_prefix) :] | |
| if config.use_wandb: | |
| # Wandb Initialization | |
| try: | |
| wandb.init( | |
| id=model_name, | |
| project="Major Entity Tracking", | |
| config=dict(config), | |
| resume=True, | |
| ) | |
| except: | |
| # Turn off wandb | |
| config.use_wandb = False | |
| logger.info(f"Model name: {model_name}") | |
| Experiment(config) | |
| if __name__ == "__main__": | |
| import sys | |
| sys.argv.append(f"hydra.run.dir={path.dirname(path.realpath(__file__))}") | |
| sys.argv.append("hydra/job_logging=none") | |
| main() | |