File size: 6,153 Bytes
fe406e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
Core environment logic for IndicScriptureQA.
Implements the OpenEnv interface:
reset(task_name, scenario_index) β StepResult
step(action) β StepResult
state() β EnvState
"""
from __future__ import annotations
import random
from typing import Optional
from models import Action, ActionType, EnvState, Observation, StepResult, StructuralMeta
from rewards import normalize_score, step_reward, terminal_reward
from tasks import TASKS, Scenario, TaskConfig
class IndicScriptureQAEnv:
"""Stateful environment β one instance per episode."""
def __init__(self) -> None:
self._state: Optional[EnvState] = None
# ββ reset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reset(
self,
task_name: str = "verify-factual",
scenario_index: Optional[int] = None,
) -> StepResult:
if task_name not in TASKS:
raise ValueError(f"Unknown task {task_name!r}. Choose from {list(TASKS)}")
cfg: TaskConfig = TASKS[task_name]
if scenario_index is not None:
idx = scenario_index % len(cfg.scenarios)
else:
idx = random.randint(0, len(cfg.scenarios) - 1)
sc: Scenario = cfg.scenarios[idx]
self._state = EnvState(
question=sc.question,
current_answer=sc.given_answer,
original_answer=sc.given_answer,
ground_truth_answer=sc.ground_truth_answer,
ground_truth_citations=list(sc.ground_truth_citations),
available_passages=list(sc.available_passages),
answer_is_correct=sc.answer_is_correct,
factual_is_correct=sc.factual_is_correct,
structural_meta=sc.structural_meta,
structural_hints=list(sc.structural_hints),
task_name=task_name,
max_steps=cfg.max_steps,
steps_remaining=cfg.max_steps,
step_count=0,
done=False,
cumulative_reward=0.0,
rewards=[],
retrieval_count=0,
edit_count=0,
restructure_count=0,
feedback="Episode started. Examine the answer for factual accuracy AND semantic structure.",
)
return StepResult(observation=self._state.to_observation(), reward=0.0, done=False)
# ββ step ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def step(self, action: Action) -> StepResult:
s = self._state
if s is None:
raise RuntimeError("Call reset() before step().")
if s.done:
raise RuntimeError("Episode already finished. Call reset().")
s.step_count += 1
s.steps_remaining -= 1
act = action.action_type
payload = (action.payload or "").strip()
reward = 0.0
feedback = ""
done = False
# ββ action dispatch βββββββββββββββββββββββββββββββββββββββββββββββ
if act == ActionType.RETRIEVE:
s.retrieval_count += 1
if s.available_passages:
idx = (s.retrieval_count - 1) % len(s.available_passages)
passage = s.available_passages[idx]
if passage not in s.retrieved_passages:
s.retrieved_passages.append(passage)
reward, feedback = step_reward(s, act, payload)
elif act == ActionType.EDIT:
s.edit_count += 1
reward, feedback = step_reward(s, act, payload)
if payload:
s.current_answer = payload
elif act == ActionType.RESTRUCTURE:
s.restructure_count += 1
reward, feedback = step_reward(s, act, payload)
if payload:
s.current_answer = payload
elif act == ActionType.CITE:
if payload and payload not in s.current_citations:
s.current_citations.append(payload)
reward, feedback = step_reward(s, act, payload)
elif act == ActionType.ACCEPT:
t_reward, feedback = terminal_reward(s, act)
reward = t_reward
done = True
elif act == ActionType.REJECT:
t_reward, feedback = terminal_reward(s, act)
reward = t_reward
done = True
else:
reward = -0.10
feedback = f"Unknown action type: {act}"
# ββ check step limit ββββββββββββββββββββββββββββββββββββββββββββββ
if not done and s.steps_remaining <= 0:
t_reward, t_fb = terminal_reward(s, ActionType.ACCEPT)
reward += t_reward - 0.20
feedback += f" | Forced termination (step limit). {t_fb}"
done = True
# ββ bookkeeping ββββββββββββββββββββββββββββββββββββββββββββββββββ
s.rewards.append(reward)
s.cumulative_reward += reward
s.done = done
s.feedback = feedback
info = {}
if done:
info["score"] = normalize_score(s.cumulative_reward)
info["cumulative_reward"] = s.cumulative_reward
return StepResult(
observation=s.to_observation(),
reward=reward,
done=done,
info=info,
)
# ββ state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def state(self) -> EnvState:
if self._state is None:
raise RuntimeError("Call reset() first.")
return self._state.model_copy(deep=True)
|