| import argparse |
| import copy |
| import csv |
| import json |
| import logging |
| import subprocess |
| import sys |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import yaml |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from train.train import compute_split_baselines, list_split_batches |
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
|
| def _load_yaml(path): |
| with open(path, "r", encoding="utf-8") as handle: |
| return yaml.safe_load(handle) |
|
|
|
|
| def _write_yaml(path, payload): |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "w", encoding="utf-8") as handle: |
| yaml.safe_dump(payload, handle, sort_keys=False) |
|
|
|
|
| def _deep_update(target, updates): |
| for key, value in updates.items(): |
| if isinstance(value, dict) and isinstance(target.get(key), dict): |
| _deep_update(target[key], value) |
| else: |
| target[key] = value |
|
|
|
|
| def _slugify(name): |
| return "".join(ch if ch.isalnum() or ch in "-_" else "_" for ch in name) |
|
|
|
|
| def _to_jsonable(value): |
| if isinstance(value, dict): |
| return {key: _to_jsonable(item) for key, item in value.items()} |
| if isinstance(value, (list, tuple)): |
| return [_to_jsonable(item) for item in value] |
| if hasattr(value, "tolist"): |
| return value.tolist() |
| return value |
|
|
|
|
| def _run_command(command, log_path, cwd): |
| log_path.parent.mkdir(parents=True, exist_ok=True) |
| logging.info("Running: %s", " ".join(command)) |
| with open(log_path, "w", encoding="utf-8") as handle: |
| process = subprocess.run(command, cwd=cwd, stdout=handle, stderr=subprocess.STDOUT, text=True) |
| return process.returncode |
|
|
|
|
| def _write_summary_csv(path, rows): |
| path.parent.mkdir(parents=True, exist_ok=True) |
| fieldnames = [ |
| "experiment", |
| "status", |
| "weights_path", |
| "model_miou", |
| "model_mf1", |
| "majority_miou", |
| "majority_mf1", |
| "hard_vote_miou", |
| "hard_vote_mf1", |
| "anyview_miou", |
| "anyview_mf1", |
| "metrics_path", |
| "log_test", |
| ] |
| with open(path, "w", encoding="utf-8", newline="") as handle: |
| writer = csv.DictWriter(handle, fieldnames=fieldnames) |
| writer.writeheader() |
| for row in rows: |
| writer.writerow({key: row.get(key, "") for key in fieldnames}) |
|
|
|
|
| def build_eval_config(base_config, experiment, output_root, defaults=None): |
| cfg = copy.deepcopy(base_config) |
| _deep_update(cfg, experiment.get("overrides", {})) |
| defaults = defaults or {} |
| cfg.setdefault("data", {}) |
| cfg["data"].setdefault("split_roots", {}) |
| for split_name in ("train", "val", "test"): |
| split_key = f"{split_name}_batches_root" |
| split_root = experiment.get(split_key, defaults.get(split_key)) |
| if split_root: |
| cfg["data"]["split_roots"][split_name] = str(split_root) |
| cfg.setdefault("training", {}) |
| cfg.setdefault("test", {}) |
| cfg["training"]["compute_validation_baselines"] = False |
| cfg["training"]["output_dir"] = str(output_root / "eval_runs" / _slugify(experiment["name"])) |
| return cfg |
|
|
|
|
| def build_test_command(config_path, weights_path, split_name, metrics_output, limit_files=None): |
| command = [ |
| sys.executable, |
| "main.py", |
| "--config", |
| str(config_path), |
| "--mode", |
| "test", |
| "--weights_path", |
| str(weights_path), |
| "--split", |
| split_name, |
| "--metrics_output", |
| str(metrics_output), |
| ] |
| if limit_files is not None: |
| command.extend(["--limit_files", str(int(limit_files))]) |
| return command |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--base-config", default="configs/config_deepchoice_base.yaml") |
| parser.add_argument("--plan", default="experiments/val_full_plan.yaml") |
| parser.add_argument("--output-root", default=None) |
| parser.add_argument("--continue-on-error", action="store_true") |
| args = parser.parse_args() |
|
|
| repo_root = REPO_ROOT |
| base_config = _load_yaml(repo_root / args.base_config) if not Path(args.base_config).is_absolute() else _load_yaml(args.base_config) |
| plan = _load_yaml(repo_root / args.plan) if not Path(args.plan).is_absolute() else _load_yaml(args.plan) |
|
|
| output_root = Path(args.output_root) if args.output_root else repo_root / "artifacts" / "experiments" / datetime.now().strftime("val_full_%Y%m%d_%H%M%S") |
| output_root.mkdir(parents=True, exist_ok=True) |
|
|
| defaults = plan.get("defaults", {}) |
| split_name = defaults.get("split", "val") |
| limit_files = defaults.get("limit_files") |
| summary_rows = [] |
|
|
| baseline_cfg = copy.deepcopy(base_config) |
| baseline_cfg.setdefault("data", {}) |
| baseline_cfg["data"].setdefault("split_roots", {}) |
| val_batches_root = defaults.get("val_batches_root") |
| if val_batches_root: |
| baseline_cfg["data"]["split_roots"]["val"] = str(val_batches_root) |
|
|
| baseline_paths = list_split_batches(baseline_cfg, split_name, limit=limit_files) |
| baselines = compute_split_baselines( |
| baseline_cfg, |
| paths=baseline_paths, |
| file_batch_size=baseline_cfg.get("test", {}).get("file_batch_size", baseline_cfg["training"].get("eval_file_batch_size", 1)), |
| desc=f"Computing {split_name} baselines", |
| ) |
| baseline_path = output_root / "metrics" / f"baselines_{split_name}.json" |
| baseline_path.parent.mkdir(parents=True, exist_ok=True) |
| baseline_path.write_text(json.dumps(_to_jsonable(baselines), indent=2), encoding="utf-8") |
| summary_rows.append( |
| { |
| "experiment": f"baseline_{split_name}", |
| "status": "ok", |
| "majority_miou": baselines["majority"]["miou"], |
| "majority_mf1": baselines["majority"]["mf1"], |
| "hard_vote_miou": baselines["hard_vote"]["miou"], |
| "hard_vote_mf1": baselines["hard_vote"]["mf1"], |
| "anyview_miou": baselines["anyview"]["miou"], |
| "anyview_mf1": baselines["anyview"]["mf1"], |
| "metrics_path": str(baseline_path), |
| } |
| ) |
|
|
| for experiment in plan.get("experiments", []): |
| exp_name = experiment["name"] |
| weights_path = Path(experiment["weights_path"]) |
| config_path = output_root / "configs" / f"{_slugify(exp_name)}.yaml" |
| test_log = output_root / "logs" / f"{_slugify(exp_name)}__test.log" |
| metrics_path = output_root / "metrics" / f"{_slugify(exp_name)}__test.json" |
|
|
| row = { |
| "experiment": exp_name, |
| "status": "pending", |
| "weights_path": str(weights_path), |
| "metrics_path": str(metrics_path), |
| "log_test": str(test_log), |
| } |
|
|
| try: |
| cfg = build_eval_config(base_config, experiment, output_root, defaults=defaults) |
| cfg["training"]["precomputed_validation_baselines"] = _to_jsonable(baselines) |
| cfg["test"]["precomputed_baselines"] = _to_jsonable(baselines) |
| _write_yaml(config_path, cfg) |
|
|
| command = build_test_command(config_path, weights_path, split_name, metrics_path, limit_files=limit_files) |
| rc = _run_command(command, test_log, repo_root) |
| if rc != 0: |
| raise RuntimeError(f"Validation failed with return code {rc}") |
|
|
| metrics = json.loads(metrics_path.read_text(encoding="utf-8")) |
| row["model_miou"] = metrics["miou"] |
| row["model_mf1"] = metrics["mf1"] |
| row["majority_miou"] = metrics["baselines"]["majority"]["miou"] |
| row["majority_mf1"] = metrics["baselines"]["majority"]["mf1"] |
| row["hard_vote_miou"] = metrics["baselines"]["hard_vote"]["miou"] |
| row["hard_vote_mf1"] = metrics["baselines"]["hard_vote"]["mf1"] |
| row["anyview_miou"] = metrics["baselines"]["anyview"]["miou"] |
| row["anyview_mf1"] = metrics["baselines"]["anyview"]["mf1"] |
| row["status"] = "ok" |
| except Exception as exc: |
| logging.exception("Validation failed: %s", exp_name) |
| row["status"] = f"failed: {exc}" |
| summary_rows.append(row) |
| _write_summary_csv(output_root / "summary.csv", summary_rows) |
| (output_root / "summary.json").write_text(json.dumps(_to_jsonable(summary_rows), indent=2), encoding="utf-8") |
| if not args.continue_on_error: |
| raise |
| continue |
|
|
| summary_rows.append(row) |
| _write_summary_csv(output_root / "summary.csv", summary_rows) |
| (output_root / "summary.json").write_text(json.dumps(_to_jsonable(summary_rows), indent=2), encoding="utf-8") |
|
|
| logging.info("Validation experiment run complete. Summary written under %s", output_root) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|