File size: 6,153 Bytes
fe406e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Core environment logic for IndicScriptureQA.

Implements the OpenEnv interface:
  reset(task_name, scenario_index) β†’ StepResult
  step(action)                     β†’ StepResult
  state()                          β†’ EnvState
"""

from __future__ import annotations

import random
from typing import Optional

from models import Action, ActionType, EnvState, Observation, StepResult, StructuralMeta
from rewards import normalize_score, step_reward, terminal_reward
from tasks import TASKS, Scenario, TaskConfig


class IndicScriptureQAEnv:
    """Stateful environment β€” one instance per episode."""

    def __init__(self) -> None:
        self._state: Optional[EnvState] = None

    # ── reset ─────────────────────────────────────────────────────────────

    def reset(
        self,
        task_name: str = "verify-factual",
        scenario_index: Optional[int] = None,
    ) -> StepResult:
        if task_name not in TASKS:
            raise ValueError(f"Unknown task {task_name!r}. Choose from {list(TASKS)}")

        cfg: TaskConfig = TASKS[task_name]
        if scenario_index is not None:
            idx = scenario_index % len(cfg.scenarios)
        else:
            idx = random.randint(0, len(cfg.scenarios) - 1)

        sc: Scenario = cfg.scenarios[idx]

        self._state = EnvState(
            question=sc.question,
            current_answer=sc.given_answer,
            original_answer=sc.given_answer,
            ground_truth_answer=sc.ground_truth_answer,
            ground_truth_citations=list(sc.ground_truth_citations),
            available_passages=list(sc.available_passages),
            answer_is_correct=sc.answer_is_correct,
            factual_is_correct=sc.factual_is_correct,
            structural_meta=sc.structural_meta,
            structural_hints=list(sc.structural_hints),
            task_name=task_name,
            max_steps=cfg.max_steps,
            steps_remaining=cfg.max_steps,
            step_count=0,
            done=False,
            cumulative_reward=0.0,
            rewards=[],
            retrieval_count=0,
            edit_count=0,
            restructure_count=0,
            feedback="Episode started. Examine the answer for factual accuracy AND semantic structure.",
        )
        return StepResult(observation=self._state.to_observation(), reward=0.0, done=False)

    # ── step ──────────────────────────────────────────────────────────────

    def step(self, action: Action) -> StepResult:
        s = self._state
        if s is None:
            raise RuntimeError("Call reset() before step().")
        if s.done:
            raise RuntimeError("Episode already finished. Call reset().")

        s.step_count += 1
        s.steps_remaining -= 1
        act = action.action_type
        payload = (action.payload or "").strip()

        reward = 0.0
        feedback = ""
        done = False

        # ── action dispatch ───────────────────────────────────────────────
        if act == ActionType.RETRIEVE:
            s.retrieval_count += 1
            if s.available_passages:
                idx = (s.retrieval_count - 1) % len(s.available_passages)
                passage = s.available_passages[idx]
                if passage not in s.retrieved_passages:
                    s.retrieved_passages.append(passage)
            reward, feedback = step_reward(s, act, payload)

        elif act == ActionType.EDIT:
            s.edit_count += 1
            reward, feedback = step_reward(s, act, payload)
            if payload:
                s.current_answer = payload

        elif act == ActionType.RESTRUCTURE:
            s.restructure_count += 1
            reward, feedback = step_reward(s, act, payload)
            if payload:
                s.current_answer = payload

        elif act == ActionType.CITE:
            if payload and payload not in s.current_citations:
                s.current_citations.append(payload)
            reward, feedback = step_reward(s, act, payload)

        elif act == ActionType.ACCEPT:
            t_reward, feedback = terminal_reward(s, act)
            reward = t_reward
            done = True

        elif act == ActionType.REJECT:
            t_reward, feedback = terminal_reward(s, act)
            reward = t_reward
            done = True

        else:
            reward = -0.10
            feedback = f"Unknown action type: {act}"

        # ── check step limit ──────────────────────────────────────────────
        if not done and s.steps_remaining <= 0:
            t_reward, t_fb = terminal_reward(s, ActionType.ACCEPT)
            reward += t_reward - 0.20
            feedback += f" | Forced termination (step limit). {t_fb}"
            done = True

        # ── bookkeeping ──────────────────────────────────────────────────
        s.rewards.append(reward)
        s.cumulative_reward += reward
        s.done = done
        s.feedback = feedback

        info = {}
        if done:
            info["score"] = normalize_score(s.cumulative_reward)
            info["cumulative_reward"] = s.cumulative_reward

        return StepResult(
            observation=s.to_observation(),
            reward=reward,
            done=done,
            info=info,
        )

    # ── state ─────────────────────────────────────────────────────────────

    def state(self) -> EnvState:
        if self._state is None:
            raise RuntimeError("Call reset() first.")
        return self._state.model_copy(deep=True)