English
DeepChoice / experiments /run_validation_experiments.py
antoine.carreaud67
Restore paper experiment plans and update README
4ccb60a
Raw
History Blame Contribute Delete
8.83 kB
import argparse
import copy
import csv
import json
import logging
import subprocess
import sys
from datetime import datetime
from pathlib import Path
import yaml
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from train.train import compute_split_baselines, list_split_batches
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def _load_yaml(path):
with open(path, "r", encoding="utf-8") as handle:
return yaml.safe_load(handle)
def _write_yaml(path, payload):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as handle:
yaml.safe_dump(payload, handle, sort_keys=False)
def _deep_update(target, updates):
for key, value in updates.items():
if isinstance(value, dict) and isinstance(target.get(key), dict):
_deep_update(target[key], value)
else:
target[key] = value
def _slugify(name):
return "".join(ch if ch.isalnum() or ch in "-_" else "_" for ch in name)
def _to_jsonable(value):
if isinstance(value, dict):
return {key: _to_jsonable(item) for key, item in value.items()}
if isinstance(value, (list, tuple)):
return [_to_jsonable(item) for item in value]
if hasattr(value, "tolist"):
return value.tolist()
return value
def _run_command(command, log_path, cwd):
log_path.parent.mkdir(parents=True, exist_ok=True)
logging.info("Running: %s", " ".join(command))
with open(log_path, "w", encoding="utf-8") as handle:
process = subprocess.run(command, cwd=cwd, stdout=handle, stderr=subprocess.STDOUT, text=True)
return process.returncode
def _write_summary_csv(path, rows):
path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = [
"experiment",
"status",
"weights_path",
"model_miou",
"model_mf1",
"majority_miou",
"majority_mf1",
"hard_vote_miou",
"hard_vote_mf1",
"anyview_miou",
"anyview_mf1",
"metrics_path",
"log_test",
]
with open(path, "w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow({key: row.get(key, "") for key in fieldnames})
def build_eval_config(base_config, experiment, output_root, defaults=None):
cfg = copy.deepcopy(base_config)
_deep_update(cfg, experiment.get("overrides", {}))
defaults = defaults or {}
cfg.setdefault("data", {})
cfg["data"].setdefault("split_roots", {})
for split_name in ("train", "val", "test"):
split_key = f"{split_name}_batches_root"
split_root = experiment.get(split_key, defaults.get(split_key))
if split_root:
cfg["data"]["split_roots"][split_name] = str(split_root)
cfg.setdefault("training", {})
cfg.setdefault("test", {})
cfg["training"]["compute_validation_baselines"] = False
cfg["training"]["output_dir"] = str(output_root / "eval_runs" / _slugify(experiment["name"]))
return cfg
def build_test_command(config_path, weights_path, split_name, metrics_output, limit_files=None):
command = [
sys.executable,
"main.py",
"--config",
str(config_path),
"--mode",
"test",
"--weights_path",
str(weights_path),
"--split",
split_name,
"--metrics_output",
str(metrics_output),
]
if limit_files is not None:
command.extend(["--limit_files", str(int(limit_files))])
return command
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base-config", default="configs/config_deepchoice_base.yaml")
parser.add_argument("--plan", default="experiments/val_full_plan.yaml")
parser.add_argument("--output-root", default=None)
parser.add_argument("--continue-on-error", action="store_true")
args = parser.parse_args()
repo_root = REPO_ROOT
base_config = _load_yaml(repo_root / args.base_config) if not Path(args.base_config).is_absolute() else _load_yaml(args.base_config)
plan = _load_yaml(repo_root / args.plan) if not Path(args.plan).is_absolute() else _load_yaml(args.plan)
output_root = Path(args.output_root) if args.output_root else repo_root / "artifacts" / "experiments" / datetime.now().strftime("val_full_%Y%m%d_%H%M%S")
output_root.mkdir(parents=True, exist_ok=True)
defaults = plan.get("defaults", {})
split_name = defaults.get("split", "val")
limit_files = defaults.get("limit_files")
summary_rows = []
baseline_cfg = copy.deepcopy(base_config)
baseline_cfg.setdefault("data", {})
baseline_cfg["data"].setdefault("split_roots", {})
val_batches_root = defaults.get("val_batches_root")
if val_batches_root:
baseline_cfg["data"]["split_roots"]["val"] = str(val_batches_root)
baseline_paths = list_split_batches(baseline_cfg, split_name, limit=limit_files)
baselines = compute_split_baselines(
baseline_cfg,
paths=baseline_paths,
file_batch_size=baseline_cfg.get("test", {}).get("file_batch_size", baseline_cfg["training"].get("eval_file_batch_size", 1)),
desc=f"Computing {split_name} baselines",
)
baseline_path = output_root / "metrics" / f"baselines_{split_name}.json"
baseline_path.parent.mkdir(parents=True, exist_ok=True)
baseline_path.write_text(json.dumps(_to_jsonable(baselines), indent=2), encoding="utf-8")
summary_rows.append(
{
"experiment": f"baseline_{split_name}",
"status": "ok",
"majority_miou": baselines["majority"]["miou"],
"majority_mf1": baselines["majority"]["mf1"],
"hard_vote_miou": baselines["hard_vote"]["miou"],
"hard_vote_mf1": baselines["hard_vote"]["mf1"],
"anyview_miou": baselines["anyview"]["miou"],
"anyview_mf1": baselines["anyview"]["mf1"],
"metrics_path": str(baseline_path),
}
)
for experiment in plan.get("experiments", []):
exp_name = experiment["name"]
weights_path = Path(experiment["weights_path"])
config_path = output_root / "configs" / f"{_slugify(exp_name)}.yaml"
test_log = output_root / "logs" / f"{_slugify(exp_name)}__test.log"
metrics_path = output_root / "metrics" / f"{_slugify(exp_name)}__test.json"
row = {
"experiment": exp_name,
"status": "pending",
"weights_path": str(weights_path),
"metrics_path": str(metrics_path),
"log_test": str(test_log),
}
try:
cfg = build_eval_config(base_config, experiment, output_root, defaults=defaults)
cfg["training"]["precomputed_validation_baselines"] = _to_jsonable(baselines)
cfg["test"]["precomputed_baselines"] = _to_jsonable(baselines)
_write_yaml(config_path, cfg)
command = build_test_command(config_path, weights_path, split_name, metrics_path, limit_files=limit_files)
rc = _run_command(command, test_log, repo_root)
if rc != 0:
raise RuntimeError(f"Validation failed with return code {rc}")
metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
row["model_miou"] = metrics["miou"]
row["model_mf1"] = metrics["mf1"]
row["majority_miou"] = metrics["baselines"]["majority"]["miou"]
row["majority_mf1"] = metrics["baselines"]["majority"]["mf1"]
row["hard_vote_miou"] = metrics["baselines"]["hard_vote"]["miou"]
row["hard_vote_mf1"] = metrics["baselines"]["hard_vote"]["mf1"]
row["anyview_miou"] = metrics["baselines"]["anyview"]["miou"]
row["anyview_mf1"] = metrics["baselines"]["anyview"]["mf1"]
row["status"] = "ok"
except Exception as exc:
logging.exception("Validation failed: %s", exp_name)
row["status"] = f"failed: {exc}"
summary_rows.append(row)
_write_summary_csv(output_root / "summary.csv", summary_rows)
(output_root / "summary.json").write_text(json.dumps(_to_jsonable(summary_rows), indent=2), encoding="utf-8")
if not args.continue_on_error:
raise
continue
summary_rows.append(row)
_write_summary_csv(output_root / "summary.csv", summary_rows)
(output_root / "summary.json").write_text(json.dumps(_to_jsonable(summary_rows), indent=2), encoding="utf-8")
logging.info("Validation experiment run complete. Summary written under %s", output_root)
if __name__ == "__main__":
main()