File size: 6,740 Bytes
5e21013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python3
"""scripts/push_cve_corpus_to_hf.py β€” pull enriched cybersec rows from
training_queue and publish them to cuilabs/bee-interactions HF dataset
in the trainer-expected schema.

Source rows must have BOTH `payload.prompt` and `payload.completion` β€”
the prompt is generated by `backfill_cve_prompts.py`, the completion by
`backfill_cve_completions.py`. Rows that have only one of the two are
silently skipped (they're either still mid-pipeline or a teacher call
failed).

Output schema mirrors what the Vertex/Kaggle trainers expect via their
`_row_user()` / `_row_assistant()` extractors:

    {
      "role": "assistant",
      "prompt": "<question>",
      "content": "<answer>",
      "domain": "cybersecurity",
      "task_type": "cve_analysis",
      "target_tiers": ["cell", "cell-plus", "comb"],
      "quality_score": 0.85,
      "feedback": null,
      "model_id": "<completion_model>",
      "sample_id": "cve:<cve_id>",
      "source": "cve_distillation:<provenance>",
      "kev_flag": <bool>,
      "cve_id": "<id>",
      "cvss_severity": "<sev>",
      "cwes": [...],
      "created_at": "<ISO ts>"
    }

The trainer filter (`is_acceptable` in workers/vertex-train/train.py)
keys on `domain` field equality plus minimum-length checks; everything
else is provenance the trainer ignores but is useful when auditing the
adapter.

Usage
-----
    python3 scripts/push_cve_corpus_to_hf.py             # full push
    python3 scripts/push_cve_corpus_to_hf.py --dry-run   # stage only
    python3 scripts/push_cve_corpus_to_hf.py --limit 10  # smoke test

Reads HF_TOKEN + POSTGRES_URL_NON_POOLING from `.env`.
"""

from __future__ import annotations

import argparse
import datetime
import json
import os
import sys
from pathlib import Path

try:
    from dotenv import load_dotenv

    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
except ImportError:
    pass

import psycopg
from psycopg import rows as psycopg_rows

REPO_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATASET_ID = "cuilabs/bee-interactions"
DEFAULT_TARGET_TIERS = ["cell", "cell-plus", "comb"]
DEFAULT_QUALITY = 0.85


def fetch_rows(conn, limit: int | None) -> list[dict]:
    sql = """
        SELECT id, external_id, payload
          FROM public.training_queue
         WHERE kind = 'cve'
           AND domain = 'cybersecurity'
           AND payload ? 'prompt'
           AND payload ? 'completion'
           AND length(payload->>'prompt') >= 40
           AND length(payload->>'completion') >= 80
         ORDER BY
            CASE WHEN (payload->>'kev')::boolean THEN 0 ELSE 1 END,
            CASE payload->>'cvss_severity'
                WHEN 'CRITICAL' THEN 1
                WHEN 'HIGH' THEN 2
                WHEN 'MEDIUM' THEN 3
                ELSE 9
            END,
            (payload->>'published') DESC NULLS LAST
    """
    if limit:
        sql += " LIMIT %s"
        params = (limit,)
    else:
        params = ()
    with conn.cursor(row_factory=psycopg_rows.dict_row) as cur:
        cur.execute(sql, params)
        return list(cur.fetchall())


def transform(row: dict) -> dict:
    p = row["payload"]
    cve_id = p.get("cve_id") or row.get("external_id") or ""
    completion_model = p.get("completion_model") or "mistral-medium-latest"
    enrich_model = p.get("enrich_model") or completion_model
    return {
        "role": "assistant",
        "prompt": p["prompt"],
        "content": p["completion"],
        "domain": "cybersecurity",
        "task_type": "cve_analysis",
        "target_tiers": DEFAULT_TARGET_TIERS,
        "quality_score": DEFAULT_QUALITY,
        "feedback": None,
        "model_id": completion_model,
        "sample_id": f"cve:{cve_id}",
        "source": f"cve_distillation:prompt={enrich_model};answer={completion_model}",
        "kev_flag": bool(p.get("kev")),
        "cve_id": cve_id,
        "cvss_severity": p.get("cvss_severity"),
        "cwes": p.get("cwes") or [],
        "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
    }


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
    parser.add_argument("--dataset-id", default=DEFAULT_DATASET_ID)
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    pg_url = (os.environ.get("POSTGRES_URL_NON_POOLING") or "").strip()
    if not pg_url:
        print("ERROR: POSTGRES_URL_NON_POOLING not set", file=sys.stderr)
        return 1
    hf_token = (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "").strip()
    if not hf_token and not args.dry_run:
        print("ERROR: HF_TOKEN not set (use --dry-run to bypass)", file=sys.stderr)
        return 1

    with psycopg.connect(pg_url, autocommit=False) as conn:
        rows = fetch_rows(conn, args.limit)

    if not rows:
        print("No rows with both prompt and completion β€” backfill incomplete?")
        return 0

    transformed = [transform(r) for r in rows]
    kev_count = sum(1 for r in transformed if r["kev_flag"])
    print(f"Transformed {len(transformed)} rows ({kev_count} KEV-flagged)")

    stamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
    staging_dir = REPO_ROOT / "data/datasets/distilled"
    staging_dir.mkdir(parents=True, exist_ok=True)
    staging_path = staging_dir / f"_upload-cve-cybersec-{stamp}.jsonl"
    with staging_path.open("w", encoding="utf-8") as f:
        for r in transformed:
            f.write(json.dumps(r) + "\n")
    print(f"Staged {len(transformed)} rows at {staging_path}")

    if args.dry_run:
        print("[dry-run] not uploading. First row:")
        print(json.dumps(transformed[0], indent=2)[:1200])
        return 0

    try:
        from huggingface_hub import HfApi
    except ImportError:
        print("ERROR: huggingface_hub not installed", file=sys.stderr)
        return 1

    api = HfApi(token=hf_token)
    upload_path = f"distilled/cve-cybersec-{stamp}.jsonl"
    print(f"Uploading {len(transformed)} rows β†’ {args.dataset_id}:{upload_path}")
    api.upload_file(
        path_or_fileobj=str(staging_path),
        path_in_repo=upload_path,
        repo_id=args.dataset_id,
        repo_type="dataset",
        commit_message=(
            f"cve cybersec corpus: {len(transformed)} (prompt, completion) "
            f"pairs, {kev_count} KEV-flagged. teacher=mistral-medium-latest. "
            f"source=training_queue@{stamp}"
        ),
    )
    print(
        f"OK uploaded β†’ "
        f"https://huggingface.co/datasets/{args.dataset_id}/blob/main/{upload_path}"
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())