| """Validate the SimplexTasks-12 release package.""" |
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| REQUIRED_MANIFEST_KEYS = { |
| "name", |
| "slug", |
| "version", |
| "seed", |
| "task_count", |
| "real_tasks", |
| "synthetic_tasks", |
| } |
|
|
|
|
| def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str: |
| h = hashlib.sha256() |
| with path.open("rb") as f: |
| while True: |
| chunk = f.read(chunk_size) |
| if not chunk: |
| break |
| h.update(chunk) |
| return h.hexdigest() |
|
|
|
|
| def check_simplex(name: str, arr: np.ndarray, tol: float = 1e-4) -> list[str]: |
| issues = [] |
| if arr.ndim != 2: |
| issues.append(f"{name} must be 2D, got shape {arr.shape}") |
| return issues |
| if np.any(arr < -tol): |
| issues.append(f"{name} contains values below 0") |
| row_sums = arr.sum(axis=1) |
| max_dev = float(np.max(np.abs(row_sums - 1.0))) |
| if max_dev > tol: |
| issues.append(f"{name} rows do not sum to 1 within tolerance; max deviation={max_dev:.2e}") |
| return issues |
|
|
|
|
| def validate_task(task_dir: Path, meta: dict) -> dict: |
| npz_path = task_dir / "task.npz" |
| meta_path = task_dir / "metadata.json" |
| issues: list[str] = [] |
| warnings: list[str] = [] |
|
|
| if not npz_path.exists(): |
| issues.append("Missing task.npz") |
| return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings} |
| if not meta_path.exists(): |
| issues.append("Missing metadata.json") |
| return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings} |
|
|
| data = np.load(npz_path, allow_pickle=False) |
| files = sorted(data.files) |
| declared = sorted(meta.get("available_arrays", [])) |
| if files != declared: |
| issues.append(f"available_arrays mismatch: metadata={declared}, npz={files}") |
|
|
| if "Y" not in data.files or "U" not in data.files: |
| issues.append("Both Y and U must be present") |
| else: |
| Y = data["Y"] |
| U = data["U"] |
| if Y.shape != U.shape: |
| issues.append(f"Y and U shape mismatch: {Y.shape} vs {U.shape}") |
| else: |
| if meta.get("n_samples") != int(Y.shape[0]): |
| issues.append(f"n_samples mismatch: metadata={meta.get('n_samples')} actual={Y.shape[0]}") |
| if meta.get("simplex_dim") != int(Y.shape[1]): |
| issues.append(f"simplex_dim mismatch: metadata={meta.get('simplex_dim')} actual={Y.shape[1]}") |
| issues.extend(check_simplex("Y", Y)) |
| issues.extend(check_simplex("U", U)) |
|
|
| if meta.get("subset") == "synthetic": |
| for required in ["X", "R", "sigma_true", "split"]: |
| if required not in data.files: |
| issues.append(f"Synthetic task missing {required}") |
| if "split" in data.files: |
| split = data["split"] |
| if split.shape[0] != meta.get("n_samples"): |
| issues.append("split length does not match n_samples") |
| unique = sorted({str(x) for x in split.tolist()}) |
| if not set(unique).issubset({"train", "scale", "cal", "test"}): |
| issues.append(f"Unexpected split labels: {unique}") |
| if not (task_dir / "config.yaml").exists(): |
| issues.append("Synthetic task missing config.yaml") |
| else: |
| redistribution = meta.get("redistribution") |
| if redistribution not in {"derived-only", "source-cited", "open"}: |
| issues.append(f"Unexpected redistribution value: {redistribution}") |
| if "notes" not in meta: |
| warnings.append("metadata.notes missing") |
|
|
| task_sha = sha256_file(npz_path) |
| return { |
| "task_id": meta["task_id"], |
| "subset": meta["subset"], |
| "task_npz_sha256": task_sha, |
| "task_npz_bytes": npz_path.stat().st_size, |
| "issues": issues, |
| "warnings": warnings, |
| } |
|
|
|
|
| def validate_package(root: Path) -> dict: |
| manifest_path = root / "manifest.json" |
| if not manifest_path.exists(): |
| raise FileNotFoundError(f"Missing manifest: {manifest_path}") |
|
|
| manifest = json.loads(manifest_path.read_text(encoding="utf-8")) |
| missing_manifest = sorted(REQUIRED_MANIFEST_KEYS - set(manifest)) |
| if missing_manifest: |
| raise ValueError(f"Manifest missing keys: {missing_manifest}") |
|
|
| task_entries = manifest["real_tasks"] + manifest["synthetic_tasks"] |
| if len(task_entries) != manifest["task_count"]: |
| raise ValueError( |
| f"task_count mismatch: manifest says {manifest['task_count']} but lists {len(task_entries)} tasks" |
| ) |
|
|
| report = { |
| "name": manifest["name"], |
| "version": manifest["version"], |
| "task_count": manifest["task_count"], |
| "task_reports": [], |
| "issues": [], |
| "warnings": [], |
| } |
|
|
| for task_meta in task_entries: |
| task_dir = root / task_meta["subset"] / task_meta["task_id"] |
| if not task_dir.exists(): |
| report["issues"].append(f"Missing task directory: {task_dir.relative_to(root)}") |
| continue |
| task_report = validate_task(task_dir, task_meta) |
| report["task_reports"].append(task_report) |
| report["issues"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["issues"]]) |
| report["warnings"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["warnings"]]) |
|
|
| for required in ["README.md", "LICENSE_NOTES.md"]: |
| if not (root / required).exists(): |
| report["issues"].append(f"Missing root file: {required}") |
|
|
| return report |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--root", default="release/simplextasks-12") |
| parser.add_argument("--write-report", action="store_true") |
| args = parser.parse_args() |
|
|
| root = Path(args.root) |
| report = validate_package(root) |
| print(json.dumps(report, indent=2)) |
|
|
| if args.write_report: |
| out_path = root / "VALIDATION_REPORT.json" |
| out_path.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
|
|
| if report["issues"]: |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|