| |
| from __future__ import annotations |
|
|
| from pathlib import Path |
| import json |
| import re |
| import sys |
| import zipfile |
|
|
| import torch |
|
|
|
|
| EXPECTED_FILES = [ |
| "README.md", |
| "MODEL_CARD.md", |
| "requirements.txt", |
| "docs/PSY_STATUS.md", |
| "docs/EVAL_PROTOCOL.md", |
| "docs/PSY_CONTACT_PROTOCOL.md", |
| "docs/PSY_MEMORY_ARCHITECTURE.md", |
| "real_pipeline/psy_contact.py", |
| "real_pipeline/psy_memory.py", |
| "real_pipeline/psy_verdict.py", |
| "real_pipeline/psy_heads.py", |
| "real_pipeline/psy_infer.py", |
| "checkpoints/psy_6.9m_encoder.pt", |
| "label_maps/cve_sanitized_labels.json", |
| "label_maps/rule_ast_labels.json", |
| "label_maps/network_flow_labels.json", |
| "demo_artifacts/cve_record_sample.jsonl", |
| "demo_artifacts/rule_ast_sample.jsonl", |
| "demo_artifacts/network_flow_sample.jsonl", |
| "results/cve_sanitized_result.json", |
| "results/rule_ast_result.json", |
| "results/network_flow_result.json", |
| "scripts/validate_bundle.py", |
| "scripts/run_demo.py", |
| ] |
|
|
| OPTIONAL_HEADS = [ |
| "checkpoints/heads/cve_sanitized_head.pt", |
| "checkpoints/heads/rule_ast_head.pt", |
| "checkpoints/heads/network_flow_head.pt", |
| ] |
|
|
| FORBIDDEN_PATH_PARTS = { |
| ".git", |
| ".env", |
| "__pycache__", |
| "train_masked.log", |
| "probe.log", |
| "ab_probes.log", |
| "checkpoint.pt", |
| "train", |
| "raw", |
| "shards", |
| "corpus", |
| "runpod", |
| } |
|
|
| SECRET_PATTERNS = [ |
| re.compile(r"-----BEGIN [A-Z ]*PRIVATE KEY-----"), |
| re.compile(r"\bAKIA[0-9A-Z]{16}\b"), |
| re.compile(r"\bghp_[A-Za-z0-9_]{20,}\b"), |
| re.compile(r"\bxox[baprs]-[A-Za-z0-9-]{10,}\b"), |
| re.compile(r"\bsk-[A-Za-z0-9]{20,}\b"), |
| ] |
|
|
|
|
| def iter_files(root: Path) -> list[Path]: |
| return [p for p in root.rglob("*") if p.is_file()] |
|
|
|
|
| def validate_jsonl(path: Path) -> int: |
| count = 0 |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| if line.strip(): |
| json.loads(line) |
| count += 1 |
| if count < 1: |
| raise ValueError(f"empty JSONL: {path}") |
| return count |
|
|
|
|
| def scan_text_file(path: Path) -> list[str]: |
| problems: list[str] = [] |
| try: |
| text = path.read_text(encoding="utf-8") |
| except UnicodeDecodeError: |
| return problems |
| for pattern in SECRET_PATTERNS: |
| if pattern.search(text): |
| problems.append(f"secret-like pattern in {path}") |
| ip_pattern = re.compile(r"\b(?!(?:127\.0\.0\.1|0\.0\.0\.0)\b)(?:\d{1,3}\.){3}\d{1,3}\b") |
| if ip_pattern.search(text): |
| problems.append(f"ip-like pattern in {path}") |
| return problems |
|
|
|
|
| def main() -> None: |
| root = Path(__file__).resolve().parents[1] |
| errors: list[str] = [] |
| warnings: list[str] = [] |
|
|
| for rel in EXPECTED_FILES: |
| if not (root / rel).exists(): |
| errors.append(f"missing required file: {rel}") |
|
|
| for rel in OPTIONAL_HEADS: |
| if not (root / rel).exists(): |
| warnings.append(f"optional probe head not present: {rel}") |
|
|
| for path in iter_files(root): |
| rel = path.relative_to(root).as_posix() |
| parts = set(rel.lower().split("/")) |
| if parts & FORBIDDEN_PATH_PARTS: |
| errors.append(f"forbidden path component: {rel}") |
| errors.extend(scan_text_file(path)) |
|
|
| if (root / "checkpoints/psy_6.9m_encoder.pt").exists(): |
| state = torch.load(root / "checkpoints/psy_6.9m_encoder.pt", map_location="cpu", weights_only=True) |
| params = sum(v.numel() for v in state.values() if hasattr(v, "numel")) |
| if params != 6904064: |
| errors.append(f"unexpected encoder parameter count: {params}") |
|
|
| for rel in [ |
| "demo_artifacts/cve_record_sample.jsonl", |
| "demo_artifacts/rule_ast_sample.jsonl", |
| "demo_artifacts/network_flow_sample.jsonl", |
| ]: |
| if (root / rel).exists(): |
| validate_jsonl(root / rel) |
|
|
| for rel in [ |
| "label_maps/cve_sanitized_labels.json", |
| "label_maps/rule_ast_labels.json", |
| "label_maps/network_flow_labels.json", |
| "results/cve_sanitized_result.json", |
| "results/rule_ast_result.json", |
| "results/network_flow_result.json", |
| ]: |
| if (root / rel).exists(): |
| json.loads((root / rel).read_text(encoding="utf-8")) |
|
|
| zip_path = root.with_suffix(".zip") |
| if zip_path.exists(): |
| with zipfile.ZipFile(zip_path) as zf: |
| for name in zf.namelist(): |
| if name.startswith(root.name + "/.git/") or "/.git/" in name: |
| errors.append(f".git entry in zip: {name}") |
|
|
| report = { |
| "status": "PASSED" if not errors else "FAILED", |
| "bundle_root": str(root), |
| "errors": errors, |
| "warnings": warnings, |
| } |
| print(json.dumps(report, sort_keys=True, indent=2)) |
| if errors: |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|