Beens commited on
Commit
593a998
Β·
verified Β·
1 Parent(s): 4550c86

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +144 -158
environment.py CHANGED
@@ -1,158 +1,144 @@
1
- """
2
- Core environment logic for IndicScriptureQA.
3
-
4
- Implements the OpenEnv interface:
5
- reset(task_name, scenario_index) β†’ StepResult
6
- step(action) β†’ StepResult
7
- state() β†’ EnvState
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import random
13
- from typing import Optional
14
-
15
- from models import Action, ActionType, EnvState, Observation, StepResult, StructuralMeta
16
- from rewards import normalize_score, step_reward, terminal_reward
17
- from tasks import TASKS, Scenario, TaskConfig
18
-
19
-
20
- class IndicScriptureQAEnv:
21
- """Stateful environment β€” one instance per episode."""
22
-
23
- def __init__(self) -> None:
24
- self._state: Optional[EnvState] = None
25
-
26
- # ── reset ─────────────────────────────────────────────────────────────
27
-
28
- def reset(
29
- self,
30
- task_name: str = "verify-factual",
31
- scenario_index: Optional[int] = None,
32
- ) -> StepResult:
33
- if task_name not in TASKS:
34
- raise ValueError(f"Unknown task {task_name!r}. Choose from {list(TASKS)}")
35
-
36
- cfg: TaskConfig = TASKS[task_name]
37
- if scenario_index is not None:
38
- idx = scenario_index % len(cfg.scenarios)
39
- else:
40
- idx = random.randint(0, len(cfg.scenarios) - 1)
41
-
42
- sc: Scenario = cfg.scenarios[idx]
43
-
44
- self._state = EnvState(
45
- question=sc.question,
46
- current_answer=sc.given_answer,
47
- original_answer=sc.given_answer,
48
- ground_truth_answer=sc.ground_truth_answer,
49
- ground_truth_citations=list(sc.ground_truth_citations),
50
- available_passages=list(sc.available_passages),
51
- answer_is_correct=sc.answer_is_correct,
52
- factual_is_correct=sc.factual_is_correct,
53
- structural_meta=sc.structural_meta,
54
- structural_hints=list(sc.structural_hints),
55
- task_name=task_name,
56
- max_steps=cfg.max_steps,
57
- steps_remaining=cfg.max_steps,
58
- step_count=0,
59
- done=False,
60
- cumulative_reward=0.0,
61
- rewards=[],
62
- retrieval_count=0,
63
- edit_count=0,
64
- restructure_count=0,
65
- feedback="Episode started. Examine the answer for factual accuracy AND semantic structure.",
66
- )
67
- return StepResult(observation=self._state.to_observation(), reward=0.0, done=False)
68
-
69
- # ── step ──────────────────────────────────────────────────────────────
70
-
71
- def step(self, action: Action) -> StepResult:
72
- s = self._state
73
- if s is None:
74
- raise RuntimeError("Call reset() before step().")
75
- if s.done:
76
- raise RuntimeError("Episode already finished. Call reset().")
77
-
78
- s.step_count += 1
79
- s.steps_remaining -= 1
80
- act = action.action_type
81
- payload = (action.payload or "").strip()
82
-
83
- reward = 0.0
84
- feedback = ""
85
- done = False
86
-
87
- # ── action dispatch ───────────────────────────────────────────────
88
- if act == ActionType.RETRIEVE:
89
- s.retrieval_count += 1
90
- if s.available_passages:
91
- idx = (s.retrieval_count - 1) % len(s.available_passages)
92
- passage = s.available_passages[idx]
93
- if passage not in s.retrieved_passages:
94
- s.retrieved_passages.append(passage)
95
- reward, feedback = step_reward(s, act, payload)
96
-
97
- elif act == ActionType.EDIT:
98
- s.edit_count += 1
99
- reward, feedback = step_reward(s, act, payload)
100
- if payload:
101
- s.current_answer = payload
102
-
103
- elif act == ActionType.RESTRUCTURE:
104
- s.restructure_count += 1
105
- reward, feedback = step_reward(s, act, payload)
106
- if payload:
107
- s.current_answer = payload
108
-
109
- elif act == ActionType.CITE:
110
- if payload and payload not in s.current_citations:
111
- s.current_citations.append(payload)
112
- reward, feedback = step_reward(s, act, payload)
113
-
114
- elif act == ActionType.ACCEPT:
115
- t_reward, feedback = terminal_reward(s, act)
116
- reward = t_reward
117
- done = True
118
-
119
- elif act == ActionType.REJECT:
120
- t_reward, feedback = terminal_reward(s, act)
121
- reward = t_reward
122
- done = True
123
-
124
- else:
125
- reward = -0.10
126
- feedback = f"Unknown action type: {act}"
127
-
128
- # ── check step limit ───────────────────────��──────────────────────
129
- if not done and s.steps_remaining <= 0:
130
- t_reward, t_fb = terminal_reward(s, ActionType.ACCEPT)
131
- reward += t_reward - 0.20
132
- feedback += f" | Forced termination (step limit). {t_fb}"
133
- done = True
134
-
135
- # ── bookkeeping ──────────────────────────────────────────────────
136
- s.rewards.append(reward)
137
- s.cumulative_reward += reward
138
- s.done = done
139
- s.feedback = feedback
140
-
141
- info = {}
142
- if done:
143
- info["score"] = normalize_score(s.cumulative_reward)
144
- info["cumulative_reward"] = s.cumulative_reward
145
-
146
- return StepResult(
147
- observation=s.to_observation(),
148
- reward=reward,
149
- done=done,
150
- info=info,
151
- )
152
-
153
- # ── state ─────────────────────────────────────────────────────────────
154
-
155
- def state(self) -> EnvState:
156
- if self._state is None:
157
- raise RuntimeError("Call reset() first.")
158
- return self._state.model_copy(deep=True)
 
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Optional
5
+
6
+ from models import Action, ActionType, EnvState, Observation, StepResult, StructuralMeta
7
+ from rewards import normalize_score, step_reward, terminal_reward
8
+ from tasks import TASKS, Scenario, TaskConfig
9
+
10
+
11
+ class IndicScriptureQAEnv:
12
+
13
+ def __init__(self) -> None:
14
+ self._state: Optional[EnvState] = None
15
+
16
+ def reset(
17
+ self,
18
+ task_name: str = "verify-factual",
19
+ scenario_index: Optional[int] = None,
20
+ ) -> StepResult:
21
+ if task_name not in TASKS:
22
+ raise ValueError(f"Unknown task {task_name!r}. Choose from {list(TASKS)}")
23
+
24
+ cfg: TaskConfig = TASKS[task_name]
25
+ if scenario_index is not None:
26
+ idx = scenario_index % len(cfg.scenarios)
27
+ else:
28
+ idx = random.randint(0, len(cfg.scenarios) - 1)
29
+
30
+ sc: Scenario = cfg.scenarios[idx]
31
+
32
+ self._state = EnvState(
33
+ question=sc.question,
34
+ current_answer=sc.given_answer,
35
+ original_answer=sc.given_answer,
36
+ ground_truth_answer=sc.ground_truth_answer,
37
+ ground_truth_citations=list(sc.ground_truth_citations),
38
+ available_passages=list(sc.available_passages),
39
+ answer_is_correct=sc.answer_is_correct,
40
+ factual_is_correct=sc.factual_is_correct,
41
+ structural_meta=sc.structural_meta,
42
+ structural_hints=list(sc.structural_hints),
43
+ task_name=task_name,
44
+ max_steps=cfg.max_steps,
45
+ steps_remaining=cfg.max_steps,
46
+ step_count=0,
47
+ done=False,
48
+ cumulative_reward=0.0,
49
+ rewards=[],
50
+ retrieval_count=0,
51
+ edit_count=0,
52
+ restructure_count=0,
53
+ feedback="Episode started. Examine the answer for factual accuracy AND semantic structure.",
54
+ )
55
+ return StepResult(observation=self._state.to_observation(), reward=0.0, done=False)
56
+
57
+
58
+ def step(self, action: Action) -> StepResult:
59
+ s = self._state
60
+ if s is None:
61
+ raise RuntimeError("Call reset() before step().")
62
+ if s.done:
63
+ raise RuntimeError("Episode already finished. Call reset().")
64
+
65
+ s.step_count += 1
66
+ s.steps_remaining -= 1
67
+ act = action.action_type
68
+ payload = (action.payload or "").strip()
69
+
70
+ reward = 0.0
71
+ feedback = ""
72
+ done = False
73
+
74
+ # action dispatch
75
+ if act == ActionType.RETRIEVE:
76
+ s.retrieval_count += 1
77
+ if s.available_passages:
78
+ idx = (s.retrieval_count - 1) % len(s.available_passages)
79
+ passage = s.available_passages[idx]
80
+ if passage not in s.retrieved_passages:
81
+ s.retrieved_passages.append(passage)
82
+ reward, feedback = step_reward(s, act, payload)
83
+
84
+ elif act == ActionType.EDIT:
85
+ s.edit_count += 1
86
+ reward, feedback = step_reward(s, act, payload)
87
+ if payload:
88
+ s.current_answer = payload
89
+
90
+ elif act == ActionType.RESTRUCTURE:
91
+ s.restructure_count += 1
92
+ reward, feedback = step_reward(s, act, payload)
93
+ if payload:
94
+ s.current_answer = payload
95
+
96
+ elif act == ActionType.CITE:
97
+ if payload and payload not in s.current_citations:
98
+ s.current_citations.append(payload)
99
+ reward, feedback = step_reward(s, act, payload)
100
+
101
+ elif act == ActionType.ACCEPT:
102
+ t_reward, feedback = terminal_reward(s, act)
103
+ reward = t_reward
104
+ done = True
105
+
106
+ elif act == ActionType.REJECT:
107
+ t_reward, feedback = terminal_reward(s, act)
108
+ reward = t_reward
109
+ done = True
110
+
111
+ else:
112
+ reward = -0.10
113
+ feedback = f"Unknown action type: {act}"
114
+
115
+ # check steps
116
+ if not done and s.steps_remaining <= 0:
117
+ t_reward, t_fb = terminal_reward(s, ActionType.ACCEPT)
118
+ reward += t_reward - 0.20
119
+ feedback += f" | Forced termination (step limit). {t_fb}"
120
+ done = True
121
+
122
+ # book-keep the logs
123
+ s.rewards.append(reward)
124
+ s.cumulative_reward += reward
125
+ s.done = done
126
+ s.feedback = feedback
127
+
128
+ info = {}
129
+ if done:
130
+ info["score"] = normalize_score(s.cumulative_reward)
131
+ info["cumulative_reward"] = s.cumulative_reward
132
+
133
+ return StepResult(
134
+ observation=s.to_observation(),
135
+ reward=reward,
136
+ done=done,
137
+ info=info,
138
+ )
139
+
140
+ # update state
141
+ def state(self) -> EnvState:
142
+ if self._state is None:
143
+ raise RuntimeError("Call reset() first.")
144
+ return self._state.model_copy(deep=True)