File size: 15,467 Bytes
f23deb1
 
 
 
 
ac224ce
 
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac224ce
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
from __future__ import annotations

import sys

"""
outputs/eval_agent.py
---------------------
Zero-shot eval agent using GPT-4o-mini to diagnose and fix broken RAG pipelines.

Purpose: validate the environment end-to-end β€” confirm reward signals are
meaningful, observations are interpretable, and the tasks are solvable by a
capable model before committing to GRPO training.

Usage:
    # Server must be running first:
    #   uvicorn server.app:app --host 0.0.0.0 --port 8000

    python baseline/eval_agent.py --task 1 --episodes 3
    python baseline/eval_agent.py --task all --episodes 2 --verbose
    python baseline/eval_agent.py --task 2 --seed 42 --server http://localhost:8000

Requirements:
    OPENAI_API_KEY environment variable must be set.
    pip install openai
"""


import argparse
import os
import time
from enum import Enum
from typing import Optional

from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel

from client import RAGDebugEnv
from models import RAGDebugAction, RAGDebugObservation
#  RAGDebugEnv, RAGDebugAction, RAGDebugObservation

load_dotenv()

# ---------------------------------------------------------------------------
# OpenAI structured output schema
# ---------------------------------------------------------------------------

class _ActionType(str, Enum):
    ADJUST_CHUNK_SIZE    = "adjust_chunk_size"
    ADJUST_CHUNK_OVERLAP = "adjust_chunk_overlap"
    ADJUST_THRESHOLD     = "adjust_threshold"
    ADJUST_TOP_K         = "adjust_top_k"
    SWAP_EMBEDDING_MODEL = "swap_embedding_model"
    TOGGLE_RERANKING     = "toggle_reranking"
    ADJUST_CONTEXT_LIMIT = "adjust_context_limit"
    REWRITE_QUERY        = "rewrite_query"
    SUBMIT               = "submit"


class AgentDecision(BaseModel):
    """Structured output schema enforced by OpenAI's API."""
    reasoning: str
    action_type: _ActionType
    # Flat param fields β€” fill only the one(s) relevant to your action_type.
    # int_value   : chunk_size, top_k, context_limit, chunk_overlap
    # float_value : similarity_threshold
    # model_name  : embedding model ("general" | "medical" | "legal" | "code")
    # enabled     : reranking toggle (True/False)
    # query_id    : query to rewrite
    int_value:   Optional[int]   = None
    float_value: Optional[float] = None
    model_name:  Optional[str]   = None
    enabled:     Optional[bool]  = None
    query_id:    Optional[int]   = None


# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------

_SYSTEM_PROMPT = """\
You are an expert RAG (Retrieval-Augmented Generation) pipeline debugger.

Your job is to diagnose why a RAG pipeline is performing poorly and take
corrective actions to restore retrieval quality. You will be given an
observation describing the current pipeline state, per-query results, and
aggregate metrics.

## Available Actions

| Action               | Required param       | Effect                              |
|----------------------|----------------------|-------------------------------------|
| adjust_chunk_size    | int_value (64-2048)  | Change chunk size                   |
| adjust_chunk_overlap | int_value (0-500)    | Change chunk overlap                |
| adjust_threshold     | float_value (0.0-1.0)| Change similarity threshold         |
| adjust_top_k         | int_value (1-50)     | Change number of retrieved chunks   |
| swap_embedding_model | model_name           | Switch embedding model              |
| toggle_reranking     | enabled (bool)       | Enable/disable cross-encoder rerank |
| adjust_context_limit | int_value (512-16384)| Change context window limit         |
| rewrite_query        | query_id (int)       | Boost a specific query              |
| submit               | (none)               | Submit β€” ends the episode           |

## Embedding Models
- "general"  β€” all-purpose (sentence-transformers/all-MiniLM-L6-v2)
- "medical"  β€” biomedical text (PubMedBert-MS-MARCO)
- "legal"    β€” legal documents (legal-bert-base-uncased)
- "code"     β€” code + docstrings (codebert-base)

## Diagnostic Heuristics
- Low coverage + low precision + many empty retrievals β†’ threshold may be too high, or top_k too small
- Low coverage + moderate precision β†’ top_k too small, or embedding model mismatch
- Many retrieved chunks but low coverage β†’ duplicate flooding, or threshold too low letting noise through
- Score distribution compressed (all scores similar) β†’ wrong embedding model, or chunk too large
- Coverage plateaus despite config changes β†’ wrong embedding model (especially on domain-specific text)
- Context overflow β†’ increase context_limit or decrease top_k
- Submit only when mean_coverage >= 0.70 and no empty retrievals

Fill in only the param field relevant to your chosen action. Leave others as null.
"""


