bee / scripts /backfill_cve_prompts.py
Bee Deploy
HF Space backend deploy [de0cba5]
5e21013
#!/usr/bin/env python3
"""scripts/backfill_cve_prompts.py β€” one-shot CVE training-prompt backfill.
Walks `training_queue` rows where `kind='cve'` and `payload` lacks a
`prompt` key, calls Mistral `mistral-medium-latest` (Experiment tier,
free, verified live 2026-05-04), generates an expert-level cybersecurity
training prompt, and writes it back into the row's payload via UPDATE.
Why this exists
---------------
The cve-ingest cron used to dump raw CVEs into `training_queue` with no
consumer-friendly shape. Now that gemma3:12b enriches new rows at ingest
time, ~1,100 historical rows still sit raw β€” they predate the
enrichment step and NVD's `lastModStartDate` window won't re-surface
them. This script clears the backlog in one go on Mistral's free
Experiment tier.
Quality filter (matches the ranking I gave Christopher 2026-05-04):
- description >= 150 chars (skips title-only rows)
- severity in {CRITICAL, HIGH, MEDIUM} (skips LOW + missing)
- ordered by severity then recency so we get the best CVEs first
even if you ctrl-C halfway
Idempotent: re-run after interruption; the WHERE clause picks up where
it left off because the UPDATE adds `payload.prompt`.
Usage
-----
python3 scripts/backfill_cve_prompts.py # all rows
python3 scripts/backfill_cve_prompts.py --limit 50 # smoke test
python3 scripts/backfill_cve_prompts.py --dry-run # count only
BEE_BACKFILL_MODEL=mistral-large-latest python3 scripts/backfill_cve_prompts.py
Reads BEE_MISTRAL_API_KEY + POSTGRES_URL_NON_POOLING from `.env`.
Throughput
----------
Mistral Experiment tier caps at 23 req/min account-wide; we pace at 20
to leave headroom. Expect ~50 minutes for the full ~1,100-row backlog.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path
# Load `.env` if present so the script "just works" from the repo root.
try:
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
except ImportError:
pass # not fatal β€” env may already be exported
import psycopg
from psycopg import rows as psycopg_rows
MISTRAL_ENDPOINT = "https://api.mistral.ai/v1/chat/completions"
DEFAULT_MODEL = "mistral-medium-latest"
# Mistral Experiment tier: 23 req/min account-wide (verified live via
# x-ratelimit-limit-req-minute header). Pace at 20 to leave headroom for
# any concurrent cron call.
RATE_LIMIT_RPM = 20
RATE_INTERVAL_S = 60.0 / RATE_LIMIT_RPM
SYSTEM_PROMPT = (
"You generate concise, expert-level cybersecurity training prompts. "
"Given a raw CVE record, write a self-contained question or analytical "
"scenario that a senior security engineer would use to teach the "
"vulnerability β€” root cause, exploitation pattern, mitigation, and "
"detection signals. Output ONLY the prompt body, no preface, no JSON, "
"no markdown fences. 2-5 sentences total."
)
def build_user_prompt(payload: dict) -> str:
cwes = payload.get("cwes") or []
return (
f"CVE: {payload.get('cve_id', '?')}\n"
f"CVSS: {payload.get('cvss_score', 'n/a')} "
f"({payload.get('cvss_severity', '?')})\n"
f"CWEs: {', '.join(cwes) if cwes else 'none listed'}\n\n"
f"Description:\n{payload.get('description', '')}"
)
def strip_markdown_fences(s: str) -> str:
"""Some models wrap output in ```…``` even when asked not to."""
s = s.strip()
if not s.startswith("```"):
return s
parts = s.split("```")
if len(parts) >= 3:
inner = parts[1]
# Drop a leading language tag like "json\n" or "text\n"
if "\n" in inner:
first, rest = inner.split("\n", 1)
if not first.strip() or first.strip().isalpha():
return rest.strip()
return inner.strip()
return s
def call_mistral(
api_key: str,
model: str,
user_prompt: str,
timeout_s: int = 60,
) -> tuple[str | None, str | None]:
"""Returns (content, error_kind). error_kind is one of:
None, '429', 'http_other', 'fetch', 'empty'.
"""
body = json.dumps(
{
"model": model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
"max_tokens": 600,
"temperature": 0.5,
}
).encode("utf-8")
req = urllib.request.Request(
MISTRAL_ENDPOINT,
data=body,
method="POST",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
)
try:
with urllib.request.urlopen(req, timeout=timeout_s) as resp:
data = json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
if e.code == 429:
return None, "429"
msg = ""
try:
msg = e.read().decode("utf-8")[:200]
except Exception:
msg = ""
print(f" ! HTTP {e.code}: {msg}", file=sys.stderr)
return None, "http_other"
except Exception as e:
print(f" ! fetch error: {e}", file=sys.stderr)
return None, "fetch"
content = (data.get("choices") or [{}])[0].get("message", {}).get("content", "")
content = strip_markdown_fences(content)
if not content or len(content) < 24:
return None, "empty"
return content, None
def count_pending(conn) -> int:
with conn.cursor() as cur:
cur.execute(
"""
SELECT count(*)
FROM public.training_queue
WHERE kind = 'cve'
AND NOT (payload ? 'prompt')
AND length(payload->>'description') >= 150
AND COALESCE(payload->>'cvss_severity', '') NOT IN ('LOW', '')
"""
)
return cur.fetchone()[0]
def fetch_rows(conn, limit: int) -> list[dict]:
sql = """
SELECT id, external_id, payload
FROM public.training_queue
WHERE kind = 'cve'
AND NOT (payload ? 'prompt')
AND length(payload->>'description') >= 150
AND COALESCE(payload->>'cvss_severity', '') NOT IN ('LOW', '')
ORDER BY
CASE payload->>'cvss_severity'
WHEN 'CRITICAL' THEN 1
WHEN 'HIGH' THEN 2
WHEN 'MEDIUM' THEN 3
ELSE 9
END,
(payload->>'published') DESC NULLS LAST
LIMIT %s
"""
with conn.cursor(row_factory=psycopg_rows.dict_row) as cur:
cur.execute(sql, (limit,))
return list(cur.fetchall())
def update_row(conn, row_id: int, prompt: str, model: str) -> None:
sql = """
UPDATE public.training_queue
SET payload = payload
|| jsonb_build_object('prompt', %s::text)
|| jsonb_build_object('enrich_model', %s::text)
WHERE id = %s
"""
with conn.cursor() as cur:
cur.execute(sql, (prompt, model, row_id))
conn.commit()
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
parser.add_argument("--limit", type=int, default=None, help="cap total rows enriched")
parser.add_argument(
"--batch", type=int, default=50, help="DB fetch batch size (rows per loop)"
)
parser.add_argument(
"--model",
default=os.environ.get("BEE_BACKFILL_MODEL", DEFAULT_MODEL),
help="model name (default: BEE_BACKFILL_MODEL or mistral-medium-latest)",
)
parser.add_argument(
"--dry-run", action="store_true", help="count pending rows and exit without enriching"
)
args = parser.parse_args()
api_key = (os.environ.get("BEE_MISTRAL_API_KEY") or "").strip()
if not api_key:
print("ERROR: BEE_MISTRAL_API_KEY not set (.env or environment)", file=sys.stderr)
return 1
pg_url = (os.environ.get("POSTGRES_URL_NON_POOLING") or "").strip()
if not pg_url:
print("ERROR: POSTGRES_URL_NON_POOLING not set", file=sys.stderr)
return 1
print(f"Backfill β€” model={args.model} batch={args.batch} pace={RATE_LIMIT_RPM} req/min")
started = time.monotonic()
enriched = 0
skipped = 0
rate_limited = 0
last_call = 0.0
with psycopg.connect(pg_url, autocommit=False) as conn:
pending = count_pending(conn)
print(f" pending rows worth enriching: {pending}")
if args.dry_run:
print("dry-run; exiting")
return 0
target = min(args.limit, pending) if args.limit else pending
if target == 0:
print("nothing to do")
return 0
print(f" target this run: {target}")
print()
while enriched + skipped < target:
remaining = target - enriched - skipped
rows = fetch_rows(conn, min(args.batch, remaining))
if not rows:
break
for row in rows:
# Pace per-call so we never exceed Mistral's 23 RPM.
elapsed = time.monotonic() - last_call
if elapsed < RATE_INTERVAL_S:
time.sleep(RATE_INTERVAL_S - elapsed)
last_call = time.monotonic()
content, err = call_mistral(
api_key, args.model, build_user_prompt(row["payload"])
)
if err == "429":
rate_limited += 1
print(" ! 429 β€” backing off 12s")
time.sleep(12.0)
continue
if not content:
skipped += 1
continue
update_row(conn, row["id"], content, args.model)
enriched += 1
if enriched % 10 == 0 or enriched == target:
elapsed_min = (time.monotonic() - started) / 60.0
rate = enriched / elapsed_min if elapsed_min > 0 else 0
eta_min = (target - enriched) / rate if rate > 0 else 0
print(
f" enriched {enriched}/{target} "
f"(skipped {skipped}, 429s {rate_limited}, "
f"~{rate:.1f}/min, ETA {eta_min:.1f}min)"
)
elapsed_total = time.monotonic() - started
print()
print(
f"Done. enriched={enriched} skipped={skipped} "
f"rate_limited={rate_limited} in {elapsed_total/60:.1f} min"
)
return 0
if __name__ == "__main__":
sys.exit(main())