File size: 10,651 Bytes
5e21013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
#!/usr/bin/env python3
"""scripts/backfill_cve_prompts.py β€” one-shot CVE training-prompt backfill.

Walks `training_queue` rows where `kind='cve'` and `payload` lacks a
`prompt` key, calls Mistral `mistral-medium-latest` (Experiment tier,
free, verified live 2026-05-04), generates an expert-level cybersecurity
training prompt, and writes it back into the row's payload via UPDATE.

Why this exists
---------------
The cve-ingest cron used to dump raw CVEs into `training_queue` with no
consumer-friendly shape. Now that gemma3:12b enriches new rows at ingest
time, ~1,100 historical rows still sit raw β€” they predate the
enrichment step and NVD's `lastModStartDate` window won't re-surface
them. This script clears the backlog in one go on Mistral's free
Experiment tier.

Quality filter (matches the ranking I gave Christopher 2026-05-04):
  - description >= 150 chars (skips title-only rows)
  - severity in {CRITICAL, HIGH, MEDIUM} (skips LOW + missing)
  - ordered by severity then recency so we get the best CVEs first
    even if you ctrl-C halfway

Idempotent: re-run after interruption; the WHERE clause picks up where
it left off because the UPDATE adds `payload.prompt`.

Usage
-----
    python3 scripts/backfill_cve_prompts.py                # all rows
    python3 scripts/backfill_cve_prompts.py --limit 50     # smoke test
    python3 scripts/backfill_cve_prompts.py --dry-run      # count only
    BEE_BACKFILL_MODEL=mistral-large-latest python3 scripts/backfill_cve_prompts.py

Reads BEE_MISTRAL_API_KEY + POSTGRES_URL_NON_POOLING from `.env`.

Throughput
----------
Mistral Experiment tier caps at 23 req/min account-wide; we pace at 20
to leave headroom. Expect ~50 minutes for the full ~1,100-row backlog.
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path

# Load `.env` if present so the script "just works" from the repo root.
try:
    from dotenv import load_dotenv

    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
except ImportError:
    pass  # not fatal β€” env may already be exported

import psycopg
from psycopg import rows as psycopg_rows

MISTRAL_ENDPOINT = "https://api.mistral.ai/v1/chat/completions"
DEFAULT_MODEL = "mistral-medium-latest"

# Mistral Experiment tier: 23 req/min account-wide (verified live via
# x-ratelimit-limit-req-minute header). Pace at 20 to leave headroom for
# any concurrent cron call.
RATE_LIMIT_RPM = 20
RATE_INTERVAL_S = 60.0 / RATE_LIMIT_RPM

SYSTEM_PROMPT = (
    "You generate concise, expert-level cybersecurity training prompts. "
    "Given a raw CVE record, write a self-contained question or analytical "
    "scenario that a senior security engineer would use to teach the "
    "vulnerability β€” root cause, exploitation pattern, mitigation, and "
    "detection signals. Output ONLY the prompt body, no preface, no JSON, "
    "no markdown fences. 2-5 sentences total."
)


def build_user_prompt(payload: dict) -> str:
    cwes = payload.get("cwes") or []
    return (
        f"CVE: {payload.get('cve_id', '?')}\n"
        f"CVSS: {payload.get('cvss_score', 'n/a')} "
        f"({payload.get('cvss_severity', '?')})\n"
        f"CWEs: {', '.join(cwes) if cwes else 'none listed'}\n\n"
        f"Description:\n{payload.get('description', '')}"
    )


def strip_markdown_fences(s: str) -> str:
    """Some models wrap output in ```…``` even when asked not to."""
    s = s.strip()
    if not s.startswith("```"):
        return s
    parts = s.split("```")
    if len(parts) >= 3:
        inner = parts[1]
        # Drop a leading language tag like "json\n" or "text\n"
        if "\n" in inner:
            first, rest = inner.split("\n", 1)
            if not first.strip() or first.strip().isalpha():
                return rest.strip()
        return inner.strip()
    return s


def call_mistral(
    api_key: str,
    model: str,
    user_prompt: str,
    timeout_s: int = 60,
) -> tuple[str | None, str | None]:
    """Returns (content, error_kind). error_kind is one of:
    None, '429', 'http_other', 'fetch', 'empty'.
    """
    body = json.dumps(
        {
            "model": model,
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            "max_tokens": 600,
            "temperature": 0.5,
        }
    ).encode("utf-8")
    req = urllib.request.Request(
        MISTRAL_ENDPOINT,
        data=body,
        method="POST",
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        },
    )
    try:
        with urllib.request.urlopen(req, timeout=timeout_s) as resp:
            data = json.loads(resp.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        if e.code == 429:
            return None, "429"
        msg = ""
        try:
            msg = e.read().decode("utf-8")[:200]
        except Exception:
            msg = ""
        print(f"  ! HTTP {e.code}: {msg}", file=sys.stderr)
        return None, "http_other"
    except Exception as e:
        print(f"  ! fetch error: {e}", file=sys.stderr)
        return None, "fetch"

    content = (data.get("choices") or [{}])[0].get("message", {}).get("content", "")
    content = strip_markdown_fences(content)
    if not content or len(content) < 24:
        return None, "empty"
    return content, None


def count_pending(conn) -> int:
    with conn.cursor() as cur:
        cur.execute(
            """
            SELECT count(*)
              FROM public.training_queue
             WHERE kind = 'cve'
               AND NOT (payload ? 'prompt')
               AND length(payload->>'description') >= 150
               AND COALESCE(payload->>'cvss_severity', '') NOT IN ('LOW', '')
            """
        )
        return cur.fetchone()[0]


def fetch_rows(conn, limit: int) -> list[dict]:
    sql = """
        SELECT id, external_id, payload
          FROM public.training_queue
         WHERE kind = 'cve'
           AND NOT (payload ? 'prompt')
           AND length(payload->>'description') >= 150
           AND COALESCE(payload->>'cvss_severity', '') NOT IN ('LOW', '')
         ORDER BY
            CASE payload->>'cvss_severity'
                WHEN 'CRITICAL' THEN 1
                WHEN 'HIGH' THEN 2
                WHEN 'MEDIUM' THEN 3
                ELSE 9
            END,
            (payload->>'published') DESC NULLS LAST
         LIMIT %s
    """
    with conn.cursor(row_factory=psycopg_rows.dict_row) as cur:
        cur.execute(sql, (limit,))
        return list(cur.fetchall())


def update_row(conn, row_id: int, prompt: str, model: str) -> None:
    sql = """
        UPDATE public.training_queue
           SET payload = payload
                         || jsonb_build_object('prompt', %s::text)
                         || jsonb_build_object('enrich_model', %s::text)
         WHERE id = %s
    """
    with conn.cursor() as cur:
        cur.execute(sql, (prompt, model, row_id))
    conn.commit()


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
    parser.add_argument("--limit", type=int, default=None, help="cap total rows enriched")
    parser.add_argument(
        "--batch", type=int, default=50, help="DB fetch batch size (rows per loop)"
    )
    parser.add_argument(
        "--model",
        default=os.environ.get("BEE_BACKFILL_MODEL", DEFAULT_MODEL),
        help="model name (default: BEE_BACKFILL_MODEL or mistral-medium-latest)",
    )
    parser.add_argument(
        "--dry-run", action="store_true", help="count pending rows and exit without enriching"
    )
    args = parser.parse_args()

    api_key = (os.environ.get("BEE_MISTRAL_API_KEY") or "").strip()
    if not api_key:
        print("ERROR: BEE_MISTRAL_API_KEY not set (.env or environment)", file=sys.stderr)
        return 1
    pg_url = (os.environ.get("POSTGRES_URL_NON_POOLING") or "").strip()
    if not pg_url:
        print("ERROR: POSTGRES_URL_NON_POOLING not set", file=sys.stderr)
        return 1

    print(f"Backfill β€” model={args.model}  batch={args.batch}  pace={RATE_LIMIT_RPM} req/min")

    started = time.monotonic()
    enriched = 0
    skipped = 0
    rate_limited = 0
    last_call = 0.0

    with psycopg.connect(pg_url, autocommit=False) as conn:
        pending = count_pending(conn)
        print(f"  pending rows worth enriching: {pending}")
        if args.dry_run:
            print("dry-run; exiting")
            return 0

        target = min(args.limit, pending) if args.limit else pending
        if target == 0:
            print("nothing to do")
            return 0
        print(f"  target this run: {target}")
        print()

        while enriched + skipped < target:
            remaining = target - enriched - skipped
            rows = fetch_rows(conn, min(args.batch, remaining))
            if not rows:
                break
            for row in rows:
                # Pace per-call so we never exceed Mistral's 23 RPM.
                elapsed = time.monotonic() - last_call
                if elapsed < RATE_INTERVAL_S:
                    time.sleep(RATE_INTERVAL_S - elapsed)
                last_call = time.monotonic()

                content, err = call_mistral(
                    api_key, args.model, build_user_prompt(row["payload"])
                )
                if err == "429":
                    rate_limited += 1
                    print("  ! 429 β€” backing off 12s")
                    time.sleep(12.0)
                    continue
                if not content:
                    skipped += 1
                    continue

                update_row(conn, row["id"], content, args.model)
                enriched += 1

                if enriched % 10 == 0 or enriched == target:
                    elapsed_min = (time.monotonic() - started) / 60.0
                    rate = enriched / elapsed_min if elapsed_min > 0 else 0
                    eta_min = (target - enriched) / rate if rate > 0 else 0
                    print(
                        f"  enriched {enriched}/{target}  "
                        f"(skipped {skipped}, 429s {rate_limited}, "
                        f"~{rate:.1f}/min, ETA {eta_min:.1f}min)"
                    )

    elapsed_total = time.monotonic() - started
    print()
    print(
        f"Done. enriched={enriched}  skipped={skipped}  "
        f"rate_limited={rate_limited}  in {elapsed_total/60:.1f} min"
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())