| | import json |
| |
|
| | import h5py |
| | import numpy as np |
| | from omegaconf import OmegaConf |
| |
|
| | |
| | |
| |
|
| |
|
| | def load_eval(dir): |
| | summaries, results = {}, {} |
| | with h5py.File(str(dir / "results.h5"), "r") as hfile: |
| | for k in hfile.keys(): |
| | r = np.array(hfile[k]) |
| | if len(r.shape) < 3: |
| | results[k] = r |
| | for k, v in hfile.attrs.items(): |
| | summaries[k] = v |
| | with open(dir / "summaries.json", "r") as f: |
| | s = json.load(f) |
| | summaries = {k: v if v is not None else np.nan for k, v in s.items()} |
| | return summaries, results |
| |
|
| |
|
| | def save_eval(dir, summaries, figures, results): |
| | with h5py.File(str(dir / "results.h5"), "w") as hfile: |
| | for k, v in results.items(): |
| | arr = np.array(v) |
| | if not np.issubdtype(arr.dtype, np.number): |
| | arr = arr.astype("object") |
| | hfile.create_dataset(k, data=arr) |
| | |
| | for k, v in summaries.items(): |
| | hfile.attrs[k] = v |
| | s = { |
| | k: float(v) if np.isfinite(v) else None |
| | for k, v in summaries.items() |
| | if not isinstance(v, list) |
| | } |
| | s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}} |
| | with open(dir / "summaries.json", "w") as f: |
| | json.dump(s, f, indent=4) |
| |
|
| | for fig_name, fig in figures.items(): |
| | fig.savefig(dir / f"{fig_name}.png") |
| |
|
| |
|
| | def exists_eval(dir): |
| | return (dir / "results.h5").exists() and (dir / "summaries.json").exists() |
| |
|
| |
|
| | class EvalPipeline: |
| | default_conf = {} |
| |
|
| | export_keys = [] |
| | optional_export_keys = [] |
| |
|
| | def __init__(self, conf): |
| | """Assumes""" |
| | self.default_conf = OmegaConf.create(self.default_conf) |
| | self.conf = OmegaConf.merge(self.default_conf, conf) |
| | self._init(self.conf) |
| |
|
| | def _init(self, conf): |
| | pass |
| |
|
| | @classmethod |
| | def get_dataloader(cls, data_conf=None): |
| | """Returns a data loader with samples for each eval datapoint""" |
| | raise NotImplementedError |
| |
|
| | def get_predictions(self, experiment_dir, model=None, overwrite=False): |
| | """Export a prediction file for each eval datapoint""" |
| | raise NotImplementedError |
| |
|
| | def run_eval(self, loader, pred_file): |
| | """Run the eval on cached predictions""" |
| | raise NotImplementedError |
| |
|
| | def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False): |
| | """Run export+eval loop""" |
| | self.save_conf(experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval) |
| | pred_file = self.get_predictions(experiment_dir, model=model, overwrite=overwrite) |
| |
|
| | f = {} |
| | if not exists_eval(experiment_dir) or overwrite_eval or overwrite: |
| | s, f, r = self.run_eval(self.get_dataloader(self.conf.data, 1), pred_file) |
| | save_eval(experiment_dir, s, f, r) |
| | s, r = load_eval(experiment_dir) |
| | return s, f, r |
| |
|
| | def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False): |
| | |
| | conf_output_path = experiment_dir / "conf.yaml" |
| | if conf_output_path.exists(): |
| | saved_conf = OmegaConf.load(conf_output_path) |
| | if (saved_conf.data != self.conf.data) or (saved_conf.model != self.conf.model): |
| | assert ( |
| | overwrite |
| | ), "configs changed, add --overwrite to rerun experiment with new conf" |
| | if saved_conf.eval != self.conf.eval: |
| | assert ( |
| | overwrite or overwrite_eval |
| | ), "eval configs changed, add --overwrite_eval to rerun evaluation" |
| | OmegaConf.save(self.conf, experiment_dir / "conf.yaml") |
| |
|