File size: 6,175 Bytes
fc329a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """Validate the SimplexTasks-12 release package."""
from __future__ import annotations
import argparse
import hashlib
import json
from pathlib import Path
import numpy as np
REQUIRED_MANIFEST_KEYS = {
"name",
"slug",
"version",
"seed",
"task_count",
"real_tasks",
"synthetic_tasks",
}
def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
h = hashlib.sha256()
with path.open("rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def check_simplex(name: str, arr: np.ndarray, tol: float = 1e-4) -> list[str]:
issues = []
if arr.ndim != 2:
issues.append(f"{name} must be 2D, got shape {arr.shape}")
return issues
if np.any(arr < -tol):
issues.append(f"{name} contains values below 0")
row_sums = arr.sum(axis=1)
max_dev = float(np.max(np.abs(row_sums - 1.0)))
if max_dev > tol:
issues.append(f"{name} rows do not sum to 1 within tolerance; max deviation={max_dev:.2e}")
return issues
def validate_task(task_dir: Path, meta: dict) -> dict:
npz_path = task_dir / "task.npz"
meta_path = task_dir / "metadata.json"
issues: list[str] = []
warnings: list[str] = []
if not npz_path.exists():
issues.append("Missing task.npz")
return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings}
if not meta_path.exists():
issues.append("Missing metadata.json")
return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings}
data = np.load(npz_path, allow_pickle=False)
files = sorted(data.files)
declared = sorted(meta.get("available_arrays", []))
if files != declared:
issues.append(f"available_arrays mismatch: metadata={declared}, npz={files}")
if "Y" not in data.files or "U" not in data.files:
issues.append("Both Y and U must be present")
else:
Y = data["Y"]
U = data["U"]
if Y.shape != U.shape:
issues.append(f"Y and U shape mismatch: {Y.shape} vs {U.shape}")
else:
if meta.get("n_samples") != int(Y.shape[0]):
issues.append(f"n_samples mismatch: metadata={meta.get('n_samples')} actual={Y.shape[0]}")
if meta.get("simplex_dim") != int(Y.shape[1]):
issues.append(f"simplex_dim mismatch: metadata={meta.get('simplex_dim')} actual={Y.shape[1]}")
issues.extend(check_simplex("Y", Y))
issues.extend(check_simplex("U", U))
if meta.get("subset") == "synthetic":
for required in ["X", "R", "sigma_true", "split"]:
if required not in data.files:
issues.append(f"Synthetic task missing {required}")
if "split" in data.files:
split = data["split"]
if split.shape[0] != meta.get("n_samples"):
issues.append("split length does not match n_samples")
unique = sorted({str(x) for x in split.tolist()})
if not set(unique).issubset({"train", "scale", "cal", "test"}):
issues.append(f"Unexpected split labels: {unique}")
if not (task_dir / "config.yaml").exists():
issues.append("Synthetic task missing config.yaml")
else:
redistribution = meta.get("redistribution")
if redistribution not in {"derived-only", "source-cited", "open"}:
issues.append(f"Unexpected redistribution value: {redistribution}")
if "notes" not in meta:
warnings.append("metadata.notes missing")
task_sha = sha256_file(npz_path)
return {
"task_id": meta["task_id"],
"subset": meta["subset"],
"task_npz_sha256": task_sha,
"task_npz_bytes": npz_path.stat().st_size,
"issues": issues,
"warnings": warnings,
}
def validate_package(root: Path) -> dict:
manifest_path = root / "manifest.json"
if not manifest_path.exists():
raise FileNotFoundError(f"Missing manifest: {manifest_path}")
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
missing_manifest = sorted(REQUIRED_MANIFEST_KEYS - set(manifest))
if missing_manifest:
raise ValueError(f"Manifest missing keys: {missing_manifest}")
task_entries = manifest["real_tasks"] + manifest["synthetic_tasks"]
if len(task_entries) != manifest["task_count"]:
raise ValueError(
f"task_count mismatch: manifest says {manifest['task_count']} but lists {len(task_entries)} tasks"
)
report = {
"name": manifest["name"],
"version": manifest["version"],
"task_count": manifest["task_count"],
"task_reports": [],
"issues": [],
"warnings": [],
}
for task_meta in task_entries:
task_dir = root / task_meta["subset"] / task_meta["task_id"]
if not task_dir.exists():
report["issues"].append(f"Missing task directory: {task_dir.relative_to(root)}")
continue
task_report = validate_task(task_dir, task_meta)
report["task_reports"].append(task_report)
report["issues"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["issues"]])
report["warnings"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["warnings"]])
for required in ["README.md", "LICENSE_NOTES.md"]:
if not (root / required).exists():
report["issues"].append(f"Missing root file: {required}")
return report
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--root", default="release/simplextasks-12")
parser.add_argument("--write-report", action="store_true")
args = parser.parse_args()
root = Path(args.root)
report = validate_package(root)
print(json.dumps(report, indent=2))
if args.write_report:
out_path = root / "VALIDATION_REPORT.json"
out_path.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
if report["issues"]:
raise SystemExit(1)
if __name__ == "__main__":
main()
|