simplexuq-code / scripts /validate_simplextasks_12.py
anonymous0523ly's picture
Initial anonymous code release
fc329a3 verified
raw
history blame
6.18 kB
"""Validate the SimplexTasks-12 release package."""
from __future__ import annotations
import argparse
import hashlib
import json
from pathlib import Path
import numpy as np
REQUIRED_MANIFEST_KEYS = {
"name",
"slug",
"version",
"seed",
"task_count",
"real_tasks",
"synthetic_tasks",
}
def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
h = hashlib.sha256()
with path.open("rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def check_simplex(name: str, arr: np.ndarray, tol: float = 1e-4) -> list[str]:
issues = []
if arr.ndim != 2:
issues.append(f"{name} must be 2D, got shape {arr.shape}")
return issues
if np.any(arr < -tol):
issues.append(f"{name} contains values below 0")
row_sums = arr.sum(axis=1)
max_dev = float(np.max(np.abs(row_sums - 1.0)))
if max_dev > tol:
issues.append(f"{name} rows do not sum to 1 within tolerance; max deviation={max_dev:.2e}")
return issues
def validate_task(task_dir: Path, meta: dict) -> dict:
npz_path = task_dir / "task.npz"
meta_path = task_dir / "metadata.json"
issues: list[str] = []
warnings: list[str] = []
if not npz_path.exists():
issues.append("Missing task.npz")
return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings}
if not meta_path.exists():
issues.append("Missing metadata.json")
return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings}
data = np.load(npz_path, allow_pickle=False)
files = sorted(data.files)
declared = sorted(meta.get("available_arrays", []))
if files != declared:
issues.append(f"available_arrays mismatch: metadata={declared}, npz={files}")
if "Y" not in data.files or "U" not in data.files:
issues.append("Both Y and U must be present")
else:
Y = data["Y"]
U = data["U"]
if Y.shape != U.shape:
issues.append(f"Y and U shape mismatch: {Y.shape} vs {U.shape}")
else:
if meta.get("n_samples") != int(Y.shape[0]):
issues.append(f"n_samples mismatch: metadata={meta.get('n_samples')} actual={Y.shape[0]}")
if meta.get("simplex_dim") != int(Y.shape[1]):
issues.append(f"simplex_dim mismatch: metadata={meta.get('simplex_dim')} actual={Y.shape[1]}")
issues.extend(check_simplex("Y", Y))
issues.extend(check_simplex("U", U))
if meta.get("subset") == "synthetic":
for required in ["X", "R", "sigma_true", "split"]:
if required not in data.files:
issues.append(f"Synthetic task missing {required}")
if "split" in data.files:
split = data["split"]
if split.shape[0] != meta.get("n_samples"):
issues.append("split length does not match n_samples")
unique = sorted({str(x) for x in split.tolist()})
if not set(unique).issubset({"train", "scale", "cal", "test"}):
issues.append(f"Unexpected split labels: {unique}")
if not (task_dir / "config.yaml").exists():
issues.append("Synthetic task missing config.yaml")
else:
redistribution = meta.get("redistribution")
if redistribution not in {"derived-only", "source-cited", "open"}:
issues.append(f"Unexpected redistribution value: {redistribution}")
if "notes" not in meta:
warnings.append("metadata.notes missing")
task_sha = sha256_file(npz_path)
return {
"task_id": meta["task_id"],
"subset": meta["subset"],
"task_npz_sha256": task_sha,
"task_npz_bytes": npz_path.stat().st_size,
"issues": issues,
"warnings": warnings,
}
def validate_package(root: Path) -> dict:
manifest_path = root / "manifest.json"
if not manifest_path.exists():
raise FileNotFoundError(f"Missing manifest: {manifest_path}")
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
missing_manifest = sorted(REQUIRED_MANIFEST_KEYS - set(manifest))
if missing_manifest:
raise ValueError(f"Manifest missing keys: {missing_manifest}")
task_entries = manifest["real_tasks"] + manifest["synthetic_tasks"]
if len(task_entries) != manifest["task_count"]:
raise ValueError(
f"task_count mismatch: manifest says {manifest['task_count']} but lists {len(task_entries)} tasks"
)
report = {
"name": manifest["name"],
"version": manifest["version"],
"task_count": manifest["task_count"],
"task_reports": [],
"issues": [],
"warnings": [],
}
for task_meta in task_entries:
task_dir = root / task_meta["subset"] / task_meta["task_id"]
if not task_dir.exists():
report["issues"].append(f"Missing task directory: {task_dir.relative_to(root)}")
continue
task_report = validate_task(task_dir, task_meta)
report["task_reports"].append(task_report)
report["issues"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["issues"]])
report["warnings"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["warnings"]])
for required in ["README.md", "LICENSE_NOTES.md"]:
if not (root / required).exists():
report["issues"].append(f"Missing root file: {required}")
return report
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--root", default="release/simplextasks-12")
parser.add_argument("--write-report", action="store_true")
args = parser.parse_args()
root = Path(args.root)
report = validate_package(root)
print(json.dumps(report, indent=2))
if args.write_report:
out_path = root / "VALIDATION_REPORT.json"
out_path.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
if report["issues"]:
raise SystemExit(1)
if __name__ == "__main__":
main()