File size: 19,198 Bytes
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac224ce
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af29724
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac224ce
 
 
 
 
 
 
 
 
 
 
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac224ce
 
 
 
f23deb1
 
 
 
 
 
ac224ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f23deb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
"""
models.py
---------
Defines the typed Action and Observation for RAGDebugEnv, plus all
internal simulation models used by the environment logic.

Architecture
------------
Two tiers of models live here:

  Tier 1 β€” OpenEnv interface types (must inherit from framework bases)
    RAGDebugAction     inherits openenv.core.env_server.types.Action
    RAGDebugObservation inherits openenv.core.env_server.types.Observation

  Tier 2 β€” Internal simulation models (plain Pydantic BaseModel)
    PipelineConfig, QueryResult, QualityMetrics, CorpusStats, Reward,
    InternalState, EpisodeResult

  The OpenEnv-provided State class is used directly for episode
  metadata (episode_id, step_count). It is NOT subclassed β€” the
  framework owns that contract.

Import convention
-----------------
    from models import RAGDebugAction, RAGDebugObservation
  from openenv.core.env_server.types import State   # for episode state
"""

from __future__ import annotations

import json
from enum import Enum
from typing import Any, Dict, List, Optional, Set

from pydantic import BaseModel, Field, field_validator, model_validator

# ── OpenEnv base types ────────────────────────────────────────────────────────
# These are the two types the framework requires us to subclass.
# Import path confirmed from official docs:
# https://meta-pytorch.org/OpenEnv/environment-builder/
from openenv.core.env_server.types import Action, Observation


# =============================================================================
# Enums shared across both tiers
# =============================================================================

class EmbeddingModel(str, Enum):
    """
    The four embedding models the pipeline can use.

    GENERAL  β€” sentence-transformers/all-MiniLM-L6-v2.
               Fast, general-purpose.
               Works well on everyday text but degrades on specialist domains.
    MEDICAL  β€” NeuML/pubmedbert-base-embeddings.
               Trained on biomedical retrieval tasks.
    LEGAL    β€” nlpaueb/legal-bert-base-uncased.  Trained on legal corpora.
    CODE     β€” sentence-transformers/multi-qa-mpnet-base-dot-v1.
               Retrieval-tuned contrast model (keeps historical "code" slot).
    """
    GENERAL = "general"
    MEDICAL = "medical"
    LEGAL   = "legal"
    CODE    = "code"


class Domain(str, Enum):
    """
    The corpus domain for each task difficulty.

    SOFTWARE β€” Python docs.  Clean prose, unambiguous vocabulary.  Task 1.
    CLIMATE  β€” IPCC reports.  Cross-disciplinary, more ambiguous.  Task 2.
    MEDICAL  β€” MedRAG textbooks.  Heavy domain terminology.  Task 3.
    """
    SOFTWARE = "software"
    CLIMATE  = "climate"
    MEDICAL  = "medical"


class ActionType(str, Enum):
    """
    Every action the agent can take against the pipeline.

    Config actions modify PipelineConfig in-place. The environment
    re-simulates retrieval on the updated config immediately.

    REWRITE_QUERY rewrites one query's text β€” simulated by
    perturbing its similarity scores toward the ground-truth set.

    SUBMIT declares the agent is done.  Triggers grading.
    Submitting before the success threshold incurs a penalty.
    """
    ADJUST_CHUNK_SIZE    = "adjust_chunk_size"
    ADJUST_CHUNK_OVERLAP = "adjust_chunk_overlap"
    ADJUST_THRESHOLD     = "adjust_threshold"
    ADJUST_TOP_K         = "adjust_top_k"
    SWAP_EMBEDDING_MODEL = "swap_embedding_model"
    TOGGLE_RERANKING     = "toggle_reranking"
    ADJUST_CONTEXT_LIMIT = "adjust_context_limit"
    REWRITE_QUERY        = "rewrite_query"
    SUBMIT               = "submit"


