File size: 6,740 Bytes
5e21013 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | #!/usr/bin/env python3
"""scripts/push_cve_corpus_to_hf.py β pull enriched cybersec rows from
training_queue and publish them to cuilabs/bee-interactions HF dataset
in the trainer-expected schema.
Source rows must have BOTH `payload.prompt` and `payload.completion` β
the prompt is generated by `backfill_cve_prompts.py`, the completion by
`backfill_cve_completions.py`. Rows that have only one of the two are
silently skipped (they're either still mid-pipeline or a teacher call
failed).
Output schema mirrors what the Vertex/Kaggle trainers expect via their
`_row_user()` / `_row_assistant()` extractors:
{
"role": "assistant",
"prompt": "<question>",
"content": "<answer>",
"domain": "cybersecurity",
"task_type": "cve_analysis",
"target_tiers": ["cell", "cell-plus", "comb"],
"quality_score": 0.85,
"feedback": null,
"model_id": "<completion_model>",
"sample_id": "cve:<cve_id>",
"source": "cve_distillation:<provenance>",
"kev_flag": <bool>,
"cve_id": "<id>",
"cvss_severity": "<sev>",
"cwes": [...],
"created_at": "<ISO ts>"
}
The trainer filter (`is_acceptable` in workers/vertex-train/train.py)
keys on `domain` field equality plus minimum-length checks; everything
else is provenance the trainer ignores but is useful when auditing the
adapter.
Usage
-----
python3 scripts/push_cve_corpus_to_hf.py # full push
python3 scripts/push_cve_corpus_to_hf.py --dry-run # stage only
python3 scripts/push_cve_corpus_to_hf.py --limit 10 # smoke test
Reads HF_TOKEN + POSTGRES_URL_NON_POOLING from `.env`.
"""
from __future__ import annotations
import argparse
import datetime
import json
import os
import sys
from pathlib import Path
try:
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
except ImportError:
pass
import psycopg
from psycopg import rows as psycopg_rows
REPO_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATASET_ID = "cuilabs/bee-interactions"
DEFAULT_TARGET_TIERS = ["cell", "cell-plus", "comb"]
DEFAULT_QUALITY = 0.85
def fetch_rows(conn, limit: int | None) -> list[dict]:
sql = """
SELECT id, external_id, payload
FROM public.training_queue
WHERE kind = 'cve'
AND domain = 'cybersecurity'
AND payload ? 'prompt'
AND payload ? 'completion'
AND length(payload->>'prompt') >= 40
AND length(payload->>'completion') >= 80
ORDER BY
CASE WHEN (payload->>'kev')::boolean THEN 0 ELSE 1 END,
CASE payload->>'cvss_severity'
WHEN 'CRITICAL' THEN 1
WHEN 'HIGH' THEN 2
WHEN 'MEDIUM' THEN 3
ELSE 9
END,
(payload->>'published') DESC NULLS LAST
"""
if limit:
sql += " LIMIT %s"
params = (limit,)
else:
params = ()
with conn.cursor(row_factory=psycopg_rows.dict_row) as cur:
cur.execute(sql, params)
return list(cur.fetchall())
def transform(row: dict) -> dict:
p = row["payload"]
cve_id = p.get("cve_id") or row.get("external_id") or ""
completion_model = p.get("completion_model") or "mistral-medium-latest"
enrich_model = p.get("enrich_model") or completion_model
return {
"role": "assistant",
"prompt": p["prompt"],
"content": p["completion"],
"domain": "cybersecurity",
"task_type": "cve_analysis",
"target_tiers": DEFAULT_TARGET_TIERS,
"quality_score": DEFAULT_QUALITY,
"feedback": None,
"model_id": completion_model,
"sample_id": f"cve:{cve_id}",
"source": f"cve_distillation:prompt={enrich_model};answer={completion_model}",
"kev_flag": bool(p.get("kev")),
"cve_id": cve_id,
"cvss_severity": p.get("cvss_severity"),
"cwes": p.get("cwes") or [],
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
parser.add_argument("--dataset-id", default=DEFAULT_DATASET_ID)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
pg_url = (os.environ.get("POSTGRES_URL_NON_POOLING") or "").strip()
if not pg_url:
print("ERROR: POSTGRES_URL_NON_POOLING not set", file=sys.stderr)
return 1
hf_token = (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "").strip()
if not hf_token and not args.dry_run:
print("ERROR: HF_TOKEN not set (use --dry-run to bypass)", file=sys.stderr)
return 1
with psycopg.connect(pg_url, autocommit=False) as conn:
rows = fetch_rows(conn, args.limit)
if not rows:
print("No rows with both prompt and completion β backfill incomplete?")
return 0
transformed = [transform(r) for r in rows]
kev_count = sum(1 for r in transformed if r["kev_flag"])
print(f"Transformed {len(transformed)} rows ({kev_count} KEV-flagged)")
stamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
staging_dir = REPO_ROOT / "data/datasets/distilled"
staging_dir.mkdir(parents=True, exist_ok=True)
staging_path = staging_dir / f"_upload-cve-cybersec-{stamp}.jsonl"
with staging_path.open("w", encoding="utf-8") as f:
for r in transformed:
f.write(json.dumps(r) + "\n")
print(f"Staged {len(transformed)} rows at {staging_path}")
if args.dry_run:
print("[dry-run] not uploading. First row:")
print(json.dumps(transformed[0], indent=2)[:1200])
return 0
try:
from huggingface_hub import HfApi
except ImportError:
print("ERROR: huggingface_hub not installed", file=sys.stderr)
return 1
api = HfApi(token=hf_token)
upload_path = f"distilled/cve-cybersec-{stamp}.jsonl"
print(f"Uploading {len(transformed)} rows β {args.dataset_id}:{upload_path}")
api.upload_file(
path_or_fileobj=str(staging_path),
path_in_repo=upload_path,
repo_id=args.dataset_id,
repo_type="dataset",
commit_message=(
f"cve cybersec corpus: {len(transformed)} (prompt, completion) "
f"pairs, {kev_count} KEV-flagged. teacher=mistral-medium-latest. "
f"source=training_queue@{stamp}"
),
)
print(
f"OK uploaded β "
f"https://huggingface.co/datasets/{args.dataset_id}/blob/main/{upload_path}"
)
return 0
if __name__ == "__main__":
sys.exit(main())
|