Zilong-Zhao's picture
first commit
c4ac745
import argparse
import atexit
import enum
import json
import os
import pickle
import shutil
import sys
import time
import uuid
from copy import deepcopy
from dataclasses import asdict, fields, is_dataclass
from pathlib import Path
from pprint import pprint
from typing import Any, Callable, List, Dict, Type, Optional, Tuple, TypeVar, Union, cast, get_args, get_origin
import __main__
import numpy as np
import tomli
import tomli_w
import torch
import zero
from . import env
RawConfig = Dict[str, Any]
Report = Dict[str, Any]
T = TypeVar('T')
class Part(enum.Enum):
TRAIN = 'train'
VAL = 'val'
TEST = 'test'
def __str__(self) -> str:
return self.value
class TaskType(enum.Enum):
BINCLASS = 'binclass'
MULTICLASS = 'multiclass'
REGRESSION = 'regression'
def __str__(self) -> str:
return self.value
class Timer(zero.Timer):
@classmethod
def launch(cls) -> 'Timer':
timer = cls()
timer.run()
return timer
def update_training_log(training_log, data, metrics):
def _update(log_part, data_part):
for k, v in data_part.items():
if isinstance(v, dict):
_update(log_part.setdefault(k, {}), v)
elif isinstance(v, list):
log_part.setdefault(k, []).extend(v)
else:
log_part.setdefault(k, []).append(v)
_update(training_log, data)
transposed_metrics = {}
for part, part_metrics in metrics.items():
for metric_name, value in part_metrics.items():
transposed_metrics.setdefault(metric_name, {})[part] = value
_update(training_log, transposed_metrics)
def raise_unknown(unknown_what: str, unknown_value: Any):
raise ValueError(f'Unknown {unknown_what}: {unknown_value}')
def _replace(data, condition, value):
def do(x):
if isinstance(x, dict):
return {k: do(v) for k, v in x.items()}
elif isinstance(x, list):
return [do(y) for y in x]
else:
return value if condition(x) else x
return do(data)
_CONFIG_NONE = '__none__'
def unpack_config(config: RawConfig) -> RawConfig:
config = cast(RawConfig, _replace(config, lambda x: x == _CONFIG_NONE, None))
return config
def pack_config(config: RawConfig) -> RawConfig:
config = cast(RawConfig, _replace(config, lambda x: x is None, _CONFIG_NONE))
return config
def load_config(path: Union[Path, str]) -> Any:
with open(path, 'rb') as f:
return unpack_config(tomli.load(f))
def dump_config(config: Any, path: Union[Path, str]) -> None:
with open(path, 'wb') as f:
tomli_w.dump(pack_config(config), f)
# check that there are no bugs in all these "pack/unpack" things
assert config == load_config(path)
def load_json(path: Union[Path, str], **kwargs) -> Any:
return json.loads(Path(path).read_text(), **kwargs)
def dump_json(x: Any, path: Union[Path, str], **kwargs) -> None:
kwargs.setdefault('indent', 4)
Path(path).write_text(json.dumps(x, **kwargs) + '\n')
def load_pickle(path: Union[Path, str], **kwargs) -> Any:
return pickle.loads(Path(path).read_bytes(), **kwargs)
def dump_pickle(x: Any, path: Union[Path, str], **kwargs) -> None:
Path(path).write_bytes(pickle.dumps(x, **kwargs))
def load(path: Union[Path, str], **kwargs) -> Any:
return globals()[f'load_{Path(path).suffix[1:]}'](Path(path), **kwargs)
def dump(x: Any, path: Union[Path, str], **kwargs) -> Any:
return globals()[f'dump_{Path(path).suffix[1:]}'](x, Path(path), **kwargs)
def _get_output_item_path(
path: Union[str, Path], filename: str, must_exist: bool
) -> Path:
path = env.get_path(path)
if path.suffix == '.toml':
path = path.with_suffix('')
if path.is_dir():
path = path / filename
else:
assert path.name == filename
assert path.parent.exists()
if must_exist:
assert path.exists()
return path
def load_report(path: Path) -> Report:
return load_json(_get_output_item_path(path, 'report.json', True))
def dump_report(report: dict, path: Path) -> None:
dump_json(report, _get_output_item_path(path, 'report.json', False))
def load_predictions(path: Path) -> Dict[str, np.ndarray]:
with np.load(_get_output_item_path(path, 'predictions.npz', True)) as predictions:
return {x: predictions[x] for x in predictions}
def dump_predictions(predictions: Dict[str, np.ndarray], path: Path) -> None:
np.savez(_get_output_item_path(path, 'predictions.npz', False), **predictions)
def dump_metrics(metrics: Dict[str, Any], path: Path) -> None:
dump_json(metrics, _get_output_item_path(path, 'metrics.json', False))
def load_checkpoint(path: Path, *args, **kwargs) -> Dict[str, np.ndarray]:
return torch.load(
_get_output_item_path(path, 'checkpoint.pt', True), *args, **kwargs
)
def get_device() -> torch.device:
if torch.cuda.is_available():
assert os.environ.get('CUDA_VISIBLE_DEVICES') is not None
return torch.device('cuda:0')
else:
return torch.device('cpu')
def _print_sep(c, size=100):
print(c * size)
def start(
config_cls: Type[T] = RawConfig,
argv: Optional[List[str]] = None,
patch_raw_config: Optional[Callable[[RawConfig], None]] = None,
) -> Tuple[T, Path, Report]: # config # output dir # report
parser = argparse.ArgumentParser()
parser.add_argument('config', metavar='FILE')
parser.add_argument('--force', action='store_true')
parser.add_argument('--continue', action='store_true', dest='continue_')
if argv is None:
program = __main__.__file__
args = parser.parse_args()
else:
program = argv[0]
try:
args = parser.parse_args(argv[1:])
except Exception:
print(
'Failed to parse `argv`.'
' Remember that the first item of `argv` must be the path (relative to'
' the project root) to the script/notebook.'
)
raise
args = parser.parse_args(argv)
snapshot_dir = os.environ.get('SNAPSHOT_PATH')
if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists():
assert args.continue_
config_path = env.get_path(args.config)
output_dir = config_path.with_suffix('')
_print_sep('=')
print(f'[output] {output_dir}')
_print_sep('=')
assert config_path.exists()
raw_config = load_config(config_path)
if patch_raw_config is not None:
patch_raw_config(raw_config)
if is_dataclass(config_cls):
config = from_dict(config_cls, raw_config)
full_raw_config = asdict(config)
else:
assert config_cls is dict
full_raw_config = config = raw_config
full_raw_config = asdict(config)
if output_dir.exists():
if args.force:
print('Removing the existing output and creating a new one...')
shutil.rmtree(output_dir)
output_dir.mkdir()
elif not args.continue_:
backup_output(output_dir)
print('The output directory already exists. Done!\n')
sys.exit()
elif output_dir.joinpath('DONE').exists():
backup_output(output_dir)
print('The "DONE" file already exists. Done!')
sys.exit()
else:
print('Continuing with the existing output...')
else:
print('Creating the output...')
output_dir.mkdir()
report = {
'program': str(env.get_relative_path(program)),
'environment': {},
'config': full_raw_config,
}
if torch.cuda.is_available(): # type: ignore[code]
report['environment'].update(
{
'CUDA_VISIBLE_DEVICES': os.environ.get('CUDA_VISIBLE_DEVICES'),
'gpus': zero.hardware.get_gpus_info(),
'torch.version.cuda': torch.version.cuda,
'torch.backends.cudnn.version()': torch.backends.cudnn.version(), # type: ignore[code]
'torch.cuda.nccl.version()': torch.cuda.nccl.version(), # type: ignore[code]
}
)
dump_report(report, output_dir)
dump_json(raw_config, output_dir / 'raw_config.json')
_print_sep('-')
pprint(full_raw_config, width=100)
_print_sep('-')
return cast(config_cls, config), output_dir, report
_LAST_SNAPSHOT_TIME = None
def backup_output(output_dir: Path) -> None:
backup_dir = os.environ.get('TMP_OUTPUT_PATH')
snapshot_dir = os.environ.get('SNAPSHOT_PATH')
if backup_dir is None:
assert snapshot_dir is None
return
assert snapshot_dir is not None
try:
relative_output_dir = output_dir.relative_to(env.PROJ)
except ValueError:
return
for dir_ in [backup_dir, snapshot_dir]:
new_output_dir = dir_ / relative_output_dir
prev_backup_output_dir = new_output_dir.with_name(new_output_dir.name + '_prev')
new_output_dir.parent.mkdir(exist_ok=True, parents=True)
if new_output_dir.exists():
new_output_dir.rename(prev_backup_output_dir)
shutil.copytree(output_dir, new_output_dir)
# the case for evaluate.py which automatically creates configs
if output_dir.with_suffix('.toml').exists():
shutil.copyfile(
output_dir.with_suffix('.toml'), new_output_dir.with_suffix('.toml')
)
if prev_backup_output_dir.exists():
shutil.rmtree(prev_backup_output_dir)
global _LAST_SNAPSHOT_TIME
if _LAST_SNAPSHOT_TIME is None or time.time() - _LAST_SNAPSHOT_TIME > 10 * 60:
import nirvana_dl.snapshot # type: ignore[code]
nirvana_dl.snapshot.dump_snapshot()
_LAST_SNAPSHOT_TIME = time.time()
print('The snapshot was saved!')
def _get_scores(metrics: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, float]]:
return (
{k: v['score'] for k, v in metrics.items()}
if 'score' in next(iter(metrics.values()))
else None
)
def format_scores(metrics: Dict[str, Dict[str, Any]]) -> str:
return ' '.join(
f"[{x}] {metrics[x]['score']:.3f}"
for x in ['test', 'val', 'train']
if x in metrics
)
def finish(output_dir: Path, report: dict) -> None:
print()
_print_sep('=')
metrics = report.get('metrics')
if metrics is not None:
scores = _get_scores(metrics)
if scores is not None:
dump_json(scores, output_dir / 'scores.json')
print(format_scores(metrics))
_print_sep('-')
dump_report(report, output_dir)
json_output_path = os.environ.get('JSON_OUTPUT_FILE')
if json_output_path:
try:
key = str(output_dir.relative_to(env.PROJ))
except ValueError:
pass
else:
json_output_path = Path(json_output_path)
try:
json_data = json.loads(json_output_path.read_text())
except (FileNotFoundError, json.decoder.JSONDecodeError):
json_data = {}
json_data[key] = load_json(output_dir / 'report.json')
json_output_path.write_text(json.dumps(json_data, indent=4))
shutil.copyfile(
json_output_path,
os.path.join(os.environ['SNAPSHOT_PATH'], 'json_output.json'),
)
output_dir.joinpath('DONE').touch()
backup_output(output_dir)
print(f'Done! | {report.get("time")} | {output_dir}')
_print_sep('=')
print()
def from_dict(datacls: Type[T], data: dict) -> T:
assert is_dataclass(datacls)
data = deepcopy(data)
for field in fields(datacls):
if field.name not in data:
continue
if is_dataclass(field.type):
data[field.name] = from_dict(field.type, data[field.name])
elif (
get_origin(field.type) is Union
and len(get_args(field.type)) == 2
and get_args(field.type)[1] is type(None)
and is_dataclass(get_args(field.type)[0])
):
if data[field.name] is not None:
data[field.name] = from_dict(get_args(field.type)[0], data[field.name])
return datacls(**data)
def replace_factor_with_value(
config: RawConfig,
key: str,
reference_value: int,
bounds: Tuple[float, float],
) -> None:
factor_key = key + '_factor'
if factor_key not in config:
assert key in config
else:
assert key not in config
factor = config.pop(factor_key)
assert bounds[0] <= factor <= bounds[1]
config[key] = int(factor * reference_value)
def get_temporary_copy(path: Union[str, Path]) -> Path:
path = env.get_path(path)
assert not path.is_dir() and not path.is_symlink()
tmp_path = path.with_name(
path.stem + '___' + str(uuid.uuid4()).replace('-', '') + path.suffix
)
shutil.copyfile(path, tmp_path)
atexit.register(lambda: tmp_path.unlink())
return tmp_path
def get_python():
python = Path('python3.9')
return str(python) if python.exists() else 'python'
def get_catboost_config(real_data_path, is_cv=False):
ds_name = Path(real_data_path).name
C = load_json(f'tuned_models/catboost/{ds_name}_cv.json')
return C