class FaultType(str, Enum):
    """
    Every fault that can be injected into the simulated pipeline.
    Stored in InternalState.  Never exposed in RAGDebugObservation.
    """
    CHUNK_TOO_LARGE       = "chunk_too_large"
    CHUNK_TOO_SMALL       = "chunk_too_small"
    THRESHOLD_TOO_LOW     = "threshold_too_low"
    THRESHOLD_TOO_HIGH    = "threshold_too_high"
    TOP_K_TOO_SMALL       = "top_k_too_small"
    CONTEXT_OVERFLOW      = "context_overflow"
    DUPLICATE_FLOODING    = "duplicate_flooding"
    WRONG_EMBEDDING_MODEL = "wrong_embedding_model"
    NO_RERANKING          = "no_reranking"


# =============================================================================
# Tier 1 β€” OpenEnv interface types
# =============================================================================

class RAGDebugAction(Action):
    """
    The action an agent takes against the RAG pipeline.

    Inherits from openenv.core.env_server.types.Action as required by
    the OpenEnv spec.  The framework uses this class for serialisation,
    deserialisation, and web-UI form generation.

    action_type selects the operation.  params carries its arguments.

    Parameter schemas by action_type
    ---------------------------------
    adjust_chunk_size     {"value": int}        64 ≀ value ≀ 2048
    adjust_chunk_overlap  {"value": int}        0 ≀ value ≀ 500
    adjust_threshold      {"value": float}      0.0 ≀ value ≀ 1.0
    adjust_top_k          {"value": int}        1 ≀ value ≀ 50
    swap_embedding_model  {"model": str}        EmbeddingModel enum value
    toggle_reranking      {"enabled": bool}
    adjust_context_limit  {"value": int}        512 ≀ value ≀ 16384
    rewrite_query         {"query_id": int,
                           "strategy": str}    currently only "rephrase" is supported
    submit                {}
    """
    action_type: ActionType = Field(
        ...,
        description="Which pipeline operation to perform.",
    )
    params: Dict[str, Any] = Field(
        default_factory=dict,
        description="Arguments for the chosen action_type.",
    )

    @field_validator("params", mode="before")
    @classmethod
    def coerce_params_dict(cls, value: Any) -> Dict[str, Any]:
        """Accept dicts and JSON-stringified dicts from the web UI."""
        if value is None:
            return {}
        if isinstance(value, dict):
            return value
        if isinstance(value, str):
            text = value.strip()
            if not text:
                return {}
            try:
                parsed = json.loads(text)
            except json.JSONDecodeError as exc:
                raise ValueError("params must be a dictionary or valid JSON object string") from exc
            if not isinstance(parsed, dict):
                raise ValueError("params JSON must decode to an object")
            return parsed
        raise TypeError("params must be a dictionary or JSON object string")

    def __str__(self) -> str:
        if self.params:
            param_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
            return f"{self.action_type.value}({param_str})"
        return f"{self.action_type.value}()"


class RAGDebugObservation(Observation):
    """
    Everything the agent is allowed to see after each step.

    Inherits from openenv.core.env_server.types.Observation as required
    by the OpenEnv spec.

    Intentional omissions
    ---------------------
    injected_faults is NOT here.  The agent must infer faults from
    metrics alone β€” that reasoning IS the task.  Faults are only
    revealed in InternalState (accessible via env.state(), used by
    graders and debuggers, not given to the agent).

    Fields
    ------
    pipeline_config   The current parameter set the agent may modify.
    query_results     Per-query retrieval results under current config.
    metrics           Aggregate quality metrics across all queries.
    corpus_stats      Static metadata about the corpus (domain, size).
    steps_taken       Actions taken so far this episode.
    max_steps         Budget before the episode force-terminates.
    task_id           1 = easy, 2 = medium, 3 = hard.
    task_description  Plain-language objective for the agent's prompt.
    done              True once the episode has ended.
    """
    pipeline_config:  PipelineConfig  = Field(
        ..., description="Current pipeline configuration the agent can modify."
    )
    query_results:    List[QueryResult] = Field(
        ..., description="Per-query retrieval results under the current config."
    )
    metrics:          QualityMetrics  = Field(
        ..., description="Aggregate retrieval quality metrics."
    )
    corpus_stats:     CorpusStats     = Field(
        ..., description="Static metadata about the corpus for this episode."
    )
    steps_taken:      int             = Field(
        ..., description="Number of actions taken so far this episode."
    )
    max_steps:        int             = Field(
        ..., description="Maximum actions allowed before episode force-terminates."
    )
    task_id:          int             = Field(
        ..., description="Task identifier: 1 = easy, 2 = medium, 3 = hard."
    )
    task_description: str             = Field(
        ..., description="Plain-language objective for the agent."
    )
    done:             bool            = Field(
        False, description="True once the episode has ended."
    )
    last_action_error: Optional[str]  = Field(
        None, description="Error message if the last action was invalid or failed."
    )
    diagnostic_hints:  List[str]      = Field(
        default_factory=list,
        description="Context-aware diagnostic hints based on current metric patterns.",
    )
    reward_components: Dict[str, float] = Field(
        default_factory=dict,
        description="Named breakdown of the reward signal for interpretability.",
    )


