|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from omegaconf import OmegaConf |
|
|
from typing import List |
|
|
|
|
|
def load_and_merge_configs(config_paths: List[str]): |
|
|
""" |
|
|
Load and merge multiple OmegaConf configs in order. |
|
|
Later configs override earlier ones. |
|
|
Any missing keys in later configs are added to the schema as None. |
|
|
|
|
|
Args: |
|
|
config_paths (List[str]): List of paths to config files. |
|
|
The first config acts as the base schema. |
|
|
|
|
|
Returns: |
|
|
OmegaConf.DictConfig: The merged configuration. |
|
|
""" |
|
|
if not config_paths: |
|
|
raise ValueError("No config paths provided.") |
|
|
|
|
|
|
|
|
schema = OmegaConf.load(config_paths[0]) |
|
|
|
|
|
|
|
|
for path in config_paths[1:]: |
|
|
cfg = OmegaConf.load(path) |
|
|
|
|
|
|
|
|
missing_keys = set(cfg.keys()) - set(schema.keys()) |
|
|
for key in missing_keys: |
|
|
OmegaConf.update(schema, key, None, force_add=True) |
|
|
|
|
|
|
|
|
schema = OmegaConf.merge(schema, cfg) |
|
|
|
|
|
return schema |
|
|
|
|
|
|
|
|
def seed_everything(seed: int): |
|
|
import random, os |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
random.seed(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
dtype_map = { |
|
|
'float32': torch.float32, |
|
|
'float': torch.float32, |
|
|
'float64': torch.float64, |
|
|
'double': torch.float64, |
|
|
'float16': torch.float16, |
|
|
'half': torch.float16, |
|
|
'bfloat16': torch.bfloat16, |
|
|
'int32': torch.int32, |
|
|
'int64': torch.int64, |
|
|
'long': torch.int64, |
|
|
} |
|
|
|