Spaces:
Running
Running
| import argparse | |
| import atexit | |
| import enum | |
| import json | |
| import os | |
| import pickle | |
| import shutil | |
| import sys | |
| import time | |
| import uuid | |
| from copy import deepcopy | |
| from dataclasses import asdict, fields, is_dataclass | |
| from pathlib import Path | |
| from pprint import pprint | |
| from typing import Any, Callable, List, Dict, Type, Optional, Tuple, TypeVar, Union, cast, get_args, get_origin | |
| import __main__ | |
| import numpy as np | |
| import tomli | |
| import tomli_w | |
| import torch | |
| import zero | |
| from . import env | |
| RawConfig = Dict[str, Any] | |
| Report = Dict[str, Any] | |
| T = TypeVar('T') | |
| class Part(enum.Enum): | |
| TRAIN = 'train' | |
| VAL = 'val' | |
| TEST = 'test' | |
| def __str__(self) -> str: | |
| return self.value | |
| class TaskType(enum.Enum): | |
| BINCLASS = 'binclass' | |
| MULTICLASS = 'multiclass' | |
| REGRESSION = 'regression' | |
| def __str__(self) -> str: | |
| return self.value | |
| class Timer(zero.Timer): | |
| def launch(cls) -> 'Timer': | |
| timer = cls() | |
| timer.run() | |
| return timer | |
| def update_training_log(training_log, data, metrics): | |
| def _update(log_part, data_part): | |
| for k, v in data_part.items(): | |
| if isinstance(v, dict): | |
| _update(log_part.setdefault(k, {}), v) | |
| elif isinstance(v, list): | |
| log_part.setdefault(k, []).extend(v) | |
| else: | |
| log_part.setdefault(k, []).append(v) | |
| _update(training_log, data) | |
| transposed_metrics = {} | |
| for part, part_metrics in metrics.items(): | |
| for metric_name, value in part_metrics.items(): | |
| transposed_metrics.setdefault(metric_name, {})[part] = value | |
| _update(training_log, transposed_metrics) | |
| def raise_unknown(unknown_what: str, unknown_value: Any): | |
| raise ValueError(f'Unknown {unknown_what}: {unknown_value}') | |
| def _replace(data, condition, value): | |
| def do(x): | |
| if isinstance(x, dict): | |
| return {k: do(v) for k, v in x.items()} | |
| elif isinstance(x, list): | |
| return [do(y) for y in x] | |
| else: | |
| return value if condition(x) else x | |
| return do(data) | |
| _CONFIG_NONE = '__none__' | |
| def unpack_config(config: RawConfig) -> RawConfig: | |
| config = cast(RawConfig, _replace(config, lambda x: x == _CONFIG_NONE, None)) | |
| return config | |
| def pack_config(config: RawConfig) -> RawConfig: | |
| config = cast(RawConfig, _replace(config, lambda x: x is None, _CONFIG_NONE)) | |
| return config | |
| def load_config(path: Union[Path, str]) -> Any: | |
| with open(path, 'rb') as f: | |
| return unpack_config(tomli.load(f)) | |
| def dump_config(config: Any, path: Union[Path, str]) -> None: | |
| with open(path, 'wb') as f: | |
| tomli_w.dump(pack_config(config), f) | |
| # check that there are no bugs in all these "pack/unpack" things | |
| assert config == load_config(path) | |
| def load_json(path: Union[Path, str], **kwargs) -> Any: | |
| return json.loads(Path(path).read_text(), **kwargs) | |
| def dump_json(x: Any, path: Union[Path, str], **kwargs) -> None: | |
| kwargs.setdefault('indent', 4) | |
| Path(path).write_text(json.dumps(x, **kwargs) + '\n') | |
| def load_pickle(path: Union[Path, str], **kwargs) -> Any: | |
| return pickle.loads(Path(path).read_bytes(), **kwargs) | |
| def dump_pickle(x: Any, path: Union[Path, str], **kwargs) -> None: | |
| Path(path).write_bytes(pickle.dumps(x, **kwargs)) | |
| def load(path: Union[Path, str], **kwargs) -> Any: | |
| return globals()[f'load_{Path(path).suffix[1:]}'](Path(path), **kwargs) | |
| def dump(x: Any, path: Union[Path, str], **kwargs) -> Any: | |
| return globals()[f'dump_{Path(path).suffix[1:]}'](x, Path(path), **kwargs) | |
| def _get_output_item_path( | |
| path: Union[str, Path], filename: str, must_exist: bool | |
| ) -> Path: | |
| path = env.get_path(path) | |
| if path.suffix == '.toml': | |
| path = path.with_suffix('') | |
| if path.is_dir(): | |
| path = path / filename | |
| else: | |
| assert path.name == filename | |
| assert path.parent.exists() | |
| if must_exist: | |
| assert path.exists() | |
| return path | |
| def load_report(path: Path) -> Report: | |
| return load_json(_get_output_item_path(path, 'report.json', True)) | |
| def dump_report(report: dict, path: Path) -> None: | |
| dump_json(report, _get_output_item_path(path, 'report.json', False)) | |
| def load_predictions(path: Path) -> Dict[str, np.ndarray]: | |
| with np.load(_get_output_item_path(path, 'predictions.npz', True)) as predictions: | |
| return {x: predictions[x] for x in predictions} | |
| def dump_predictions(predictions: Dict[str, np.ndarray], path: Path) -> None: | |
| np.savez(_get_output_item_path(path, 'predictions.npz', False), **predictions) | |
| def dump_metrics(metrics: Dict[str, Any], path: Path) -> None: | |
| dump_json(metrics, _get_output_item_path(path, 'metrics.json', False)) | |
| def load_checkpoint(path: Path, *args, **kwargs) -> Dict[str, np.ndarray]: | |
| return torch.load( | |
| _get_output_item_path(path, 'checkpoint.pt', True), *args, **kwargs | |
| ) | |
| def get_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| assert os.environ.get('CUDA_VISIBLE_DEVICES') is not None | |
| return torch.device('cuda:0') | |
| else: | |
| return torch.device('cpu') | |
| def _print_sep(c, size=100): | |
| print(c * size) | |
| def start( | |
| config_cls: Type[T] = RawConfig, | |
| argv: Optional[List[str]] = None, | |
| patch_raw_config: Optional[Callable[[RawConfig], None]] = None, | |
| ) -> Tuple[T, Path, Report]: # config # output dir # report | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('config', metavar='FILE') | |
| parser.add_argument('--force', action='store_true') | |
| parser.add_argument('--continue', action='store_true', dest='continue_') | |
| if argv is None: | |
| program = __main__.__file__ | |
| args = parser.parse_args() | |
| else: | |
| program = argv[0] | |
| try: | |
| args = parser.parse_args(argv[1:]) | |
| except Exception: | |
| print( | |
| 'Failed to parse `argv`.' | |
| ' Remember that the first item of `argv` must be the path (relative to' | |
| ' the project root) to the script/notebook.' | |
| ) | |
| raise | |
| args = parser.parse_args(argv) | |
| snapshot_dir = os.environ.get('SNAPSHOT_PATH') | |
| if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists(): | |
| assert args.continue_ | |
| config_path = env.get_path(args.config) | |
| output_dir = config_path.with_suffix('') | |
| _print_sep('=') | |
| print(f'[output] {output_dir}') | |
| _print_sep('=') | |
| assert config_path.exists() | |
| raw_config = load_config(config_path) | |
| if patch_raw_config is not None: | |
| patch_raw_config(raw_config) | |
| if is_dataclass(config_cls): | |
| config = from_dict(config_cls, raw_config) | |
| full_raw_config = asdict(config) | |
| else: | |
| assert config_cls is dict | |
| full_raw_config = config = raw_config | |
| full_raw_config = asdict(config) | |
| if output_dir.exists(): | |
| if args.force: | |
| print('Removing the existing output and creating a new one...') | |
| shutil.rmtree(output_dir) | |
| output_dir.mkdir() | |
| elif not args.continue_: | |
| backup_output(output_dir) | |
| print('The output directory already exists. Done!\n') | |
| sys.exit() | |
| elif output_dir.joinpath('DONE').exists(): | |
| backup_output(output_dir) | |
| print('The "DONE" file already exists. Done!') | |
| sys.exit() | |
| else: | |
| print('Continuing with the existing output...') | |
| else: | |
| print('Creating the output...') | |
| output_dir.mkdir() | |
| report = { | |
| 'program': str(env.get_relative_path(program)), | |
| 'environment': {}, | |
| 'config': full_raw_config, | |
| } | |
| if torch.cuda.is_available(): # type: ignore[code] | |
| report['environment'].update( | |
| { | |
| 'CUDA_VISIBLE_DEVICES': os.environ.get('CUDA_VISIBLE_DEVICES'), | |
| 'gpus': zero.hardware.get_gpus_info(), | |
| 'torch.version.cuda': torch.version.cuda, | |
| 'torch.backends.cudnn.version()': torch.backends.cudnn.version(), # type: ignore[code] | |
| 'torch.cuda.nccl.version()': torch.cuda.nccl.version(), # type: ignore[code] | |
| } | |
| ) | |
| dump_report(report, output_dir) | |
| dump_json(raw_config, output_dir / 'raw_config.json') | |
| _print_sep('-') | |
| pprint(full_raw_config, width=100) | |
| _print_sep('-') | |
| return cast(config_cls, config), output_dir, report | |
| _LAST_SNAPSHOT_TIME = None | |
| def backup_output(output_dir: Path) -> None: | |
| backup_dir = os.environ.get('TMP_OUTPUT_PATH') | |
| snapshot_dir = os.environ.get('SNAPSHOT_PATH') | |
| if backup_dir is None: | |
| assert snapshot_dir is None | |
| return | |
| assert snapshot_dir is not None | |
| try: | |
| relative_output_dir = output_dir.relative_to(env.PROJ) | |
| except ValueError: | |
| return | |
| for dir_ in [backup_dir, snapshot_dir]: | |
| new_output_dir = dir_ / relative_output_dir | |
| prev_backup_output_dir = new_output_dir.with_name(new_output_dir.name + '_prev') | |
| new_output_dir.parent.mkdir(exist_ok=True, parents=True) | |
| if new_output_dir.exists(): | |
| new_output_dir.rename(prev_backup_output_dir) | |
| shutil.copytree(output_dir, new_output_dir) | |
| # the case for evaluate.py which automatically creates configs | |
| if output_dir.with_suffix('.toml').exists(): | |
| shutil.copyfile( | |
| output_dir.with_suffix('.toml'), new_output_dir.with_suffix('.toml') | |
| ) | |
| if prev_backup_output_dir.exists(): | |
| shutil.rmtree(prev_backup_output_dir) | |
| global _LAST_SNAPSHOT_TIME | |
| if _LAST_SNAPSHOT_TIME is None or time.time() - _LAST_SNAPSHOT_TIME > 10 * 60: | |
| import nirvana_dl.snapshot # type: ignore[code] | |
| nirvana_dl.snapshot.dump_snapshot() | |
| _LAST_SNAPSHOT_TIME = time.time() | |
| print('The snapshot was saved!') | |
| def _get_scores(metrics: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, float]]: | |
| return ( | |
| {k: v['score'] for k, v in metrics.items()} | |
| if 'score' in next(iter(metrics.values())) | |
| else None | |
| ) | |
| def format_scores(metrics: Dict[str, Dict[str, Any]]) -> str: | |
| return ' '.join( | |
| f"[{x}] {metrics[x]['score']:.3f}" | |
| for x in ['test', 'val', 'train'] | |
| if x in metrics | |
| ) | |
| def finish(output_dir: Path, report: dict) -> None: | |
| print() | |
| _print_sep('=') | |
| metrics = report.get('metrics') | |
| if metrics is not None: | |
| scores = _get_scores(metrics) | |
| if scores is not None: | |
| dump_json(scores, output_dir / 'scores.json') | |
| print(format_scores(metrics)) | |
| _print_sep('-') | |
| dump_report(report, output_dir) | |
| json_output_path = os.environ.get('JSON_OUTPUT_FILE') | |
| if json_output_path: | |
| try: | |
| key = str(output_dir.relative_to(env.PROJ)) | |
| except ValueError: | |
| pass | |
| else: | |
| json_output_path = Path(json_output_path) | |
| try: | |
| json_data = json.loads(json_output_path.read_text()) | |
| except (FileNotFoundError, json.decoder.JSONDecodeError): | |
| json_data = {} | |
| json_data[key] = load_json(output_dir / 'report.json') | |
| json_output_path.write_text(json.dumps(json_data, indent=4)) | |
| shutil.copyfile( | |
| json_output_path, | |
| os.path.join(os.environ['SNAPSHOT_PATH'], 'json_output.json'), | |
| ) | |
| output_dir.joinpath('DONE').touch() | |
| backup_output(output_dir) | |
| print(f'Done! | {report.get("time")} | {output_dir}') | |
| _print_sep('=') | |
| print() | |
| def from_dict(datacls: Type[T], data: dict) -> T: | |
| assert is_dataclass(datacls) | |
| data = deepcopy(data) | |
| for field in fields(datacls): | |
| if field.name not in data: | |
| continue | |
| if is_dataclass(field.type): | |
| data[field.name] = from_dict(field.type, data[field.name]) | |
| elif ( | |
| get_origin(field.type) is Union | |
| and len(get_args(field.type)) == 2 | |
| and get_args(field.type)[1] is type(None) | |
| and is_dataclass(get_args(field.type)[0]) | |
| ): | |
| if data[field.name] is not None: | |
| data[field.name] = from_dict(get_args(field.type)[0], data[field.name]) | |
| return datacls(**data) | |
| def replace_factor_with_value( | |
| config: RawConfig, | |
| key: str, | |
| reference_value: int, | |
| bounds: Tuple[float, float], | |
| ) -> None: | |
| factor_key = key + '_factor' | |
| if factor_key not in config: | |
| assert key in config | |
| else: | |
| assert key not in config | |
| factor = config.pop(factor_key) | |
| assert bounds[0] <= factor <= bounds[1] | |
| config[key] = int(factor * reference_value) | |
| def get_temporary_copy(path: Union[str, Path]) -> Path: | |
| path = env.get_path(path) | |
| assert not path.is_dir() and not path.is_symlink() | |
| tmp_path = path.with_name( | |
| path.stem + '___' + str(uuid.uuid4()).replace('-', '') + path.suffix | |
| ) | |
| shutil.copyfile(path, tmp_path) | |
| atexit.register(lambda: tmp_path.unlink()) | |
| return tmp_path | |
| def get_python(): | |
| python = Path('python3.9') | |
| return str(python) if python.exists() else 'python' | |
| def get_catboost_config(real_data_path, is_cv=False): | |
| ds_name = Path(real_data_path).name | |
| C = load_json(f'tuned_models/catboost/{ds_name}_cv.json') | |
| return C |