File size: 6,175 Bytes
fc329a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Validate the SimplexTasks-12 release package."""
from __future__ import annotations

import argparse
import hashlib
import json
from pathlib import Path

import numpy as np

REQUIRED_MANIFEST_KEYS = {
    "name",
    "slug",
    "version",
    "seed",
    "task_count",
    "real_tasks",
    "synthetic_tasks",
}


def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            h.update(chunk)
    return h.hexdigest()


def check_simplex(name: str, arr: np.ndarray, tol: float = 1e-4) -> list[str]:
    issues = []
    if arr.ndim != 2:
        issues.append(f"{name} must be 2D, got shape {arr.shape}")
        return issues
    if np.any(arr < -tol):
        issues.append(f"{name} contains values below 0")
    row_sums = arr.sum(axis=1)
    max_dev = float(np.max(np.abs(row_sums - 1.0)))
    if max_dev > tol:
        issues.append(f"{name} rows do not sum to 1 within tolerance; max deviation={max_dev:.2e}")
    return issues


def validate_task(task_dir: Path, meta: dict) -> dict:
    npz_path = task_dir / "task.npz"
    meta_path = task_dir / "metadata.json"
    issues: list[str] = []
    warnings: list[str] = []

    if not npz_path.exists():
        issues.append("Missing task.npz")
        return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings}
    if not meta_path.exists():
        issues.append("Missing metadata.json")
        return {"task_id": meta.get("task_id", task_dir.name), "issues": issues, "warnings": warnings}

    data = np.load(npz_path, allow_pickle=False)
    files = sorted(data.files)
    declared = sorted(meta.get("available_arrays", []))
    if files != declared:
        issues.append(f"available_arrays mismatch: metadata={declared}, npz={files}")

    if "Y" not in data.files or "U" not in data.files:
        issues.append("Both Y and U must be present")
    else:
        Y = data["Y"]
        U = data["U"]
        if Y.shape != U.shape:
            issues.append(f"Y and U shape mismatch: {Y.shape} vs {U.shape}")
        else:
            if meta.get("n_samples") != int(Y.shape[0]):
                issues.append(f"n_samples mismatch: metadata={meta.get('n_samples')} actual={Y.shape[0]}")
            if meta.get("simplex_dim") != int(Y.shape[1]):
                issues.append(f"simplex_dim mismatch: metadata={meta.get('simplex_dim')} actual={Y.shape[1]}")
        issues.extend(check_simplex("Y", Y))
        issues.extend(check_simplex("U", U))

    if meta.get("subset") == "synthetic":
        for required in ["X", "R", "sigma_true", "split"]:
            if required not in data.files:
                issues.append(f"Synthetic task missing {required}")
        if "split" in data.files:
            split = data["split"]
            if split.shape[0] != meta.get("n_samples"):
                issues.append("split length does not match n_samples")
            unique = sorted({str(x) for x in split.tolist()})
            if not set(unique).issubset({"train", "scale", "cal", "test"}):
                issues.append(f"Unexpected split labels: {unique}")
        if not (task_dir / "config.yaml").exists():
            issues.append("Synthetic task missing config.yaml")
    else:
        redistribution = meta.get("redistribution")
        if redistribution not in {"derived-only", "source-cited", "open"}:
            issues.append(f"Unexpected redistribution value: {redistribution}")
        if "notes" not in meta:
            warnings.append("metadata.notes missing")

    task_sha = sha256_file(npz_path)
    return {
        "task_id": meta["task_id"],
        "subset": meta["subset"],
        "task_npz_sha256": task_sha,
        "task_npz_bytes": npz_path.stat().st_size,
        "issues": issues,
        "warnings": warnings,
    }


def validate_package(root: Path) -> dict:
    manifest_path = root / "manifest.json"
    if not manifest_path.exists():
        raise FileNotFoundError(f"Missing manifest: {manifest_path}")

    manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
    missing_manifest = sorted(REQUIRED_MANIFEST_KEYS - set(manifest))
    if missing_manifest:
        raise ValueError(f"Manifest missing keys: {missing_manifest}")

    task_entries = manifest["real_tasks"] + manifest["synthetic_tasks"]
    if len(task_entries) != manifest["task_count"]:
        raise ValueError(
            f"task_count mismatch: manifest says {manifest['task_count']} but lists {len(task_entries)} tasks"
        )

    report = {
        "name": manifest["name"],
        "version": manifest["version"],
        "task_count": manifest["task_count"],
        "task_reports": [],
        "issues": [],
        "warnings": [],
    }

    for task_meta in task_entries:
        task_dir = root / task_meta["subset"] / task_meta["task_id"]
        if not task_dir.exists():
            report["issues"].append(f"Missing task directory: {task_dir.relative_to(root)}")
            continue
        task_report = validate_task(task_dir, task_meta)
        report["task_reports"].append(task_report)
        report["issues"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["issues"]])
        report["warnings"].extend([f"{task_meta['task_id']}: {x}" for x in task_report["warnings"]])

    for required in ["README.md", "LICENSE_NOTES.md"]:
        if not (root / required).exists():
            report["issues"].append(f"Missing root file: {required}")

    return report


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", default="release/simplextasks-12")
    parser.add_argument("--write-report", action="store_true")
    args = parser.parse_args()

    root = Path(args.root)
    report = validate_package(root)
    print(json.dumps(report, indent=2))

    if args.write_report:
        out_path = root / "VALIDATION_REPORT.json"
        out_path.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")

    if report["issues"]:
        raise SystemExit(1)


if __name__ == "__main__":
    main()