File size: 5,042 Bytes
a9141f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Run both assistants over each benchmark file and emit a JSONL of replies.

For every (model, dataset, sample) it records the assistant's text reply
plus latency / tokens / refusal flag. Guardrails are run with both ON and
OFF so the report can quantify their contribution.
"""
from __future__ import annotations

import argparse
import json
import sys
import time
from pathlib import Path
from typing import Dict, Iterable

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))

from dotenv import load_dotenv
load_dotenv(ROOT / ".env")

from app.assistants import LlamaAssistant, OpenAIAssistant  # noqa: E402
from app.assistants.base import SYSTEM_PROMPT  # noqa: E402
from app.guardrails import check_input, check_output  # noqa: E402

DATASET_DIR = ROOT / "eval" / "datasets"
RESULTS_DIR = ROOT / "eval" / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)


def _load(path: Path) -> Iterable[dict]:
    if not path.exists():
        print(f"[warn] missing {path.name} — run scripts/download_datasets.py first")
        return []
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def _one_shot(assistant, prompt: str, use_guardrails: bool) -> dict:
    if use_guardrails:
        v = check_input(prompt)
        if not v.allowed:
            return {
                "reply": v.refusal_message,
                "latency_ms": 0,
                "tokens_in": None, "tokens_out": None,
                "refused": True, "guardrail_blocked": True,
                "guardrail_category": v.category,
            }
    t0 = time.perf_counter()
    try:
        reply = assistant.chat(SYSTEM_PROMPT, [{"role": "user", "content": prompt}])
        text = reply.text
        blocked = False
        if use_guardrails:
            ov = check_output(text)
            if not ov.allowed:
                blocked = True
                text = ov.safe_text
        return {
            "reply": text,
            "latency_ms": reply.latency_ms,
            "tokens_in": reply.tokens_in,
            "tokens_out": reply.tokens_out,
            "refused": False,
            "guardrail_blocked": blocked,
            "provider": reply.provider,
        }
    except Exception as exc:
        return {
            "reply": "",
            "latency_ms": int((time.perf_counter() - t0) * 1000),
            "tokens_in": None, "tokens_out": None,
            "refused": False, "guardrail_blocked": False,
            "error": f"{type(exc).__name__}: {exc}",
        }


def run(models: Dict[str, object], datasets: list[str], use_guardrails: bool, limit: int | None) -> Path:
    suffix = "guarded" if use_guardrails else "raw"
    out_path = RESULTS_DIR / f"results-{suffix}.jsonl"
    n_done = 0
    with out_path.open("w", encoding="utf-8") as out:
        for ds in datasets:
            samples = _load(DATASET_DIR / f"{ds}.jsonl")
            if limit is not None:
                samples = samples[:limit]
            for s in samples:
                for model_name, assistant in models.items():
                    res = _one_shot(assistant, s["prompt"], use_guardrails)
                    row = {
                        "model": model_name,
                        "dataset": ds,
                        "id": s["id"],
                        "category": s.get("category"),
                        "prompt": s["prompt"],
                        "reference": s.get("reference"),
                        "use_guardrails": use_guardrails,
                        **res,
                    }
                    out.write(json.dumps(row, ensure_ascii=False) + "\n")
                    out.flush()
                    n_done += 1
                    print(f"  [{n_done}] {ds}/{model_name}/{s['id']}: {res.get('latency_ms')} ms"
                          + (" (refused)" if res.get("refused") else ""))
    return out_path


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--datasets", nargs="+", default=["truthfulqa", "advbench", "bbq"])
    ap.add_argument("--models", nargs="+", default=["openai", "llama"])
    ap.add_argument("--limit", type=int, default=None, help="max samples per dataset (debug)")
    ap.add_argument("--mode", choices=["both", "guarded", "raw"], default="both")
    args = ap.parse_args()

    loaded = {}
    if "openai" in args.models:
        print("[init] OpenAI assistant (Groq fallback if configured)…")
        loaded["openai"] = OpenAIAssistant()
    if "llama" in args.models:
        print("[init] Llama assistant (downloading + loading model)…")
        loaded["llama"] = LlamaAssistant()

    if args.mode in ("both", "guarded"):
        print("=== run: guardrails ON ===")
        run(loaded, args.datasets, use_guardrails=True, limit=args.limit)
    if args.mode in ("both", "raw"):
        print("=== run: guardrails OFF ===")
        run(loaded, args.datasets, use_guardrails=False, limit=args.limit)

    print(f"\nDone. Results under {RESULTS_DIR}/")


if __name__ == "__main__":
    main()