# =============================================================================
# Tier 2 β€” Internal simulation models  (plain Pydantic BaseModel)
# =============================================================================

# ── Pipeline Configuration ────────────────────────────────────────────────────

class PipelineConfig(BaseModel):
    """
    The complete set of parameters defining the RAG pipeline's behaviour.

    These are the knobs the agent turns.  Every RAGDebugAction ultimately
    modifies one field here (or switches the active embedding model, which
    swaps which S_true matrix is used in simulation).

    Bounds reflect real-world sensible ranges.  The validator enforces
    that overlap < chunk_size because an overlap equal to chunk_size
    would produce infinite identical chunks.
    """
    chunk_size:           int            = Field(512,  ge=64,   le=2048)
    chunk_overlap:        int            = Field(50,   ge=0,    le=500)
    similarity_threshold: float          = Field(0.3,  ge=0.0,  le=1.0)
    top_k:                int            = Field(10,   ge=1,    le=50)
    embedding_model:      EmbeddingModel = EmbeddingModel.GENERAL
    use_reranking:        bool           = False
    context_window_limit: int            = Field(4096, ge=512,  le=16384)

    @model_validator(mode="after")
    def overlap_less_than_chunk_size(self) -> "PipelineConfig":
        if self.chunk_overlap >= self.chunk_size:
            raise ValueError(
                f"chunk_overlap ({self.chunk_overlap}) must be "
                f"strictly less than chunk_size ({self.chunk_size})"
            )
        return self


# ── Per-Query Results ─────────────────────────────────────────────────────────

class QueryResult(BaseModel):
    """
    Retrieval outcome for a single query under the current config.

    retrieved_chunk_ids and retrieval_scores are parallel β€” index i of
    each list refers to the same chunk.

    coverage_score = |R_agent ∩ R*| / |R*|
      1.0 β†’ all relevant chunks retrieved
      0.0 β†’ no relevant chunks retrieved

    is_multi_hop flags queries that require two chunks to answer
    (relevant for Task 3 grading only).
    """
    query_id:            int
    query_text:          str
    retrieved_chunk_ids: List[int]
    retrieval_scores:    List[float]
    n_retrieved:         int
    coverage_score:      float = Field(ge=0.0, le=1.0)
    precision_score:     float = Field(ge=0.0, le=1.0)
    is_multi_hop:        bool  = False


# ── Aggregate Metrics ─────────────────────────────────────────────────────────

class QualityMetrics(BaseModel):
    """
    Aggregate retrieval quality across all queries in the episode.

    mean_coverage        Primary signal.  Mean of per-query coverage scores.
    mean_precision       Fraction of retrieved chunks that are relevant.
    mean_recall          Fraction of relevant chunks that were retrieved.
                         Numerically equals mean_coverage when R* is the
                         ground-truth set, but tracked separately for clarity.
    n_empty_retrievals   Queries where nothing passed the threshold filter.
    n_context_overflows  Queries where retrieved chunks exceeded limit.
    multi_hop_coverage   Mean coverage on multi-hop queries only.
                         None when no multi-hop queries exist (Tasks 1 & 2).
    """
    mean_coverage:      float          = Field(ge=0.0, le=1.0)
    mean_precision:     float          = Field(ge=0.0, le=1.0)
    mean_recall:        float          = Field(ge=0.0, le=1.0)
    n_empty_retrievals: int            = Field(ge=0)
    n_context_overflows: int           = Field(ge=0)
    multi_hop_coverage: Optional[float] = Field(None, ge=0.0, le=1.0)


# ── Corpus Metadata ───────────────────────────────────────────────────────────