# ---------------------------------------------------------------------------
# Observation formatter
# ---------------------------------------------------------------------------

def _format_observation(obs: RAGDebugObservation, action_history: list[dict]) -> str:
    """Convert an observation to a readable prompt string for the model."""
    cfg = obs.pipeline_config
    m   = obs.metrics
    cs  = obs.corpus_stats

    lines = [
        f"## Task {obs.task_id}: {obs.task_description}",
        f"Step {obs.steps_taken} / {obs.max_steps}",
        "",
        "## Current Pipeline Config",
        f"  chunk_size        = {cfg.chunk_size}",
        f"  chunk_overlap     = {cfg.chunk_overlap}",
        f"  similarity_threshold = {cfg.similarity_threshold}",
        f"  top_k             = {cfg.top_k}",
        f"  embedding_model   = {cfg.embedding_model.value}",
        f"  use_reranking     = {cfg.use_reranking}",
        f"  context_window_limit = {cfg.context_window_limit}",
        "",
        "## Corpus Info",
        f"  domain = {cs.domain.value}  |  {cs.n_chunks} chunks  |  {cs.n_queries} queries",
        f"  multi-hop queries: {cs.n_multi_hop_queries}",
        "",
        "## Aggregate Metrics",
        f"  mean_coverage    = {m.mean_coverage:.3f}",
        f"  mean_precision   = {m.mean_precision:.3f}",
        f"  empty retrievals = {m.n_empty_retrievals}",
        f"  context overflows = {m.n_context_overflows}",
    ]
    if m.multi_hop_coverage is not None:
        lines.append(f"  multi_hop_coverage = {m.multi_hop_coverage:.3f}")

    lines += ["", "## Per-Query Results"]
    for qr in obs.query_results:
        mh_tag = " [multi-hop]" if qr.is_multi_hop else ""
        score_summary = ""
        if qr.retrieval_scores:
            score_summary = (
                f"  scores: min={min(qr.retrieval_scores):.3f} "
                f"max={max(qr.retrieval_scores):.3f} "
                f"mean={sum(qr.retrieval_scores)/len(qr.retrieval_scores):.3f}"
            )
        lines.append(
            f"  Q{qr.query_id}{mh_tag}: coverage={qr.coverage_score:.3f} "
            f"precision={qr.precision_score:.3f} "
            f"retrieved={qr.n_retrieved}{score_summary}"
        )
        if qr.n_retrieved == 0:
            lines.append(f"    !! empty retrieval β€” no chunks above threshold")

    if action_history:
        lines += ["", "## Actions Taken So Far"]
        for i, ah in enumerate(action_history, 1):
            lines.append(f"  {i}. {ah['action_type']}({ah.get('params', {})})  reward={ah['reward']:+.3f}")

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Action builder
# ---------------------------------------------------------------------------

def _decision_to_action(decision: AgentDecision) -> RAGDebugAction:
    """Convert AgentDecision (structured output) to RAGDebugAction."""
    at = decision.action_type.value
    params: dict = {}

    if at in ("adjust_chunk_size", "adjust_top_k", "adjust_context_limit", "adjust_chunk_overlap"):
        if decision.int_value is not None:
            params["value"] = decision.int_value

    elif at == "adjust_threshold":
        if decision.float_value is not None:
            params["value"] = decision.float_value

    elif at == "swap_embedding_model":
        if decision.model_name:
            params["model"] = decision.model_name

    elif at == "toggle_reranking":
        if decision.enabled is not None:
            params["enabled"] = decision.enabled

    elif at == "rewrite_query":
        if decision.query_id is not None:
            params["query_id"] = decision.query_id

    # submit: no params needed

    return RAGDebugAction(action_type=at, params=params)


# ---------------------------------------------------------------------------
# Single episode
# ---------------------------------------------------------------------------

