Spaces:
Sleeping
Sleeping
File size: 19,198 Bytes
f23deb1 ac224ce f23deb1 af29724 f23deb1 ac224ce f23deb1 ac224ce f23deb1 ac224ce f23deb1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 | """
models.py
---------
Defines the typed Action and Observation for RAGDebugEnv, plus all
internal simulation models used by the environment logic.
Architecture
------------
Two tiers of models live here:
Tier 1 β OpenEnv interface types (must inherit from framework bases)
RAGDebugAction inherits openenv.core.env_server.types.Action
RAGDebugObservation inherits openenv.core.env_server.types.Observation
Tier 2 β Internal simulation models (plain Pydantic BaseModel)
PipelineConfig, QueryResult, QualityMetrics, CorpusStats, Reward,
InternalState, EpisodeResult
The OpenEnv-provided State class is used directly for episode
metadata (episode_id, step_count). It is NOT subclassed β the
framework owns that contract.
Import convention
-----------------
from models import RAGDebugAction, RAGDebugObservation
from openenv.core.env_server.types import State # for episode state
"""
from __future__ import annotations
import json
from enum import Enum
from typing import Any, Dict, List, Optional, Set
from pydantic import BaseModel, Field, field_validator, model_validator
# ββ OpenEnv base types ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# These are the two types the framework requires us to subclass.
# Import path confirmed from official docs:
# https://meta-pytorch.org/OpenEnv/environment-builder/
from openenv.core.env_server.types import Action, Observation
# =============================================================================
# Enums shared across both tiers
# =============================================================================
class EmbeddingModel(str, Enum):
"""
The four embedding models the pipeline can use.
GENERAL β sentence-transformers/all-MiniLM-L6-v2.
Fast, general-purpose.
Works well on everyday text but degrades on specialist domains.
MEDICAL β NeuML/pubmedbert-base-embeddings.
Trained on biomedical retrieval tasks.
LEGAL β nlpaueb/legal-bert-base-uncased. Trained on legal corpora.
CODE β sentence-transformers/multi-qa-mpnet-base-dot-v1.
Retrieval-tuned contrast model (keeps historical "code" slot).
"""
GENERAL = "general"
MEDICAL = "medical"
LEGAL = "legal"
CODE = "code"
class Domain(str, Enum):
"""
The corpus domain for each task difficulty.
SOFTWARE β Python docs. Clean prose, unambiguous vocabulary. Task 1.
CLIMATE β IPCC reports. Cross-disciplinary, more ambiguous. Task 2.
MEDICAL β MedRAG textbooks. Heavy domain terminology. Task 3.
"""
SOFTWARE = "software"
CLIMATE = "climate"
MEDICAL = "medical"
class ActionType(str, Enum):
"""
Every action the agent can take against the pipeline.
Config actions modify PipelineConfig in-place. The environment
re-simulates retrieval on the updated config immediately.
REWRITE_QUERY rewrites one query's text β simulated by
perturbing its similarity scores toward the ground-truth set.
SUBMIT declares the agent is done. Triggers grading.
Submitting before the success threshold incurs a penalty.
"""
ADJUST_CHUNK_SIZE = "adjust_chunk_size"
ADJUST_CHUNK_OVERLAP = "adjust_chunk_overlap"
ADJUST_THRESHOLD = "adjust_threshold"
ADJUST_TOP_K = "adjust_top_k"
SWAP_EMBEDDING_MODEL = "swap_embedding_model"
TOGGLE_RERANKING = "toggle_reranking"
ADJUST_CONTEXT_LIMIT = "adjust_context_limit"
REWRITE_QUERY = "rewrite_query"
SUBMIT = "submit"
class FaultType(str, Enum):
"""
Every fault that can be injected into the simulated pipeline.
Stored in InternalState. Never exposed in RAGDebugObservation.
"""
CHUNK_TOO_LARGE = "chunk_too_large"
CHUNK_TOO_SMALL = "chunk_too_small"
THRESHOLD_TOO_LOW = "threshold_too_low"
THRESHOLD_TOO_HIGH = "threshold_too_high"
TOP_K_TOO_SMALL = "top_k_too_small"
CONTEXT_OVERFLOW = "context_overflow"
DUPLICATE_FLOODING = "duplicate_flooding"
WRONG_EMBEDDING_MODEL = "wrong_embedding_model"
NO_RERANKING = "no_reranking"
# =============================================================================
# Tier 1 β OpenEnv interface types
# =============================================================================
class RAGDebugAction(Action):
"""
The action an agent takes against the RAG pipeline.
Inherits from openenv.core.env_server.types.Action as required by
the OpenEnv spec. The framework uses this class for serialisation,
deserialisation, and web-UI form generation.
action_type selects the operation. params carries its arguments.
Parameter schemas by action_type
---------------------------------
adjust_chunk_size {"value": int} 64 β€ value β€ 2048
adjust_chunk_overlap {"value": int} 0 β€ value β€ 500
adjust_threshold {"value": float} 0.0 β€ value β€ 1.0
adjust_top_k {"value": int} 1 β€ value β€ 50
swap_embedding_model {"model": str} EmbeddingModel enum value
toggle_reranking {"enabled": bool}
adjust_context_limit {"value": int} 512 β€ value β€ 16384
rewrite_query {"query_id": int,
"strategy": str} currently only "rephrase" is supported
submit {}
"""
action_type: ActionType = Field(
...,
description="Which pipeline operation to perform.",
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Arguments for the chosen action_type.",
)
@field_validator("params", mode="before")
@classmethod
def coerce_params_dict(cls, value: Any) -> Dict[str, Any]:
"""Accept dicts and JSON-stringified dicts from the web UI."""
if value is None:
return {}
if isinstance(value, dict):
return value
if isinstance(value, str):
text = value.strip()
if not text:
return {}
try:
parsed = json.loads(text)
except json.JSONDecodeError as exc:
raise ValueError("params must be a dictionary or valid JSON object string") from exc
if not isinstance(parsed, dict):
raise ValueError("params JSON must decode to an object")
return parsed
raise TypeError("params must be a dictionary or JSON object string")
def __str__(self) -> str:
if self.params:
param_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
return f"{self.action_type.value}({param_str})"
return f"{self.action_type.value}()"
class RAGDebugObservation(Observation):
"""
Everything the agent is allowed to see after each step.
Inherits from openenv.core.env_server.types.Observation as required
by the OpenEnv spec.
Intentional omissions
---------------------
injected_faults is NOT here. The agent must infer faults from
metrics alone β that reasoning IS the task. Faults are only
revealed in InternalState (accessible via env.state(), used by
graders and debuggers, not given to the agent).
Fields
------
pipeline_config The current parameter set the agent may modify.
query_results Per-query retrieval results under current config.
metrics Aggregate quality metrics across all queries.
corpus_stats Static metadata about the corpus (domain, size).
steps_taken Actions taken so far this episode.
max_steps Budget before the episode force-terminates.
task_id 1 = easy, 2 = medium, 3 = hard.
task_description Plain-language objective for the agent's prompt.
done True once the episode has ended.
"""
pipeline_config: PipelineConfig = Field(
..., description="Current pipeline configuration the agent can modify."
)
query_results: List[QueryResult] = Field(
..., description="Per-query retrieval results under the current config."
)
metrics: QualityMetrics = Field(
..., description="Aggregate retrieval quality metrics."
)
corpus_stats: CorpusStats = Field(
..., description="Static metadata about the corpus for this episode."
)
steps_taken: int = Field(
..., description="Number of actions taken so far this episode."
)
max_steps: int = Field(
..., description="Maximum actions allowed before episode force-terminates."
)
task_id: int = Field(
..., description="Task identifier: 1 = easy, 2 = medium, 3 = hard."
)
task_description: str = Field(
..., description="Plain-language objective for the agent."
)
done: bool = Field(
False, description="True once the episode has ended."
)
last_action_error: Optional[str] = Field(
None, description="Error message if the last action was invalid or failed."
)
diagnostic_hints: List[str] = Field(
default_factory=list,
description="Context-aware diagnostic hints based on current metric patterns.",
)
reward_components: Dict[str, float] = Field(
default_factory=dict,
description="Named breakdown of the reward signal for interpretability.",
)
# =============================================================================
# Tier 2 β Internal simulation models (plain Pydantic BaseModel)
# =============================================================================
# ββ Pipeline Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ
class PipelineConfig(BaseModel):
"""
The complete set of parameters defining the RAG pipeline's behaviour.
These are the knobs the agent turns. Every RAGDebugAction ultimately
modifies one field here (or switches the active embedding model, which
swaps which S_true matrix is used in simulation).
Bounds reflect real-world sensible ranges. The validator enforces
that overlap < chunk_size because an overlap equal to chunk_size
would produce infinite identical chunks.
"""
chunk_size: int = Field(512, ge=64, le=2048)
chunk_overlap: int = Field(50, ge=0, le=500)
similarity_threshold: float = Field(0.3, ge=0.0, le=1.0)
top_k: int = Field(10, ge=1, le=50)
embedding_model: EmbeddingModel = EmbeddingModel.GENERAL
use_reranking: bool = False
context_window_limit: int = Field(4096, ge=512, le=16384)
@model_validator(mode="after")
def overlap_less_than_chunk_size(self) -> "PipelineConfig":
if self.chunk_overlap >= self.chunk_size:
raise ValueError(
f"chunk_overlap ({self.chunk_overlap}) must be "
f"strictly less than chunk_size ({self.chunk_size})"
)
return self
# ββ Per-Query Results βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class QueryResult(BaseModel):
"""
Retrieval outcome for a single query under the current config.
retrieved_chunk_ids and retrieval_scores are parallel β index i of
each list refers to the same chunk.
coverage_score = |R_agent β© R*| / |R*|
1.0 β all relevant chunks retrieved
0.0 β no relevant chunks retrieved
is_multi_hop flags queries that require two chunks to answer
(relevant for Task 3 grading only).
"""
query_id: int
query_text: str
retrieved_chunk_ids: List[int]
retrieval_scores: List[float]
n_retrieved: int
coverage_score: float = Field(ge=0.0, le=1.0)
precision_score: float = Field(ge=0.0, le=1.0)
is_multi_hop: bool = False
# ββ Aggregate Metrics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class QualityMetrics(BaseModel):
"""
Aggregate retrieval quality across all queries in the episode.
mean_coverage Primary signal. Mean of per-query coverage scores.
mean_precision Fraction of retrieved chunks that are relevant.
mean_recall Fraction of relevant chunks that were retrieved.
Numerically equals mean_coverage when R* is the
ground-truth set, but tracked separately for clarity.
n_empty_retrievals Queries where nothing passed the threshold filter.
n_context_overflows Queries where retrieved chunks exceeded limit.
multi_hop_coverage Mean coverage on multi-hop queries only.
None when no multi-hop queries exist (Tasks 1 & 2).
"""
mean_coverage: float = Field(ge=0.0, le=1.0)
mean_precision: float = Field(ge=0.0, le=1.0)
mean_recall: float = Field(ge=0.0, le=1.0)
n_empty_retrievals: int = Field(ge=0)
n_context_overflows: int = Field(ge=0)
multi_hop_coverage: Optional[float] = Field(None, ge=0.0, le=1.0)
# ββ Corpus Metadata βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class CorpusStats(BaseModel):
"""
Static metadata about the corpus for this episode.
Gives the agent context about the data it's working with.
"""
domain: Domain
n_documents: int
n_chunks: int
avg_chunk_tokens: int
has_near_duplicates: bool
n_queries: int
n_multi_hop_queries: int
# ββ Reward ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Reward(BaseModel):
"""
The reward signal produced by env.step().
All rewards are in [0.0, 1.0]. Non-terminal step rewards span
[0.0, ~0.89] based on absolute quality progress; terminal rewards
occupy [0.7, 1.0] (success) or [0.0, 0.15] (failure).
value is the scalar used by the RL algorithm.
components is a labelled breakdown for interpretability. The
environment always populates this β it aids debugging and makes
reward shaping decisions auditable.
Non-terminal step components
----------------------------
progress_reward 0.10 + 0.55 Γ progress β [0.10, 0.65]
progress = min(1, quality_score / quality_target)
Absolute quality level signal; ensures full reward
range is utilised across the episode.
delta_bonus clip(Ξquality Γ 2.0, β0.15, +0.15)
Direction signal: distinguishes an improving step
from a no-op at the same quality level.
empty_retrieval_signal Bidirectional: rewards fixing empties, penalizes new ones, weight 0.06
overflow_signal Bidirectional: rewards fixing overflows, penalizes new ones, weight 0.04
step_cost Fixed -0.01 per step (efficiency pressure)
redundancy_penalty -0.04 if same action type taken twice consecutively
invalid_action_penalty -0.05 if the action had invalid parameters
Terminal SUBMIT components
--------------------------
terminal_success 0.7 + 0.3 Γ task_score β [0.7, 1.0] on successful SUBMIT
terminal_failure 0.2 Γ task_score β [0.0, 0.2] on premature SUBMIT
"""
value: float
components: Dict[str, float] = Field(default_factory=dict)
def __str__(self) -> str:
parts = ", ".join(f"{k}={v:+.3f}" for k, v in self.components.items())
return f"Reward(total={self.value:+.3f} | {parts})"
# ββ Fault Config (internal, never sent to agent) ββββββββββββββββββββββββββββββ
class FaultConfig(BaseModel):
"""
Parameters of a single injected fault.
Stored in InternalState. Never included in RAGDebugObservation.
"""
fault_type: FaultType
params: Dict[str, Any] = Field(default_factory=dict)
description: str = ""
# ββ Internal State (server-side only) βββββββββββββββββββββββββββββββββββββββββ
class InternalState(BaseModel):
"""
Full server-side state of the environment.
Returned by env.state() and used by graders and the
RealPipelineBackend adapter. NOT given to the agent during training.
The OpenEnv framework's State class (with episode_id and step_count)
is used alongside this for the parts the framework owns. This class
carries the domain-specific internal state.
"""
injected_faults: List[FaultConfig]
episode_seed: int
action_history: List[RAGDebugAction] = Field(default_factory=list)
reward_history: List[float] = Field(default_factory=list)
@property
def total_reward(self) -> float:
return sum(self.reward_history)
@property
def fault_names(self) -> List[str]:
return [f.fault_type.value for f in self.injected_faults]
# ββ Episode Result (post-episode summary) ββββββββββββββββββββββββββββββββββββ
class EpisodeResult(BaseModel):
"""
Summary returned by env.grade() after a completed episode.
task_score 0.0β1.0 from the task's grader function.
success True if task_score >= the task's success_threshold.
fault_names Which faults were injected (revealed post-episode).
"""
task_id: int
task_score: float = Field(ge=0.0, le=1.0)
success: bool
n_steps: int
total_reward: float
final_metrics: QualityMetrics
fault_names: List[str]
action_history: List[RAGDebugAction]
# =============================================================================
# Rebuild forward references
# =============================================================================
# RAGDebugObservation references PipelineConfig, QueryResult, QualityMetrics,
# and CorpusStats which are defined after it in the file. model_rebuild()
# resolves those forward refs.
RAGDebugObservation.model_rebuild()
InternalState.model_rebuild()
EpisodeResult.model_rebuild() |