| import json |
| import random |
| from pathlib import Path |
|
|
| from openenv.core import Environment |
|
|
| from environment.actions import ( |
| ActionType, ContextCorruptionAction, Document, |
| EpisodeObservation, ContextCorruptionState, |
| ) |
| from environment.reward import ContextCorruptionRubric |
|
|
| _FALLBACK_FACTS = [ |
| {"question": "What is the capital of France?", "answer": "Paris"} |
| ] |
|
|
|
|
| class ContextCorruptionEnv(Environment[ContextCorruptionAction, EpisodeObservation, ContextCorruptionState]): |
| MAX_BUDGET = 12 |
| NUM_DOCS = 8 |
| DIFFICULTY_LEVELS = [1, 2, 3, 4] |
| SUPPORTS_CONCURRENT_SESSIONS = True |
|
|
| def __init__(self, difficulty=None): |
| rubric = ContextCorruptionRubric(state_fn=self._state_dict) |
| super().__init__(rubric=rubric) |
| self.difficulty = difficulty |
| facts_path = Path(__file__).parent.parent / "data" / "facts.json" |
| if facts_path.exists(): |
| with open(facts_path, encoding="utf-8") as f: |
| self._facts = json.load(f) |
| else: |
| self._facts = _FALLBACK_FACTS |
| self._reset_state() |
|
|
| def _reset_state(self): |
| self._question = "" |
| self._ground_truth = "" |
| self._documents: list[dict] = [] |
| self._corrupt_ids: list[int] = [] |
| self._flagged_ids: list[int] = [] |
| self._budget_used = 0 |
| self._turn = 0 |
| self._done = False |
| self._reward = None |
| self._breakdown = None |
|
|
| def reset(self, seed=None, episode_id=None, **kwargs) -> EpisodeObservation: |
| self._reset_rubric() |
| self._reset_state() |
| if seed is not None: |
| random.seed(seed) |
| fact = random.choice(self._facts) |
| n_corrupt = self.difficulty if self.difficulty is not None else random.choice(self.DIFFICULTY_LEVELS) |
| self._corrupt_ids = random.sample(range(self.NUM_DOCS), n_corrupt) |
| self._question = fact["question"] |
| self._ground_truth = fact["answer"] |
|
|
| try: |
| from data.generator import generate_documents |
| raw_docs = generate_documents(fact, num_docs=self.NUM_DOCS, corrupt_positions=self._corrupt_ids) |
| except Exception: |
| raw_docs = [ |
| {"id": i, "title": f"Document {i}", "content": fact["answer"], "is_corrupt": i in self._corrupt_ids} |
| for i in range(self.NUM_DOCS) |
| ] |
|
|
| self._documents = raw_docs |
| return self._apply_transform(self._build_observation()) |
|
|
| def step(self, action: ContextCorruptionAction, timeout_s=None, **kwargs) -> EpisodeObservation: |
| if self._done: |
| return self._apply_transform(self._build_observation(message="Episode already done.")) |
|
|
| self._turn += 1 |
| self._budget_used += 1 |
|
|
| if action.action_type == ActionType.read_doc: |
| pass |
|
|
| elif action.action_type == ActionType.flag_suspicious: |
| if action.doc_id is not None and action.doc_id not in self._flagged_ids: |
| self._flagged_ids.append(action.doc_id) |
|
|
| elif action.action_type == ActionType.unflag_doc: |
| if action.doc_id in self._flagged_ids: |
| self._flagged_ids.remove(action.doc_id) |
|
|
| elif action.action_type == ActionType.submit_answer: |
| self._done = True |
|
|
| |
| if self._budget_used >= self.MAX_BUDGET and not self._done: |
| self._done = True |
|
|
| obs = self._build_observation() |
|
|
| if obs.done: |
| obs.reward = self._apply_rubric(action, obs) |
| self._reward = obs.reward |
| self._breakdown = self.rubric.last_breakdown if self.rubric else None |
|
|
| return self._apply_transform(obs) |
|
|
| @property |
| def state(self) -> ContextCorruptionState: |
| return ContextCorruptionState( |
| question=self._question, |
| ground_truth=self._ground_truth, |
| corrupt_ids=list(self._corrupt_ids), |
| flagged_ids=list(self._flagged_ids), |
| budget_used=self._budget_used, |
| done=self._done, |
| reward=self._reward, |
| breakdown=self._breakdown, |
| ) |
|
|
| def _state_dict(self) -> dict: |
| return { |
| "ground_truth": self._ground_truth, |
| "flagged_ids": list(self._flagged_ids), |
| "corrupt_ids": list(self._corrupt_ids), |
| "budget_used": self._budget_used, |
| "max_budget": self.MAX_BUDGET, |
| } |
|
|
| def _build_observation(self, message=None) -> EpisodeObservation: |
| docs = [ |
| Document( |
| id=d["id"], |
| title=d["title"], |
| content=d["content"], |
| is_flagged=d["id"] in self._flagged_ids, |
| ) |
| for d in self._documents |
| ] |
| return EpisodeObservation( |
| question=self._question, |
| documents=docs, |
| flagged_ids=list(self._flagged_ids), |
| budget_remaining=self.MAX_BUDGET - self._budget_used, |
| turn=self._turn, |
| done=self._done, |
| reward=self._reward, |
| message=message, |
| ) |
|
|