def run_episode(
    client: OpenAI,
    env: RAGDebugEnv,
    task_id: int,
    seed: Optional[int],
    episode_num: int,
    verbose: bool = False,
) -> dict:
    """
    Run one episode and return a result dict.

    Returns
    -------
    {task_id, episode, seed, steps, final_coverage, final_precision,
     success, total_reward, actions}
    """
    reset_kwargs: dict = {"task_id": task_id}
    if seed is not None:
        reset_kwargs["seed"] = seed

    result = env.reset(**reset_kwargs)
    obs: RAGDebugObservation = result.observation

    action_history: list[dict] = []
    total_reward = 0.0
    success = False

    print(f"\n  Episode {episode_num} (task={task_id})")
    print(f"  {'─'*50}")
    print(f"  Initial state: coverage={obs.metrics.mean_coverage:.3f}  "
          f"precision={obs.metrics.mean_precision:.3f}  "
          f"empty={obs.metrics.n_empty_retrievals}")

    while not obs.done:
        observation_text = _format_observation(obs, action_history)

        if verbose:
            print(f"\n--- Observation (step {obs.steps_taken}) ---")
            print(observation_text)

        # Call GPT-4o-mini with structured output
        response = client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": _SYSTEM_PROMPT},
                {"role": "user",   "content": observation_text},
            ],
            response_format=AgentDecision,
            temperature=0.2,
        )

        decision: AgentDecision = response.choices[0].message.parsed
        action = _decision_to_action(decision)

        if verbose:
            print(f"\n  Reasoning: {decision.reasoning[:200]}")

        try:
            step_result = env.step(action)
        except RuntimeError as e:
            # OpenEnv can report terminal state from server side even if the
            # local observation's done flag has not yet been updated.
            if "Episode is already done" in str(e):
                break
            raise
        reward = step_result.reward or 0.0
        total_reward += reward
        obs = step_result.observation

        action_history.append({
            "action_type": action.action_type,
            "params": action.params,
            "reward": reward,
        })

        cov_str = f"coverage={obs.metrics.mean_coverage:.3f}"
        print(
            f"  Step {obs.steps_taken:2d}: {action.action_type:<22} "
            f"reward={reward:+.3f}  {cov_str}"
        )

        if obs.done:
            final_coverage = obs.metrics.mean_coverage
            final_precision = obs.metrics.mean_precision
            # Infer success from terminal reward
            success = reward >= 0.7
            break

    outcome = "SUCCESS βœ“" if success else "failed βœ—"
    print(f"  {'─'*50}")
    print(f"  {outcome}  |  total_reward={total_reward:+.3f}  "
          f"final_coverage={obs.metrics.mean_coverage:.3f}  "
          f"steps={obs.steps_taken}")

    return {
        "task_id":          task_id,
        "episode":          episode_num,
        "seed":             seed,
        "steps":            obs.steps_taken,
        "final_coverage":   obs.metrics.mean_coverage,
        "final_precision":  obs.metrics.mean_precision,
        "success":          success,
        "total_reward":     total_reward,
        "actions":          [ah["action_type"] for ah in action_history],
    }


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser(
        description="GPT-4o-mini zero-shot eval agent for RAGDebugEnv"
    )
    parser.add_argument(
        "--task", choices=["1", "2", "3", "all"], default="1",
        help="Task ID to evaluate (default: 1)",
    )
    parser.add_argument(
        "--episodes", type=int, default=3,
        help="Number of episodes per task (default: 3)",
    )
    parser.add_argument(
        "--server", default="http://localhost:8000",
        help="Environment server URL (default: http://localhost:8000)",
    )
    parser.add_argument(
        "--seed", type=int, default=None,
        help="Random seed for reproducibility (default: random)",
    )
    parser.add_argument(
        "--verbose", action="store_true",
        help="Print full observation each step",
    )
    args = parser.parse_args()

    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        print("ERROR: OPENAI_API_KEY not set in environment.", file=sys.stderr)
        sys.exit(1)

    openai_client = OpenAI(api_key=api_key)

    tasks = [1, 2, 3] if args.task == "all" else [int(args.task)]
    all_results: list[dict] = []

    env = RAGDebugEnv(base_url=args.server)
    with env.sync() as env:
        for task_id in tasks:
            print(f"\n{'='*60}")
            print(f"  Task {task_id}  ({args.episodes} episodes)")
            print(f"{'='*60}")

            for ep in range(1, args.episodes + 1):
                seed = args.seed if args.seed is not None else None
                try:
                    result = run_episode(
                        client=openai_client,
                        env=env,
                        task_id=task_id,
                        seed=seed,
                        episode_num=ep,
                        verbose=args.verbose,
                    )
                    all_results.append(result)
                except Exception as e:
                    print(f"\n  ERROR in episode {ep}: {e}", file=sys.stderr)
                    import traceback
                    traceback.print_exc()

    # Summary table
    if all_results:
        print(f"\n{'='*60}")
        print(f"  Summary")
        print(f"{'='*60}")
        for task_id in tasks:
            task_results = [r for r in all_results if r["task_id"] == task_id]
            if not task_results:
                continue
            n_success = sum(1 for r in task_results if r["success"])
            avg_cov = sum(r["final_coverage"] for r in task_results) / len(task_results)
            avg_steps = sum(r["steps"] for r in task_results) / len(task_results)
            avg_reward = sum(r["total_reward"] for r in task_results) / len(task_results)
            print(
                f"  Task {task_id}: {n_success}/{len(task_results)} success  "
                f"avg_coverage={avg_cov:.3f}  avg_steps={avg_steps:.1f}  "
                f"avg_reward={avg_reward:+.3f}"
            )


if __name__ == "__main__":
    main()