Siddh12334's picture
feat: training space with manual start UI
204fa23 verified
import json
import random
from pathlib import Path
from openenv.core import Environment
from environment.actions import (
ActionType, ContextCorruptionAction, Document,
EpisodeObservation, ContextCorruptionState,
)
from environment.reward import ContextCorruptionRubric
_FALLBACK_FACTS = [
{"question": "What is the capital of France?", "answer": "Paris"}
]
class ContextCorruptionEnv(Environment[ContextCorruptionAction, EpisodeObservation, ContextCorruptionState]):
MAX_BUDGET = 12
NUM_DOCS = 8
DIFFICULTY_LEVELS = [1, 2, 3, 4]
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, difficulty=None):
rubric = ContextCorruptionRubric(state_fn=self._state_dict)
super().__init__(rubric=rubric)
self.difficulty = difficulty
facts_path = Path(__file__).parent.parent / "data" / "facts.json"
if facts_path.exists():
with open(facts_path, encoding="utf-8") as f:
self._facts = json.load(f)
else:
self._facts = _FALLBACK_FACTS
self._reset_state()
def _reset_state(self):
self._question = ""
self._ground_truth = ""
self._documents: list[dict] = []
self._corrupt_ids: list[int] = []
self._flagged_ids: list[int] = []
self._budget_used = 0
self._turn = 0
self._done = False
self._reward = None
self._breakdown = None
def reset(self, seed=None, episode_id=None, **kwargs) -> EpisodeObservation:
self._reset_rubric()
self._reset_state()
if seed is not None:
random.seed(seed)
fact = random.choice(self._facts)
n_corrupt = self.difficulty if self.difficulty is not None else random.choice(self.DIFFICULTY_LEVELS)
self._corrupt_ids = random.sample(range(self.NUM_DOCS), n_corrupt)
self._question = fact["question"]
self._ground_truth = fact["answer"]
try:
from data.generator import generate_documents
raw_docs = generate_documents(fact, num_docs=self.NUM_DOCS, corrupt_positions=self._corrupt_ids)
except Exception:
raw_docs = [
{"id": i, "title": f"Document {i}", "content": fact["answer"], "is_corrupt": i in self._corrupt_ids}
for i in range(self.NUM_DOCS)
]
self._documents = raw_docs
return self._apply_transform(self._build_observation())
def step(self, action: ContextCorruptionAction, timeout_s=None, **kwargs) -> EpisodeObservation:
if self._done:
return self._apply_transform(self._build_observation(message="Episode already done."))
self._turn += 1
self._budget_used += 1
if action.action_type == ActionType.read_doc:
pass
elif action.action_type == ActionType.flag_suspicious:
if action.doc_id is not None and action.doc_id not in self._flagged_ids:
self._flagged_ids.append(action.doc_id)
elif action.action_type == ActionType.unflag_doc:
if action.doc_id in self._flagged_ids:
self._flagged_ids.remove(action.doc_id)
elif action.action_type == ActionType.submit_answer:
self._done = True
# Force-submit on budget exhaustion
if self._budget_used >= self.MAX_BUDGET and not self._done:
self._done = True
obs = self._build_observation()
if obs.done:
obs.reward = self._apply_rubric(action, obs)
self._reward = obs.reward
self._breakdown = self.rubric.last_breakdown if self.rubric else None
return self._apply_transform(obs)
@property
def state(self) -> ContextCorruptionState:
return ContextCorruptionState(
question=self._question,
ground_truth=self._ground_truth,
corrupt_ids=list(self._corrupt_ids),
flagged_ids=list(self._flagged_ids),
budget_used=self._budget_used,
done=self._done,
reward=self._reward,
breakdown=self._breakdown,
)
def _state_dict(self) -> dict:
return {
"ground_truth": self._ground_truth,
"flagged_ids": list(self._flagged_ids),
"corrupt_ids": list(self._corrupt_ids),
"budget_used": self._budget_used,
"max_budget": self.MAX_BUDGET,
}
def _build_observation(self, message=None) -> EpisodeObservation:
docs = [
Document(
id=d["id"],
title=d["title"],
content=d["content"],
is_flagged=d["id"] in self._flagged_ids,
)
for d in self._documents
]
return EpisodeObservation(
question=self._question,
documents=docs,
flagged_ids=list(self._flagged_ids),
budget_remaining=self.MAX_BUDGET - self._budget_used,
turn=self._turn,
done=self._done,
reward=self._reward,
message=message,
)