Spaces:
Sleeping
Sleeping
File size: 36,404 Bytes
f23deb1 b401c21 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 af29724 f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 b401c21 f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce af29724 f23deb1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 | """
server/rag_debug_env_environment.py
------------------------------------
Real RAGDebugEnvironment implementation.
Replaces the echo stub with a full RL environment for training agents to
diagnose and fix broken RAG pipelines.
Architecture summary
--------------------
- Each episode samples N queries from one domain's corpus.
- S_true_{model}.npy matrices (precomputed by Stage 5) hold the ground-truth
cosine similarity scores for every (query, chunk) pair.
- Faults are applied as mathematical transformations: S_faulted = f(S_true, config, faults).
- The agent's job is to modify PipelineConfig (and/or query rewrites) until
the retrieval simulationβrun on S_faultedβrecovers adequate coverage of R*.
- Every config change triggers _recompute_S_faulted() so the fault math can
modulate its severity based on the new config values.
- Noise arrays are pre-generated at reset() time for determinism across
recomputation calls within a single episode.
"""
from __future__ import annotations
import json
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4
import numpy as np
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import EnvironmentMetadata, State
from pydantic import ValidationError
from server.constants import (
_TASK_DOMAIN,
_TASK_DESCRIPTION,
_N_EPISODE_QUERIES,
_MAX_STEPS,
_MODEL_FILE,
_TASK1_FAULT_SETS,
_TASK2_FAULT_SETS,
_TASK3_FAULTS,
)
from server.corpus import _load_corpus
from server.fault_math import apply_faults
from models import (
RAGDebugAction,
RAGDebugObservation,
ActionType,
EmbeddingModel,
FaultType,
FaultConfig,
InternalState,
PipelineConfig,
QueryResult,
QualityMetrics,
CorpusStats,
Domain,
Reward,
)
class RAGDebugEnvironment(Environment):
"""
RL environment for diagnosing and fixing broken RAG pipelines.
Each episode samples a small query set from one domain, injects faults
into the similarity matrix, and rewards the agent for recovering
retrieval quality through pipeline config changes.
Tasks
-----
Task 1 (software): One or two config faults. Success threshold β₯ 0.80.
Task 2 (climate): Compound config faults. Success threshold β₯ 0.75.
Task 3 (medical): Wrong embedding model + config faults + multi-hop.
Success threshold β₯ 0.70 AND multi_hop_coverage > 0.60.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
super().__init__()
self._state = State(episode_id=str(uuid4()), step_count=0)
# Episode-level data, set during reset()
self._task_id: int = 1
self._domain: str = "software"
self._chunks: List[Dict] = []
self._episode_queries: List[Dict] = []
self._ground_truth: Dict[str, List[int]] = {}
self._corpus_stats_dict: Dict = {}
# Per-episode numpy data
self._chunk_ids: List[int] = []
self._chunk_tokens: List[int] = []
self._chunk_id_to_tokens: Dict[int, int] = {} # O(1) token lookup
self._s_true_episode: Dict[str, np.ndarray] = {} # model_name β (n_q, n_c) float32
self._active_model: EmbeddingModel = EmbeddingModel.GENERAL
# Deterministic noise (pre-generated at reset, scaled during recompute)
self._noise: Dict[FaultType, np.ndarray] = {}
self._dupe_ids: np.ndarray = np.array([], dtype=int)
# REWRITE_QUERY persistent overlay
self._rewrite_boosts: np.ndarray = np.zeros((0, 0), dtype=np.float32)
# Current faulted matrix (rebuilt by _recompute_S_faulted)
self._S_faulted: np.ndarray = np.zeros((0, 0), dtype=np.float32)
# Episode config & state
self._config: PipelineConfig = PipelineConfig()
self._injected_faults: List[FaultConfig] = []
self._internal_state: InternalState = InternalState(
injected_faults=[], episode_seed=0
)
self._prev_metrics: Optional[QualityMetrics] = None
self._prev_action_type: Optional[ActionType] = None
self._done: bool = False
self._last_action_error: Optional[str] = None
self._last_reward_components: Dict[str, float] = {}
# ------------------------------------------------------------------
# OpenEnv required interface
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> RAGDebugObservation:
"""Reset the environment and return the initial observation."""
# Inherited from openenv.core.env_server.interfaces.Environment.
# Clears any OpenEnv framework-level episode state (e.g. step counters
# tracked by the base class) before we re-initialise domain-specific state.
self._reset_rubric()
# Task selection
task_id = int(kwargs.get("task_id", 1))
if task_id not in (1, 2, 3):
raise ValueError(f"task_id must be 1, 2, or 3; got {task_id}")
self._task_id = task_id
# RNG
if seed is None:
seed = int(np.random.default_rng().integers(0, 2**31))
rng = np.random.default_rng(seed)
# State bookkeeping
ep_id = episode_id or str(uuid4())
self._state = State(episode_id=ep_id, step_count=0)
self._done = False
self._prev_action_type = None
self._last_action_error = None
self._last_reward_components = {}
# Domain & corpus
self._domain = _TASK_DOMAIN[task_id].value
corpus = _load_corpus(self._domain)
self._chunks = corpus["chunks"]
self._ground_truth = corpus["ground_truth"]
self._corpus_stats_dict = corpus["corpus_stats"]
all_queries: List[Dict] = corpus["queries"]
# Sample episode queries
self._episode_queries = self._sample_queries(all_queries, task_id, rng)
# Build episode index structures
self._chunk_ids = [c["chunk_id"] for c in self._chunks]
self._chunk_tokens = [c.get("n_tokens", 100) for c in self._chunks]
self._chunk_id_to_tokens = {
c["chunk_id"]: c.get("n_tokens", 100) for c in self._chunks
}
n_q = len(self._episode_queries)
n_c = len(self._chunks)
# Slice S_true to episode query rows
# Full S_true shape: (all_queries, n_chunks); query_id == row index in S_true
self._s_true_episode = {}
all_query_ids = [q["query_id"] for q in all_queries]
ep_query_ids = [q["query_id"] for q in self._episode_queries]
# Build index map: query_id β row in the full S_true matrix
qid_to_row = {qid: i for i, qid in enumerate(all_query_ids)}
ep_rows = [qid_to_row[qid] for qid in ep_query_ids]
s_true_full = corpus["s_true"]
for model_name, s_full in s_true_full.items():
self._s_true_episode[model_name] = s_full[ep_rows, :].copy()
# Pre-generate noise (unit normal; scaled during recompute)
shape = (n_q, n_c)
self._noise = {
FaultType.CHUNK_TOO_SMALL: rng.standard_normal(shape).astype(np.float32),
FaultType.THRESHOLD_TOO_LOW: rng.standard_normal(shape).astype(np.float32),
FaultType.NO_RERANKING: rng.standard_normal(shape).astype(np.float32),
}
self._dupe_ids = rng.choice(n_c, size=max(1, n_c // 7), replace=False)
# Reset overlay and config
self._rewrite_boosts = np.zeros(shape, dtype=np.float32)
self._config = PipelineConfig()
# Task 3 starts with the LEGAL embedding model (wrong model for medical text).
# S_true_legal has ~0.62 coverage vs ~0.90 for GENERAL on medical data.
# The agent must discover this and swap to GENERAL or MEDICAL.
self._active_model = EmbeddingModel.LEGAL if task_id == 3 else EmbeddingModel.GENERAL
self._config = self._config.model_copy(update={"embedding_model": self._active_model})
# Start from a mildly constrained baseline so most episodes leave
# headroom for meaningful improvement steps.
self._config = self._config.model_copy(
update={
"top_k": int(rng.integers(5, 9)),
"similarity_threshold": float(rng.uniform(0.34, 0.48)),
}
)
# Sample and inject faults
self._injected_faults = self._sample_faults(task_id, rng)
self._internal_state = InternalState(
injected_faults=self._injected_faults,
episode_seed=seed,
)
# Fault-specific initial config: some faults only degrade coverage when
# the related parameter starts at a bad value.
# TOP_K_TOO_SMALL: score compression preserves rank order, so coverage
# stays high at top_k=10. Start with a small top_k so the agent must
# increase it to recover coverage.
# DUPLICATE_FLOODING: flooded chunks can't displace high-scoring relevant
# chunks in a top_k=10 pool. Start with reduced top_k so the flooding
# actually crowds out relevant chunks.
fault_types_active = {f.fault_type for f in self._injected_faults}
if FaultType.TOP_K_TOO_SMALL in fault_types_active:
self._config = self._config.model_copy(update={"top_k": int(rng.integers(2, 4))})
elif FaultType.DUPLICATE_FLOODING in fault_types_active:
self._config = self._config.model_copy(update={"top_k": int(rng.integers(4, 8))})
# If reset still lands in a very strong state, add one extra nudge
# so the agent usually has room to improve from step 1.
self._calibrate_initial_difficulty(rng)
# Initial matrix + metrics
self._recompute_S_faulted()
initial_results = self._simulate_retrieval()
self._prev_metrics = self._compute_metrics(initial_results)
return self._build_observation(initial_results, reward=None)
def step(
self,
action: RAGDebugAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> RAGDebugObservation:
"""Execute one agent action and return the updated observation."""
if self._done:
raise RuntimeError("Episode is already done. Call reset() to start a new episode.")
self._state.step_count += 1
prev_metrics = self._prev_metrics
# Route action (returns Reward for SUBMIT, None otherwise)
reward_obj = self._apply_action(action)
# Recompute retrieval results and metrics
new_results = self._simulate_retrieval()
new_metrics = self._compute_metrics(new_results)
# Compute reward if not already set by SUBMIT handler
if reward_obj is None:
reward_obj = self._compute_reward(prev_metrics, new_metrics, action)
# Keep per-step reward components available in observations.
self._last_reward_components = dict(reward_obj.components)
self._internal_state.action_history.append(action)
self._internal_state.reward_history.append(reward_obj.value)
self._prev_metrics = new_metrics
self._prev_action_type = action.action_type
# Auto-terminate on max steps
if self._state.step_count >= _MAX_STEPS and not self._done:
self._done = True
return self._build_observation(new_results, reward=reward_obj.value)
@property
def state(self) -> State:
return self._state
def get_metadata(self) -> EnvironmentMetadata:
readme_path = Path(__file__).parent.parent / "README.md"
readme_content: Optional[str] = None
if readme_path.exists():
raw = readme_path.read_text(encoding="utf-8")
# Strip YAML frontmatter (--- ... ---) so the UI renders clean Markdown
if raw.startswith("---"):
end = raw.find("---", 3)
if end != -1:
raw = raw[end + 3:].lstrip("\n")
readme_content = raw
return EnvironmentMetadata(
name="RAGDebugEnv",
description="Debug broken RAG pipelines by tuning config and swapping embedding models.",
readme_content=readme_content,
version="1.0.0",
)
# ------------------------------------------------------------------
# Action routing
# ------------------------------------------------------------------
def _update_config(self, **updates) -> Optional[str]:
"""
Apply updates to PipelineConfig using the constructor, which runs all
Pydantic validators (including model_validators).
model_copy(update=...) does NOT run validators in Pydantic v2, so
invalid combinations (e.g. chunk_overlap >= chunk_size) are silently
accepted and later cause crashes when the config is embedded in an
observation. Using the constructor guarantees validation always runs.
Returns an error string if validation fails, None on success.
"""
try:
self._config = PipelineConfig(**{**self._config.model_dump(), **updates})
return None
except (ValueError, TypeError, ValidationError) as exc:
return str(exc)
def _apply_action(self, action: RAGDebugAction) -> Optional[Reward]:
"""
Apply action to config / overlays.
Returns a Reward if the action is SUBMIT, else None.
"""
self._last_action_error = None
t = action.action_type
p = action.params
if t == ActionType.ADJUST_CHUNK_SIZE:
value = int(p.get("value", self._config.chunk_size))
err = self._update_config(chunk_size=value)
if err:
self._last_action_error = f"Invalid chunk_size {value}: {err}"
self._recompute_S_faulted()
elif t == ActionType.ADJUST_CHUNK_OVERLAP:
value = int(p.get("value", self._config.chunk_overlap))
err = self._update_config(chunk_overlap=value)
if err:
self._last_action_error = f"Invalid chunk_overlap {value}: {err}"
# Recompute required: fault_math.apply_faults() uses chunk_overlap to
# modulate CHUNK_TOO_SMALL noise sigma (higher overlap stabilises boundary
# embeddings, reducing noise severity). Without recompute, the config change
# has no visible effect on retrieval scores until another action triggers it.
self._recompute_S_faulted()
elif t == ActionType.ADJUST_THRESHOLD:
value = float(p.get("value", self._config.similarity_threshold))
err = self._update_config(similarity_threshold=value)
if err:
self._last_action_error = f"Invalid threshold {value}: {err}"
# Threshold is applied in retrieval simulation; no matrix recompute
elif t == ActionType.ADJUST_TOP_K:
value = int(p.get("value", self._config.top_k))
err = self._update_config(top_k=value)
if err:
self._last_action_error = f"Invalid top_k {value}: {err}"
elif t == ActionType.ADJUST_CONTEXT_LIMIT:
value = int(p.get("value", self._config.context_window_limit))
err = self._update_config(context_window_limit=value)
if err:
self._last_action_error = f"Invalid context_limit {value}: {err}"
self._recompute_S_faulted()
elif t == ActionType.SWAP_EMBEDDING_MODEL:
model_str = str(p.get("model", self._active_model.value))
try:
new_model = EmbeddingModel(model_str)
except (ValueError, KeyError) as exc:
self._last_action_error = f"Invalid embedding model '{model_str}': {exc}"
new_model = None
if new_model is not None:
self._active_model = new_model
self._update_config(embedding_model=new_model)
self._recompute_S_faulted()
elif t == ActionType.TOGGLE_RERANKING:
enabled = bool(p.get("enabled", not self._config.use_reranking))
self._update_config(use_reranking=enabled)
self._recompute_S_faulted()
elif t == ActionType.REWRITE_QUERY:
query_id = p.get("query_id")
ep_ids = [q["query_id"] for q in self._episode_queries]
if query_id in ep_ids:
row = ep_ids.index(query_id)
r_star = self._ground_truth.get(str(query_id), [])
# Map chunk_id β column index
cid_to_col = {cid: col for col, cid in enumerate(self._chunk_ids)}
cols = [cid_to_col[cid] for cid in r_star if cid in cid_to_col]
if cols:
self._rewrite_boosts[row, cols] = 0.20
self._recompute_S_faulted()
elif t == ActionType.SUBMIT:
# Compute final metrics at SUBMIT time for accurate grading.
results = self._simulate_retrieval()
metrics = self._compute_metrics(results)
task_score = self._compute_task_score(metrics)
self._done = True
success = self._check_success(metrics, task_score)
if success:
terminal_value = float(np.clip(0.7 + 0.3 * task_score, 0.7, 1.0))
return Reward(
value=terminal_value,
components={"terminal_success": terminal_value},
)
else:
terminal_value = float(np.clip(0.2 * task_score, 0.0, 0.2))
return Reward(
value=terminal_value,
components={"terminal_failure": terminal_value},
)
return None # signal caller to compute reward via _compute_reward
# ------------------------------------------------------------------
# Fault math
# ------------------------------------------------------------------
def _recompute_S_faulted(self) -> None:
"""
Apply all active faults to the S_true matrix for the current active model.
WRONG_EMBEDDING_MODEL is implicit: Task 3 starts with active_model=LEGAL,
whose score distribution on medical text is fundamentally wrong (compressed,
meanβ0.84, stdβ0.033 vs GENERAL stdβ0.13). The agent must diagnose this
from retrieval score distributions and swap models. All other faults are
applied as matrix transformations by apply_faults().
"""
# Lock to GENERAL on tasks without WRONG_EMBEDDING_MODEL so that
# swap_embedding_model actions don't accidentally shift coverage/precision
# via raw score-level differences between embedding model matrices.
fault_types = {f.fault_type for f in self._injected_faults}
if FaultType.WRONG_EMBEDDING_MODEL in fault_types:
model_key = _MODEL_FILE[self._active_model]
else:
model_key = "general"
S = self._s_true_episode.get(model_key)
if S is None:
S = self._s_true_episode["general"]
self._S_faulted = apply_faults(
S=S,
fault_types=fault_types,
config_chunk_size=self._config.chunk_size,
config_context_limit=self._config.context_window_limit,
config_use_reranking=self._config.use_reranking,
config_chunk_overlap=self._config.chunk_overlap,
noise=self._noise,
dupe_ids=self._dupe_ids,
rewrite_boosts=self._rewrite_boosts,
)
# ------------------------------------------------------------------
# Retrieval simulation
# ------------------------------------------------------------------
def _simulate_retrieval(self) -> List[QueryResult]:
"""Run retrieval simulation over all episode queries using S_faulted."""
results = []
config = self._config
n_c = len(self._chunk_ids)
for i, query in enumerate(self._episode_queries):
query_id = query["query_id"]
scores = self._S_faulted[i] # (n_chunks,)
# Top-K by descending score
top_k = min(config.top_k, n_c)
top_indices = np.argsort(scores)[::-1][:top_k]
# Filter by threshold
retrieved: List[Tuple[int, float]] = [
(self._chunk_ids[j], float(scores[j]))
for j in top_indices
if scores[j] >= config.similarity_threshold
]
retrieved_ids = [cid for cid, _ in retrieved]
retrieved_scores = [s for _, s in retrieved]
# Coverage and precision
r_star = set(self._ground_truth.get(str(query_id), []))
r_agent = set(retrieved_ids)
coverage = len(r_agent & r_star) / len(r_star) if r_star else 0.0
precision = len(r_agent & r_star) / len(r_agent) if r_agent else 0.0
# Context overflow check
total_tokens = sum(
self._chunk_id_to_tokens.get(cid, 100)
for cid in retrieved_ids
)
results.append(
QueryResult(
query_id=query_id,
query_text=query["text"],
retrieved_chunk_ids=retrieved_ids,
retrieval_scores=retrieved_scores,
n_retrieved=len(retrieved_ids),
coverage_score=float(np.clip(coverage, 0.0, 1.0)),
precision_score=float(np.clip(precision, 0.0, 1.0)),
is_multi_hop=bool(query.get("is_multi_hop", False)),
)
)
return results
# ------------------------------------------------------------------
# Metrics
# ------------------------------------------------------------------
def _compute_metrics(self, results: List[QueryResult]) -> QualityMetrics:
coverages = [r.coverage_score for r in results]
precisions = [r.precision_score for r in results]
n_empty = sum(1 for r in results if r.n_retrieved == 0)
# Count context overflows by re-checking token counts
n_overflow = 0
config = self._config
for r in results:
total = sum(
self._chunk_id_to_tokens.get(cid, 100)
for cid in r.retrieved_chunk_ids
)
if total > config.context_window_limit:
n_overflow += 1
multi_hop_covs = [r.coverage_score for r in results if r.is_multi_hop]
multi_hop_cov = float(np.mean(multi_hop_covs)) if multi_hop_covs else None
return QualityMetrics(
mean_coverage=float(np.mean(coverages)) if coverages else 0.0,
mean_precision=float(np.mean(precisions)) if precisions else 0.0,
mean_recall=float(np.mean(coverages)) if coverages else 0.0,
n_empty_retrievals=n_empty,
n_context_overflows=n_overflow,
multi_hop_coverage=multi_hop_cov,
)
# ------------------------------------------------------------------
# Reward
# ------------------------------------------------------------------
def _compute_reward(
self,
prev: Optional[QualityMetrics],
new: QualityMetrics,
action: RAGDebugAction,
) -> Reward:
"""Compute a progress-based reward in [0.0, 1.0].
Design: reward reflects the absolute quality level (progress toward the
success threshold) PLUS a small bonus for the direction of change.
This ensures the full [0.0, ~0.89] range is utilised for non-terminal
steps, giving the RL agent a strong per-step learning signal:
- Terrible state, no improvement β reward β 0.09
- Mid quality, no change β reward β 0.42
- At success threshold, no change β reward β 0.64
- Large improvement step β reward up to 0.89
- Large regression + penalties β reward clipped to 0.00
Terminal rewards (SUBMIT) remain in their own zone [0.7, 1.0] for
success and [0.0, 0.15] for failure, as before.
"""
components: Dict[str, float] = {}
n_queries = max(len(self._episode_queries), 1)
# --- Progress reward: absolute quality level signal ---
# Maps quality_score to [0.10, 0.65] proportional to how close we are
# to the task's success threshold. Spans the reward range across the
# full episode regardless of per-step delta magnitude.
quality_target = 0.75 if self._task_id in (1, 2) else 0.70
current_quality = self._quality_score(new)
progress = min(1.0, current_quality / quality_target)
components["progress_reward"] = 0.10 + 0.55 * progress
# --- Delta bonus: immediate direction feedback ---
# Distinguishes an improving step from a no-op at the same quality level.
# Multiplied delta capped at Β±0.15 so a single step never dominates.
if prev is not None:
prev_quality = self._quality_score(prev)
q_delta = current_quality - prev_quality
components["delta_bonus"] = float(np.clip(q_delta * 2.0, -0.15, 0.15))
# Empty retrieval signal: bidirectional (weight 0.06)
empty_change = prev.n_empty_retrievals - new.n_empty_retrievals
components["empty_retrieval_signal"] = float(np.clip(empty_change / n_queries, -1.0, 1.0)) * 0.06
# Context overflow signal: bidirectional (weight 0.04)
overflow_change = prev.n_context_overflows - new.n_context_overflows
components["overflow_signal"] = float(np.clip(overflow_change / n_queries, -1.0, 1.0)) * 0.04
else:
components["delta_bonus"] = 0.0
components["empty_retrieval_signal"] = 0.0
components["overflow_signal"] = 0.0
# --- Efficiency penalties ---
components["step_cost"] = -0.01
# Redundancy penalty for repeating the same action type consecutively
if self._prev_action_type is not None and action.action_type == self._prev_action_type:
components["redundancy_penalty"] = -0.04
else:
components["redundancy_penalty"] = 0.0
# Penalty for invalid action parameters
if self._last_action_error is not None:
components["invalid_action_penalty"] = -0.05
# --- Combine: no fixed base β progress_reward IS the base ---
raw = sum(components.values())
value = float(np.clip(raw, 0.0, 1.0))
return Reward(value=value, components=components)
def _compute_task_score(self, metrics: QualityMetrics) -> float:
"""Compute the scalar task score used for grading."""
n_steps = self._state.step_count
if self._task_id in (1, 2):
efficiency = 1.0 - n_steps / _MAX_STEPS
return (
0.60 * metrics.mean_coverage
+ 0.25 * metrics.mean_precision
+ 0.15 * efficiency
)
else: # task 3
mh_cov = metrics.multi_hop_coverage or 0.0
return (
0.55 * metrics.mean_coverage
+ 0.25 * metrics.mean_precision
+ 0.20 * mh_cov
)
def _quality_score(self, metrics: QualityMetrics) -> float:
"""Quality portion of task_score, excluding efficiency. Normalized to [0, 1].
Uses the same coverage:precision weighting as _compute_task_score so that
step rewards are aligned with the terminal success criterion.
"""
if self._task_id in (1, 2):
# 0.60 + 0.25 = 0.85 max; normalize to [0, 1]
return (0.60 * metrics.mean_coverage + 0.25 * metrics.mean_precision) / 0.85
else: # task 3: includes multi-hop coverage, already sums to 1.0
mh_cov = metrics.multi_hop_coverage or 0.0
return 0.55 * metrics.mean_coverage + 0.25 * metrics.mean_precision + 0.20 * mh_cov
def _check_success(self, metrics: QualityMetrics, task_score: float) -> bool:
if self._task_id == 1:
return task_score >= 0.75
elif self._task_id == 2:
return task_score >= 0.75
else: # task 3
mh_cov = metrics.multi_hop_coverage or 0.0
return task_score >= 0.70 and mh_cov > 0.60
# ------------------------------------------------------------------
# Fault sampling
# ------------------------------------------------------------------
def _sample_faults(self, task_id: int, rng: np.random.Generator) -> List[FaultConfig]:
if task_id == 1:
idx = int(rng.integers(0, len(_TASK1_FAULT_SETS)))
fault_types = _TASK1_FAULT_SETS[idx]
elif task_id == 2:
idx = int(rng.integers(0, len(_TASK2_FAULT_SETS)))
fault_types = _TASK2_FAULT_SETS[idx]
else:
fault_types = _TASK3_FAULTS
return [FaultConfig(fault_type=ft) for ft in fault_types]
# ------------------------------------------------------------------
# Query sampling
# ------------------------------------------------------------------
def _sample_queries(
self, all_queries: List[Dict], task_id: int, rng: np.random.Generator
) -> List[Dict]:
n = _N_EPISODE_QUERIES[task_id]
if task_id == 3:
regular = [q for q in all_queries if not q.get("is_multi_hop")]
multi_hop = [q for q in all_queries if q.get("is_multi_hop")]
n_mh = min(2, len(multi_hop))
n_reg = n - n_mh
reg_sample = list(rng.choice(len(regular), size=min(n_reg, len(regular)), replace=False))
mh_sample = list(rng.choice(len(multi_hop), size=n_mh, replace=False))
sampled = [regular[i] for i in reg_sample] + [multi_hop[i] for i in mh_sample]
else:
indices = list(rng.choice(len(all_queries), size=min(n, len(all_queries)), replace=False))
sampled = [all_queries[i] for i in indices]
return sampled
def _calibrate_initial_difficulty(self, rng: np.random.Generator) -> None:
"""Nudge overly-strong reset states toward improvable starting points."""
self._recompute_S_faulted()
results = self._simulate_retrieval()
metrics = self._compute_metrics(results)
if not results:
return
full_cov_rate = float(
np.mean([1.0 if r.coverage_score >= 0.999 else 0.0 for r in results])
)
cov_caps = {1: 0.60, 2: 0.52, 3: 0.48}
full_cov_caps = {1: 0.50, 2: 0.45, 3: 0.40}
cov_cap = cov_caps.get(self._task_id, 0.60)
full_cov_cap = full_cov_caps.get(self._task_id, 0.50)
if metrics.mean_coverage <= cov_cap and full_cov_rate <= full_cov_cap:
return
updates: Dict[str, Any] = {}
if self._config.top_k > 3:
shrink = int(rng.integers(1, 3))
updates["top_k"] = max(3, self._config.top_k - shrink)
fault_types_active = {f.fault_type for f in self._injected_faults}
if FaultType.THRESHOLD_TOO_HIGH not in fault_types_active:
bump = float(rng.uniform(0.05, 0.12))
updates["similarity_threshold"] = min(
0.75, self._config.similarity_threshold + bump
)
if updates:
self._config = self._config.model_copy(update=updates)
# ------------------------------------------------------------------
# Diagnostic hints
# ------------------------------------------------------------------
def _generate_diagnostic_hints(
self, metrics: QualityMetrics, results: List[QueryResult]
) -> List[str]:
"""Generate context-aware hints based on current metric patterns."""
hints: List[str] = []
cfg = self._config
if metrics.n_empty_retrievals > 0:
hints.append(
f"{metrics.n_empty_retrievals} queries have empty retrievals β "
"consider lowering similarity_threshold or increasing top_k."
)
if metrics.n_context_overflows > 0:
hints.append(
f"{metrics.n_context_overflows} queries exceed context window β "
"consider increasing context_window_limit or reducing top_k."
)
if metrics.mean_coverage < 0.5 and metrics.mean_precision > 0.3:
hints.append(
"Low coverage with moderate precision suggests top_k is too small "
"or the embedding model may not suit this domain."
)
if metrics.mean_coverage < 0.4 and metrics.mean_precision < 0.3:
hints.append(
"Both coverage and precision are low β check if the similarity threshold "
"is filtering out too many chunks, or if the embedding model is mismatched."
)
# Check for score compression (sign of wrong embedding model or TOP_K_TOO_SMALL)
all_scores = [s for r in results for s in r.retrieval_scores]
if all_scores:
score_std = float(np.std(all_scores))
score_mean = float(np.mean(all_scores))
if score_std < 0.05 and len(all_scores) > 3:
hints.append(
f"Retrieval scores are tightly compressed (std={score_std:.3f}) β "
"this may indicate the wrong embedding model or score compression fault."
)
if score_mean > 0.7 and metrics.mean_precision < 0.5:
hints.append(
"High retrieval scores but low precision β many irrelevant chunks are "
"scoring high. Consider enabling reranking or checking for duplicate flooding."
)
# Task 3 multi-hop hint
if self._task_id == 3:
mh_cov = metrics.multi_hop_coverage
if mh_cov is not None and mh_cov < 0.5:
hints.append(
f"Multi-hop coverage is low ({mh_cov:.3f}) β multi-hop queries need "
"broad retrieval. Consider increasing top_k and checking the embedding model."
)
# Reranking hint
if not cfg.use_reranking and metrics.mean_precision < 0.4:
hints.append(
"Reranking is disabled. Enabling it can improve precision by re-scoring "
"candidates with a cross-encoder."
)
return hints
# ------------------------------------------------------------------
# Observation builder
# ------------------------------------------------------------------
def _build_observation(
self,
results: List[QueryResult],
reward: Optional[float],
) -> RAGDebugObservation:
cs = self._corpus_stats_dict
corpus_stats = CorpusStats(
domain=Domain(self._domain),
n_documents=cs.get("n_documents", 0),
n_chunks=cs.get("n_chunks", len(self._chunks)),
avg_chunk_tokens=cs.get("avg_chunk_tokens", 0),
has_near_duplicates=bool(cs.get("has_near_duplicates", False)),
n_queries=cs.get("n_queries", 0),
n_multi_hop_queries=cs.get("n_multi_hop_queries", 0),
)
metrics = self._compute_metrics(results)
hints = self._generate_diagnostic_hints(metrics, results)
obs = RAGDebugObservation(
pipeline_config=self._config,
query_results=results,
metrics=metrics,
corpus_stats=corpus_stats,
steps_taken=self._state.step_count,
max_steps=_MAX_STEPS,
task_id=self._task_id,
task_description=_TASK_DESCRIPTION[self._task_id],
done=self._done,
last_action_error=self._last_action_error,
diagnostic_hints=hints,
reward_components=self._last_reward_components,
reward=reward,
)
return obs
# Backward-compat alias (server/__init__.py and app.py import RagDebugEnvironment)
RagDebugEnvironment = RAGDebugEnvironment
|