psy-6.9m / scripts /validate_bundle.py
AIIT-Threshold's picture
Psy v0.1 — byte-level defensive cyber-artifact encoder (6.9M params)
185d3f7 verified
Raw
History Blame Contribute Delete
4.83 kB
#!/usr/bin/env python3
from __future__ import annotations
from pathlib import Path
import json
import re
import sys
import zipfile
import torch
EXPECTED_FILES = [
"README.md",
"MODEL_CARD.md",
"requirements.txt",
"docs/PSY_STATUS.md",
"docs/EVAL_PROTOCOL.md",
"docs/PSY_CONTACT_PROTOCOL.md",
"docs/PSY_MEMORY_ARCHITECTURE.md",
"real_pipeline/psy_contact.py",
"real_pipeline/psy_memory.py",
"real_pipeline/psy_verdict.py",
"real_pipeline/psy_heads.py",
"real_pipeline/psy_infer.py",
"checkpoints/psy_6.9m_encoder.pt",
"label_maps/cve_sanitized_labels.json",
"label_maps/rule_ast_labels.json",
"label_maps/network_flow_labels.json",
"demo_artifacts/cve_record_sample.jsonl",
"demo_artifacts/rule_ast_sample.jsonl",
"demo_artifacts/network_flow_sample.jsonl",
"results/cve_sanitized_result.json",
"results/rule_ast_result.json",
"results/network_flow_result.json",
"scripts/validate_bundle.py",
"scripts/run_demo.py",
]
OPTIONAL_HEADS = [
"checkpoints/heads/cve_sanitized_head.pt",
"checkpoints/heads/rule_ast_head.pt",
"checkpoints/heads/network_flow_head.pt",
]
FORBIDDEN_PATH_PARTS = {
".git",
".env",
"__pycache__",
"train_masked.log",
"probe.log",
"ab_probes.log",
"checkpoint.pt",
"train",
"raw",
"shards",
"corpus",
"runpod",
}
SECRET_PATTERNS = [
re.compile(r"-----BEGIN [A-Z ]*PRIVATE KEY-----"),
re.compile(r"\bAKIA[0-9A-Z]{16}\b"),
re.compile(r"\bghp_[A-Za-z0-9_]{20,}\b"),
re.compile(r"\bxox[baprs]-[A-Za-z0-9-]{10,}\b"),
re.compile(r"\bsk-[A-Za-z0-9]{20,}\b"),
]
def iter_files(root: Path) -> list[Path]:
return [p for p in root.rglob("*") if p.is_file()]
def validate_jsonl(path: Path) -> int:
count = 0
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
json.loads(line)
count += 1
if count < 1:
raise ValueError(f"empty JSONL: {path}")
return count
def scan_text_file(path: Path) -> list[str]:
problems: list[str] = []
try:
text = path.read_text(encoding="utf-8")
except UnicodeDecodeError:
return problems
for pattern in SECRET_PATTERNS:
if pattern.search(text):
problems.append(f"secret-like pattern in {path}")
ip_pattern = re.compile(r"\b(?!(?:127\.0\.0\.1|0\.0\.0\.0)\b)(?:\d{1,3}\.){3}\d{1,3}\b")
if ip_pattern.search(text):
problems.append(f"ip-like pattern in {path}")
return problems
def main() -> None:
root = Path(__file__).resolve().parents[1]
errors: list[str] = []
warnings: list[str] = []
for rel in EXPECTED_FILES:
if not (root / rel).exists():
errors.append(f"missing required file: {rel}")
for rel in OPTIONAL_HEADS:
if not (root / rel).exists():
warnings.append(f"optional probe head not present: {rel}")
for path in iter_files(root):
rel = path.relative_to(root).as_posix()
parts = set(rel.lower().split("/"))
if parts & FORBIDDEN_PATH_PARTS:
errors.append(f"forbidden path component: {rel}")
errors.extend(scan_text_file(path))
if (root / "checkpoints/psy_6.9m_encoder.pt").exists():
state = torch.load(root / "checkpoints/psy_6.9m_encoder.pt", map_location="cpu", weights_only=True)
params = sum(v.numel() for v in state.values() if hasattr(v, "numel"))
if params != 6904064:
errors.append(f"unexpected encoder parameter count: {params}")
for rel in [
"demo_artifacts/cve_record_sample.jsonl",
"demo_artifacts/rule_ast_sample.jsonl",
"demo_artifacts/network_flow_sample.jsonl",
]:
if (root / rel).exists():
validate_jsonl(root / rel)
for rel in [
"label_maps/cve_sanitized_labels.json",
"label_maps/rule_ast_labels.json",
"label_maps/network_flow_labels.json",
"results/cve_sanitized_result.json",
"results/rule_ast_result.json",
"results/network_flow_result.json",
]:
if (root / rel).exists():
json.loads((root / rel).read_text(encoding="utf-8"))
zip_path = root.with_suffix(".zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path) as zf:
for name in zf.namelist():
if name.startswith(root.name + "/.git/") or "/.git/" in name:
errors.append(f".git entry in zip: {name}")
report = {
"status": "PASSED" if not errors else "FAILED",
"bundle_root": str(root),
"errors": errors,
"warnings": warnings,
}
print(json.dumps(report, sort_keys=True, indent=2))
if errors:
sys.exit(1)
if __name__ == "__main__":
main()