import json import random from pathlib import Path from openenv.core import Environment from environment.actions import ( ActionType, ContextCorruptionAction, Document, EpisodeObservation, ContextCorruptionState, ) from environment.reward import ContextCorruptionRubric _FALLBACK_FACTS = [ {"question": "What is the capital of France?", "answer": "Paris"} ] class ContextCorruptionEnv(Environment[ContextCorruptionAction, EpisodeObservation, ContextCorruptionState]): MAX_BUDGET = 12 NUM_DOCS = 8 DIFFICULTY_LEVELS = [1, 2, 3, 4] SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self, difficulty=None): rubric = ContextCorruptionRubric(state_fn=self._state_dict) super().__init__(rubric=rubric) self.difficulty = difficulty facts_path = Path(__file__).parent.parent / "data" / "facts.json" if facts_path.exists(): with open(facts_path, encoding="utf-8") as f: self._facts = json.load(f) else: self._facts = _FALLBACK_FACTS self._reset_state() def _reset_state(self): self._question = "" self._ground_truth = "" self._documents: list[dict] = [] self._corrupt_ids: list[int] = [] self._flagged_ids: list[int] = [] self._budget_used = 0 self._turn = 0 self._done = False self._reward = None self._breakdown = None def reset(self, seed=None, episode_id=None, **kwargs) -> EpisodeObservation: self._reset_rubric() self._reset_state() if seed is not None: random.seed(seed) fact = random.choice(self._facts) n_corrupt = self.difficulty if self.difficulty is not None else random.choice(self.DIFFICULTY_LEVELS) self._corrupt_ids = random.sample(range(self.NUM_DOCS), n_corrupt) self._question = fact["question"] self._ground_truth = fact["answer"] try: from data.generator import generate_documents raw_docs = generate_documents(fact, num_docs=self.NUM_DOCS, corrupt_positions=self._corrupt_ids) except Exception: raw_docs = [ {"id": i, "title": f"Document {i}", "content": fact["answer"], "is_corrupt": i in self._corrupt_ids} for i in range(self.NUM_DOCS) ] self._documents = raw_docs return self._apply_transform(self._build_observation()) def step(self, action: ContextCorruptionAction, timeout_s=None, **kwargs) -> EpisodeObservation: if self._done: return self._apply_transform(self._build_observation(message="Episode already done.")) self._turn += 1 self._budget_used += 1 if action.action_type == ActionType.read_doc: pass elif action.action_type == ActionType.flag_suspicious: if action.doc_id is not None and action.doc_id not in self._flagged_ids: self._flagged_ids.append(action.doc_id) elif action.action_type == ActionType.unflag_doc: if action.doc_id in self._flagged_ids: self._flagged_ids.remove(action.doc_id) elif action.action_type == ActionType.submit_answer: self._done = True # Force-submit on budget exhaustion if self._budget_used >= self.MAX_BUDGET and not self._done: self._done = True obs = self._build_observation() if obs.done: obs.reward = self._apply_rubric(action, obs) self._reward = obs.reward self._breakdown = self.rubric.last_breakdown if self.rubric else None return self._apply_transform(obs) @property def state(self) -> ContextCorruptionState: return ContextCorruptionState( question=self._question, ground_truth=self._ground_truth, corrupt_ids=list(self._corrupt_ids), flagged_ids=list(self._flagged_ids), budget_used=self._budget_used, done=self._done, reward=self._reward, breakdown=self._breakdown, ) def _state_dict(self) -> dict: return { "ground_truth": self._ground_truth, "flagged_ids": list(self._flagged_ids), "corrupt_ids": list(self._corrupt_ids), "budget_used": self._budget_used, "max_budget": self.MAX_BUDGET, } def _build_observation(self, message=None) -> EpisodeObservation: docs = [ Document( id=d["id"], title=d["title"], content=d["content"], is_flagged=d["id"] in self._flagged_ids, ) for d in self._documents ] return EpisodeObservation( question=self._question, documents=docs, flagged_ids=list(self._flagged_ids), budget_remaining=self.MAX_BUDGET - self._budget_used, turn=self._turn, done=self._done, reward=self._reward, message=message, )