| |
| """scripts/backfill_cve_prompts.py β one-shot CVE training-prompt backfill. |
| |
| Walks `training_queue` rows where `kind='cve'` and `payload` lacks a |
| `prompt` key, calls Mistral `mistral-medium-latest` (Experiment tier, |
| free, verified live 2026-05-04), generates an expert-level cybersecurity |
| training prompt, and writes it back into the row's payload via UPDATE. |
| |
| Why this exists |
| --------------- |
| The cve-ingest cron used to dump raw CVEs into `training_queue` with no |
| consumer-friendly shape. Now that gemma3:12b enriches new rows at ingest |
| time, ~1,100 historical rows still sit raw β they predate the |
| enrichment step and NVD's `lastModStartDate` window won't re-surface |
| them. This script clears the backlog in one go on Mistral's free |
| Experiment tier. |
| |
| Quality filter (matches the ranking I gave Christopher 2026-05-04): |
| - description >= 150 chars (skips title-only rows) |
| - severity in {CRITICAL, HIGH, MEDIUM} (skips LOW + missing) |
| - ordered by severity then recency so we get the best CVEs first |
| even if you ctrl-C halfway |
| |
| Idempotent: re-run after interruption; the WHERE clause picks up where |
| it left off because the UPDATE adds `payload.prompt`. |
| |
| Usage |
| ----- |
| python3 scripts/backfill_cve_prompts.py # all rows |
| python3 scripts/backfill_cve_prompts.py --limit 50 # smoke test |
| python3 scripts/backfill_cve_prompts.py --dry-run # count only |
| BEE_BACKFILL_MODEL=mistral-large-latest python3 scripts/backfill_cve_prompts.py |
| |
| Reads BEE_MISTRAL_API_KEY + POSTGRES_URL_NON_POOLING from `.env`. |
| |
| Throughput |
| ---------- |
| Mistral Experiment tier caps at 23 req/min account-wide; we pace at 20 |
| to leave headroom. Expect ~50 minutes for the full ~1,100-row backlog. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| import urllib.error |
| import urllib.request |
| from pathlib import Path |
|
|
| |
| try: |
| from dotenv import load_dotenv |
|
|
| load_dotenv(Path(__file__).resolve().parent.parent / ".env") |
| except ImportError: |
| pass |
|
|
| import psycopg |
| from psycopg import rows as psycopg_rows |
|
|
| MISTRAL_ENDPOINT = "https://api.mistral.ai/v1/chat/completions" |
| DEFAULT_MODEL = "mistral-medium-latest" |
|
|
| |
| |
| |
| RATE_LIMIT_RPM = 20 |
| RATE_INTERVAL_S = 60.0 / RATE_LIMIT_RPM |
|
|
| SYSTEM_PROMPT = ( |
| "You generate concise, expert-level cybersecurity training prompts. " |
| "Given a raw CVE record, write a self-contained question or analytical " |
| "scenario that a senior security engineer would use to teach the " |
| "vulnerability β root cause, exploitation pattern, mitigation, and " |
| "detection signals. Output ONLY the prompt body, no preface, no JSON, " |
| "no markdown fences. 2-5 sentences total." |
| ) |
|
|
|
|
| def build_user_prompt(payload: dict) -> str: |
| cwes = payload.get("cwes") or [] |
| return ( |
| f"CVE: {payload.get('cve_id', '?')}\n" |
| f"CVSS: {payload.get('cvss_score', 'n/a')} " |
| f"({payload.get('cvss_severity', '?')})\n" |
| f"CWEs: {', '.join(cwes) if cwes else 'none listed'}\n\n" |
| f"Description:\n{payload.get('description', '')}" |
| ) |
|
|
|
|
| def strip_markdown_fences(s: str) -> str: |
| """Some models wrap output in ```β¦``` even when asked not to.""" |
| s = s.strip() |
| if not s.startswith("```"): |
| return s |
| parts = s.split("```") |
| if len(parts) >= 3: |
| inner = parts[1] |
| |
| if "\n" in inner: |
| first, rest = inner.split("\n", 1) |
| if not first.strip() or first.strip().isalpha(): |
| return rest.strip() |
| return inner.strip() |
| return s |
|
|
|
|
| def call_mistral( |
| api_key: str, |
| model: str, |
| user_prompt: str, |
| timeout_s: int = 60, |
| ) -> tuple[str | None, str | None]: |
| """Returns (content, error_kind). error_kind is one of: |
| None, '429', 'http_other', 'fetch', 'empty'. |
| """ |
| body = json.dumps( |
| { |
| "model": model, |
| "messages": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| "max_tokens": 600, |
| "temperature": 0.5, |
| } |
| ).encode("utf-8") |
| req = urllib.request.Request( |
| MISTRAL_ENDPOINT, |
| data=body, |
| method="POST", |
| headers={ |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json", |
| }, |
| ) |
| try: |
| with urllib.request.urlopen(req, timeout=timeout_s) as resp: |
| data = json.loads(resp.read().decode("utf-8")) |
| except urllib.error.HTTPError as e: |
| if e.code == 429: |
| return None, "429" |
| msg = "" |
| try: |
| msg = e.read().decode("utf-8")[:200] |
| except Exception: |
| msg = "" |
| print(f" ! HTTP {e.code}: {msg}", file=sys.stderr) |
| return None, "http_other" |
| except Exception as e: |
| print(f" ! fetch error: {e}", file=sys.stderr) |
| return None, "fetch" |
|
|
| content = (data.get("choices") or [{}])[0].get("message", {}).get("content", "") |
| content = strip_markdown_fences(content) |
| if not content or len(content) < 24: |
| return None, "empty" |
| return content, None |
|
|
|
|
| def count_pending(conn) -> int: |
| with conn.cursor() as cur: |
| cur.execute( |
| """ |
| SELECT count(*) |
| FROM public.training_queue |
| WHERE kind = 'cve' |
| AND NOT (payload ? 'prompt') |
| AND length(payload->>'description') >= 150 |
| AND COALESCE(payload->>'cvss_severity', '') NOT IN ('LOW', '') |
| """ |
| ) |
| return cur.fetchone()[0] |
|
|
|
|
| def fetch_rows(conn, limit: int) -> list[dict]: |
| sql = """ |
| SELECT id, external_id, payload |
| FROM public.training_queue |
| WHERE kind = 'cve' |
| AND NOT (payload ? 'prompt') |
| AND length(payload->>'description') >= 150 |
| AND COALESCE(payload->>'cvss_severity', '') NOT IN ('LOW', '') |
| ORDER BY |
| CASE payload->>'cvss_severity' |
| WHEN 'CRITICAL' THEN 1 |
| WHEN 'HIGH' THEN 2 |
| WHEN 'MEDIUM' THEN 3 |
| ELSE 9 |
| END, |
| (payload->>'published') DESC NULLS LAST |
| LIMIT %s |
| """ |
| with conn.cursor(row_factory=psycopg_rows.dict_row) as cur: |
| cur.execute(sql, (limit,)) |
| return list(cur.fetchall()) |
|
|
|
|
| def update_row(conn, row_id: int, prompt: str, model: str) -> None: |
| sql = """ |
| UPDATE public.training_queue |
| SET payload = payload |
| || jsonb_build_object('prompt', %s::text) |
| || jsonb_build_object('enrich_model', %s::text) |
| WHERE id = %s |
| """ |
| with conn.cursor() as cur: |
| cur.execute(sql, (prompt, model, row_id)) |
| conn.commit() |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) |
| parser.add_argument("--limit", type=int, default=None, help="cap total rows enriched") |
| parser.add_argument( |
| "--batch", type=int, default=50, help="DB fetch batch size (rows per loop)" |
| ) |
| parser.add_argument( |
| "--model", |
| default=os.environ.get("BEE_BACKFILL_MODEL", DEFAULT_MODEL), |
| help="model name (default: BEE_BACKFILL_MODEL or mistral-medium-latest)", |
| ) |
| parser.add_argument( |
| "--dry-run", action="store_true", help="count pending rows and exit without enriching" |
| ) |
| args = parser.parse_args() |
|
|
| api_key = (os.environ.get("BEE_MISTRAL_API_KEY") or "").strip() |
| if not api_key: |
| print("ERROR: BEE_MISTRAL_API_KEY not set (.env or environment)", file=sys.stderr) |
| return 1 |
| pg_url = (os.environ.get("POSTGRES_URL_NON_POOLING") or "").strip() |
| if not pg_url: |
| print("ERROR: POSTGRES_URL_NON_POOLING not set", file=sys.stderr) |
| return 1 |
|
|
| print(f"Backfill β model={args.model} batch={args.batch} pace={RATE_LIMIT_RPM} req/min") |
|
|
| started = time.monotonic() |
| enriched = 0 |
| skipped = 0 |
| rate_limited = 0 |
| last_call = 0.0 |
|
|
| with psycopg.connect(pg_url, autocommit=False) as conn: |
| pending = count_pending(conn) |
| print(f" pending rows worth enriching: {pending}") |
| if args.dry_run: |
| print("dry-run; exiting") |
| return 0 |
|
|
| target = min(args.limit, pending) if args.limit else pending |
| if target == 0: |
| print("nothing to do") |
| return 0 |
| print(f" target this run: {target}") |
| print() |
|
|
| while enriched + skipped < target: |
| remaining = target - enriched - skipped |
| rows = fetch_rows(conn, min(args.batch, remaining)) |
| if not rows: |
| break |
| for row in rows: |
| |
| elapsed = time.monotonic() - last_call |
| if elapsed < RATE_INTERVAL_S: |
| time.sleep(RATE_INTERVAL_S - elapsed) |
| last_call = time.monotonic() |
|
|
| content, err = call_mistral( |
| api_key, args.model, build_user_prompt(row["payload"]) |
| ) |
| if err == "429": |
| rate_limited += 1 |
| print(" ! 429 β backing off 12s") |
| time.sleep(12.0) |
| continue |
| if not content: |
| skipped += 1 |
| continue |
|
|
| update_row(conn, row["id"], content, args.model) |
| enriched += 1 |
|
|
| if enriched % 10 == 0 or enriched == target: |
| elapsed_min = (time.monotonic() - started) / 60.0 |
| rate = enriched / elapsed_min if elapsed_min > 0 else 0 |
| eta_min = (target - enriched) / rate if rate > 0 else 0 |
| print( |
| f" enriched {enriched}/{target} " |
| f"(skipped {skipped}, 429s {rate_limited}, " |
| f"~{rate:.1f}/min, ETA {eta_min:.1f}min)" |
| ) |
|
|
| elapsed_total = time.monotonic() - started |
| print() |
| print( |
| f"Done. enriched={enriched} skipped={skipped} " |
| f"rate_limited={rate_limited} in {elapsed_total/60:.1f} min" |
| ) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|