| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import functools |
| import os |
| import sys |
| from typing import Any, Callable, Optional |
|
|
| from hydra._internal.utils import _run_hydra, get_args_parser |
| from hydra.core.config_store import ConfigStore |
| from hydra.types import TaskFunction |
| from omegaconf import DictConfig, OmegaConf |
|
|
|
|
| def _get_gpu_name(): |
| try: |
| import pynvml |
| except (ImportError, ModuleNotFoundError): |
| return None |
|
|
| pynvml.nvmlInit() |
| handle = pynvml.nvmlDeviceGetHandleByIndex(0) |
| cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle) |
| pynvml.nvmlShutdown() |
| if cuda_capability == 8: |
| return "a100" |
| elif cuda_capability == 9: |
| return "h100" |
| else: |
| return None |
|
|
|
|
| OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) |
|
|
| |
| OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) |
|
|
| |
| OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True) |
|
|
|
|
| def hydra_runner( |
| config_path: Optional[str] = ".", config_name: Optional[str] = None, schema: Optional[Any] = None |
| ) -> Callable[[TaskFunction], Any]: |
| """ |
| Decorator used for passing the Config paths to main function. |
| Optionally registers a schema used for validation/providing default values. |
| |
| Args: |
| config_path: Optional path that will be added to config search directory. |
| NOTE: The default value of `config_path` has changed between Hydra 1.0 and Hydra 1.1+. |
| Please refer to https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path/ |
| for details. |
| config_name: Pathname of the config file. |
| schema: Structured config type representing the schema used for validation/providing default values. |
| """ |
|
|
| def decorator(task_function: TaskFunction) -> Callable[[], None]: |
| @functools.wraps(task_function) |
| def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: |
| |
| if cfg_passthrough is not None: |
| return task_function(cfg_passthrough) |
| else: |
| args = get_args_parser() |
|
|
| |
| parsed_args = args.parse_args() |
|
|
| |
| overrides = parsed_args.overrides |
|
|
| |
| |
| overrides.append("hydra.output_subdir=null") |
| |
| |
| overrides.append("hydra/job_logging=stdout") |
|
|
| |
| overrides.append("hydra.run.dir=.") |
|
|
| |
| if schema is not None: |
| |
| cs = ConfigStore.instance() |
|
|
| |
| if parsed_args.config_name is not None: |
| path, name = os.path.split(parsed_args.config_name) |
| |
| if path != '': |
| sys.stderr.write( |
| f"ERROR Cannot set config file path using `--config-name` when " |
| "using schema. Please set path using `--config-path` and file name using " |
| "`--config-name` separately.\n" |
| ) |
| sys.exit(1) |
| else: |
| name = config_name |
|
|
| |
| cs.store(name=name, node=schema) |
|
|
| |
| |
| def parse_args(self, args=None, namespace=None): |
| return parsed_args |
|
|
| parsed_args.parse_args = parse_args |
|
|
| |
| |
| |
| argparse_wrapper = parsed_args |
|
|
| _run_hydra( |
| args=argparse_wrapper, |
| args_parser=args, |
| task_function=task_function, |
| config_path=config_path, |
| config_name=config_name, |
| ) |
|
|
| return wrapper |
|
|
| return decorator |
|
|