class CorpusStats(BaseModel):
    """
    Static metadata about the corpus for this episode.
    Gives the agent context about the data it's working with.
    """
    domain:              Domain
    n_documents:         int
    n_chunks:            int
    avg_chunk_tokens:    int
    has_near_duplicates: bool
    n_queries:           int
    n_multi_hop_queries: int


# ── Reward ────────────────────────────────────────────────────────────────────

class Reward(BaseModel):
    """
    The reward signal produced by env.step().

    All rewards are in [0.0, 1.0].  Non-terminal step rewards span
    [0.0, ~0.89] based on absolute quality progress; terminal rewards
    occupy [0.7, 1.0] (success) or [0.0, 0.15] (failure).

    value is the scalar used by the RL algorithm.

    components is a labelled breakdown for interpretability.  The
    environment always populates this β€” it aids debugging and makes
    reward shaping decisions auditable.

    Non-terminal step components
    ----------------------------
    progress_reward         0.10 + 0.55 Γ— progress β†’ [0.10, 0.65]
                            progress = min(1, quality_score / quality_target)
                            Absolute quality level signal; ensures full reward
                            range is utilised across the episode.
    delta_bonus             clip(Ξ”quality Γ— 2.0, βˆ’0.15, +0.15)
                            Direction signal: distinguishes an improving step
                            from a no-op at the same quality level.
    empty_retrieval_signal  Bidirectional: rewards fixing empties, penalizes new ones, weight 0.06
    overflow_signal         Bidirectional: rewards fixing overflows, penalizes new ones, weight 0.04
    step_cost               Fixed -0.01 per step (efficiency pressure)
    redundancy_penalty      -0.04 if same action type taken twice consecutively
    invalid_action_penalty  -0.05 if the action had invalid parameters

    Terminal SUBMIT components
    --------------------------
    terminal_success        0.7 + 0.3 Γ— task_score β†’ [0.7, 1.0] on successful SUBMIT
    terminal_failure        0.2 Γ— task_score β†’ [0.0, 0.2] on premature SUBMIT
    """
    value:      float
    components: Dict[str, float] = Field(default_factory=dict)

    def __str__(self) -> str:
        parts = ", ".join(f"{k}={v:+.3f}" for k, v in self.components.items())
        return f"Reward(total={self.value:+.3f} | {parts})"


# ── Fault Config (internal, never sent to agent) ──────────────────────────────

class FaultConfig(BaseModel):
    """
    Parameters of a single injected fault.
    Stored in InternalState.  Never included in RAGDebugObservation.
    """
    fault_type:  FaultType
    params:      Dict[str, Any] = Field(default_factory=dict)
    description: str = ""


# ── Internal State (server-side only) ─────────────────────────────────────────

class InternalState(BaseModel):
    """
    Full server-side state of the environment.

    Returned by env.state() and used by graders and the
    RealPipelineBackend adapter.  NOT given to the agent during training.

    The OpenEnv framework's State class (with episode_id and step_count)
    is used alongside this for the parts the framework owns.  This class
    carries the domain-specific internal state.
    """
    injected_faults: List[FaultConfig]
    episode_seed:    int
    action_history:  List[RAGDebugAction] = Field(default_factory=list)
    reward_history:  List[float]          = Field(default_factory=list)

    @property
    def total_reward(self) -> float:
        return sum(self.reward_history)

    @property
    def fault_names(self) -> List[str]:
        return [f.fault_type.value for f in self.injected_faults]


# ── Episode Result (post-episode summary) ────────────────────────────────────

class EpisodeResult(BaseModel):
    """
    Summary returned by env.grade() after a completed episode.

    task_score     0.0–1.0 from the task's grader function.
    success        True if task_score >= the task's success_threshold.
    fault_names    Which faults were injected (revealed post-episode).
    """
    task_id:        int
    task_score:     float = Field(ge=0.0, le=1.0)
    success:        bool
    n_steps:        int
    total_reward:   float
    final_metrics:  QualityMetrics
    fault_names:    List[str]
    action_history: List[RAGDebugAction]


# =============================================================================
# Rebuild forward references
# =============================================================================
# RAGDebugObservation references PipelineConfig, QueryResult, QualityMetrics,
# and CorpusStats which are defined after it in the file.  model_rebuild()
# resolves those forward refs.
RAGDebugObservation.model_rebuild()
InternalState.model_rebuild()
EpisodeResult.model_rebuild()