kishl commited on
Commit
afe67aa
Β·
1 Parent(s): 57cd13b

llm as a judge

Browse files
Files changed (9) hide show
  1. Dockerfile +3 -2
  2. README.md +0 -8
  3. openenv.yaml +1 -1
  4. pyproject.toml +30 -0
  5. requirements-server.txt +1 -0
  6. rewards.py +283 -324
  7. server/__init__.py +0 -0
  8. server/app.py +18 -0
  9. uv.lock +0 -0
Dockerfile CHANGED
@@ -6,7 +6,8 @@ COPY requirements-server.txt .
6
  RUN pip install --no-cache-dir -r requirements-server.txt
7
 
8
  COPY models.py tasks.py rewards.py environment.py main.py openenv.yaml ./
 
9
 
10
- EXPOSE 8000
11
 
12
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
 
6
  RUN pip install --no-cache-dir -r requirements-server.txt
7
 
8
  COPY models.py tasks.py rewards.py environment.py main.py openenv.yaml ./
9
+ COPY server/ ./server/
10
 
11
+ EXPOSE 7860
12
 
13
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,3 @@
1
- ---
2
- license: mit
3
- language:
4
- - en
5
- - hi
6
- - sa
7
- pipeline_tag: reinforcement-learning
8
- ---
9
  # IndicScriptureQA β€” OpenEnv Environment
10
 
11
  **Semantic structure and factual grounding evaluation for low-resource Indic languages.**
 
 
 
 
 
 
 
 
 
1
  # IndicScriptureQA β€” OpenEnv Environment
2
 
3
  **Semantic structure and factual grounding evaluation for low-resource Indic languages.**
openenv.yaml CHANGED
@@ -13,7 +13,7 @@ license: MIT
13
 
14
  env:
15
  module: main:app
16
- port: 8000
17
  health_endpoint: /health
18
 
19
  action_space:
 
13
 
14
  env:
15
  module: main:app
16
+ port: 7860
17
  health_endpoint: /health
18
 
19
  action_space:
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "indic-scripture-qa"
7
+ version = "1.1.0"
8
+ description = "OpenEnv environment for evaluating LLMs on Indic scripture factual accuracy and semantic structure quality"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ { name = "Kishlay Kisu", email = "kishlay.work1@gmail.com" },
14
+ ]
15
+ keywords = ["openenv", "rl", "indic", "nlp", "benchmark", "low-resource-languages"]
16
+
17
+ dependencies = [
18
+ "fastapi>=0.110.0",
19
+ "uvicorn[standard]>=0.27.0",
20
+ "pydantic>=2.0.0",
21
+ "openai>=1.0.0",
22
+ "requests>=2.31.0",
23
+ "openenv-core>=0.2.0",
24
+ ]
25
+
26
+ [project.scripts]
27
+ server = "server.app:main"
28
+
29
+ [project.urls]
30
+ Repository = "https://huggingface.co/spaces/kishl/indicQARL"
requirements-server.txt CHANGED
@@ -1,3 +1,4 @@
1
  fastapi>=0.110.0
2
  uvicorn[standard]>=0.27.0
3
  pydantic>=2.0.0
 
 
1
  fastapi>=0.110.0
2
  uvicorn[standard]>=0.27.0
3
  pydantic>=2.0.0
4
+ openai>=1.0.0
rewards.py CHANGED
@@ -1,396 +1,355 @@
1
  """
2
- Reward computation for IndicScriptureQA.
3
 
4
- Two evaluation axes, weighted into a single scalar:
5
- A. Factual quality β€” token-F1 similarity to ground truth, citation recall
6
- B. Structural quality β€” coherence, completeness, terminology, ordering
7
 
8
- All scoring is zero-dependency (no ML models) so the env runs on 2 vCPU / 8 GB.
 
 
 
9
  """
10
 
11
  from __future__ import annotations
12
 
 
 
13
  import re
14
- from typing import List, Tuple
 
 
15
 
16
  from models import ActionType, EnvState, StructuralMeta
17
 
18
 
19
  # ═══════════════════════════════════════════════════════════════════════════════
20
- # A. FACTUAL SCORING
21
  # ═══════════════════════════════════════════════════════════════════════════════
22
 
23
- def _tokenize(text: str) -> List[str]:
24
- """Lowercase split on non-alphanumeric (keeps Devanagari chars)."""
25
- return [t for t in re.split(r"[^a-zA-Z0-9\u0900-\u097F]+", text.lower()) if t]
26
-
27
-
28
- def token_f1(candidate: str, reference: str) -> float:
29
- """Token-level F1 between candidate and reference. Returns 0–1."""
30
- cand_toks = _tokenize(candidate)
31
- ref_toks = _tokenize(reference)
32
- if not cand_toks or not ref_toks:
33
- return 0.0
34
- cand_set = set(cand_toks)
35
- ref_set = set(ref_toks)
36
- common = cand_set & ref_set
37
- if not common:
38
- return 0.0
39
- precision = len(common) / len(cand_set)
40
- recall = len(common) / len(ref_set)
41
- return 2 * precision * recall / (precision + recall)
42
-
43
-
44
- def _normalize_citation(c: str) -> str:
45
- return re.sub(r"\s+", " ", c.strip().lower())
46
-
47
-
48
- def citation_recall(predicted: List[str], ground_truth: List[str]) -> float:
49
- """Fraction of ground-truth citations matched (fuzzy substring)."""
50
- if not ground_truth:
51
- return 1.0
52
- gt_norms = [_normalize_citation(g) for g in ground_truth]
53
- pred_norms = [_normalize_citation(p) for p in predicted]
54
- matched = 0
55
- for gt in gt_norms:
56
- for pred in pred_norms:
57
- if gt in pred or pred in gt:
58
- matched += 1
59
- break
60
- return matched / len(gt_norms)
61
 
62
 
63
  # ═══════════════════════════════════════════════════════════════════════════════
64
- # B. STRUCTURAL SCORING
65
  # ═══════════════════════════════════════════════════════════════════════════════
66
 
67
- # ── B1. Terminology precision ────────────────────────────────────────────────
68
-
69
- def terminology_score(answer: str, meta: StructuralMeta) -> float:
70
- """
71
- Checks:
72
- + required_terms present β†’ recall over required_terms
73
- - banned_terms present β†’ hard penalty per banned term found
74
- Returns float in [-1.0, 1.0].
75
- """
76
- answer_lower = answer.lower()
77
-
78
- # required term recall
79
- if meta.required_terms:
80
- hits = sum(1 for t in meta.required_terms if t.lower() in answer_lower)
81
- term_recall = hits / len(meta.required_terms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
- term_recall = 1.0
84
-
85
- # banned term penalty
86
- ban_penalty = 0.0
87
- if meta.banned_terms:
88
- for bt in meta.banned_terms:
89
- if bt.lower() in answer_lower:
90
- ban_penalty += 0.25
91
- ban_penalty = min(ban_penalty, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- return term_recall - ban_penalty
94
 
 
 
 
95
 
96
- # ── B2. Completeness (section coverage) ──────────────────────────────────────
 
97
 
98
- def completeness_score(answer: str, meta: StructuralMeta) -> float:
99
- """
100
- Heuristic: for each required_section, check whether characteristic
101
- keywords from that section label appear in the answer.
102
- Returns 0–1 (fraction of sections covered).
103
- """
104
- if not meta.required_sections:
105
- return 1.0
106
- answer_lower = answer.lower()
107
- covered = 0
108
- for section in meta.required_sections:
109
- # use the keywords from the section label itself
110
- section_keywords = _tokenize(section)
111
- # count a section as covered if β‰₯ half its keywords appear
112
- if section_keywords:
113
- hits = sum(1 for kw in section_keywords if kw in answer_lower)
114
- if hits / len(section_keywords) >= 0.5:
115
- covered += 1
116
- return covered / len(meta.required_sections)
117
-
118
-
119
- # ── B3. Logical ordering (sequence adherence) ────────────────────────────────
120
-
121
- def ordering_score(answer: str, meta: StructuralMeta) -> float:
122
- """
123
- Checks whether concepts in expected_order appear in the correct sequence
124
- in the answer. Uses first-occurrence position of each concept's keywords.
125
- Returns 0–1.
126
- """
127
- if len(meta.expected_order) < 2:
128
- return 1.0
129
 
130
- answer_lower = answer.lower()
131
- positions: List[int] = []
132
-
133
- for concept in meta.expected_order:
134
- keywords = _tokenize(concept)
135
- # find earliest position of any keyword
136
- earliest = len(answer_lower) + 1
137
- for kw in keywords:
138
- idx = answer_lower.find(kw)
139
- if idx != -1 and idx < earliest:
140
- earliest = idx
141
- positions.append(earliest)
142
-
143
- # count correctly ordered adjacent pairs
144
- correct_pairs = sum(
145
- 1 for i in range(len(positions) - 1) if positions[i] <= positions[i + 1]
146
- )
147
- return correct_pairs / (len(positions) - 1)
148
-
149
-
150
- # ── B4. Coherence (transition quality + sentence structure) ──────────────────
151
-
152
- _TRANSITION_MARKERS = {
153
- "therefore", "however", "moreover", "furthermore", "thus", "consequently",
154
- "specifically", "in contrast", "for example", "similarly", "additionally",
155
- "because", "since", "although", "while", "first", "second", "third",
156
- "finally", "in particular", "notably", "according to", "this means",
157
- "as a result", "in other words",
158
- }
159
-
160
-
161
- def coherence_score(answer: str) -> float:
162
- """
163
- Lightweight coherence proxy:
164
- - Sentence count (more than 1 sentence expected)
165
- - Transition markers (discourse connectives)
166
- - Sentence-length variance (very uneven β†’ lower coherence)
167
- Returns 0–1.
168
- """
169
- sentences = [s.strip() for s in re.split(r"[.!?]+", answer) if s.strip()]
170
- if len(sentences) <= 1:
171
- return 0.3 # single sentence is structurally weak for these tasks
172
-
173
- # transition marker density
174
- answer_lower = answer.lower()
175
- marker_count = sum(1 for m in _TRANSITION_MARKERS if m in answer_lower)
176
- marker_density = min(marker_count / max(len(sentences) - 1, 1), 1.0)
177
-
178
- # sentence length variance (normalised). Very uneven β†’ incoherent.
179
- lengths = [len(s.split()) for s in sentences]
180
- mean_len = sum(lengths) / len(lengths)
181
- if mean_len == 0:
182
- return 0.2
183
- variance = sum((l - mean_len) ** 2 for l in lengths) / len(lengths)
184
- cv = (variance ** 0.5) / mean_len # coefficient of variation
185
- uniformity = max(0.0, 1.0 - cv) # lower CV β†’ more uniform β†’ higher score
186
-
187
- # blend: 50 % markers, 30 % uniformity, 20 % baseline for multi-sentence
188
- return 0.5 * marker_density + 0.3 * uniformity + 0.2
189
-
190
-
191
- # ── Composite structural score ───────────────────────────────────────────────
192
-
193
- def structural_quality(answer: str, meta: StructuralMeta) -> Tuple[float, dict]:
194
- """
195
- Weighted composite of all structural axes.
196
- Returns (score_0_to_1, breakdown_dict).
197
- """
198
- ts = terminology_score(answer, meta)
199
- cs = completeness_score(answer, meta)
200
- os_ = ordering_score(answer, meta)
201
- coh = coherence_score(answer)
202
-
203
- # weights
204
- composite = (
205
- 0.30 * max(ts, 0.0) # terminology (clamp negatives to 0 for composite)
206
- + 0.25 * cs # completeness
207
- + 0.25 * os_ # ordering
208
- + 0.20 * coh # coherence
209
- )
210
-
211
- # apply banned-term penalty on top
212
- if ts < 0:
213
- composite += 0.15 * ts # propagate penalty
214
 
215
- composite = max(0.0, min(1.0, composite))
216
 
217
- breakdown = {
218
- "terminology": round(ts, 3),
219
- "completeness": round(cs, 3),
220
- "ordering": round(os_, 3),
221
- "coherence": round(coh, 3),
222
- "composite": round(composite, 3),
223
- }
224
- return composite, breakdown
225
 
226
 
227
  # ═══════════════════════════════════════════════════════════════════════════════
228
  # PER-STEP REWARD
229
  # ═══════════════════════════════════════════════════════════════════════════════
230
 
231
- def step_reward(state: EnvState, action_type: ActionType, payload: str | None) -> Tuple[float, str]:
232
- """
233
- Compute per-step reward and feedback message.
234
- Now accounts for structural improvement on EDIT and RESTRUCTURE.
235
- """
236
- reward = 0.0
237
- feedback = ""
238
 
 
239
  if action_type == ActionType.RETRIEVE:
240
  if state.retrieval_count >= 3:
241
- reward = -0.15
242
- feedback = "Redundant retrieval β€” you've already retrieved 3 times."
243
  elif state.available_passages:
244
- reward = 0.05
245
- feedback = "Passages retrieved."
246
  else:
247
- reward = -0.05
248
- feedback = "No passages available for retrieval."
249
 
250
- elif action_type == ActionType.EDIT:
 
251
  if not payload:
252
- reward = -0.10
253
- feedback = "Empty edit β€” no content provided."
254
- else:
255
- # factual delta
256
- old_sim = token_f1(state.current_answer, state.ground_truth_answer)
257
- new_sim = token_f1(payload, state.ground_truth_answer)
258
- fact_delta = new_sim - old_sim
259
-
260
- # structural delta
261
- old_struct, _ = structural_quality(state.current_answer, state.structural_meta)
262
- new_struct, bk = structural_quality(payload, state.structural_meta)
263
- struct_delta = new_struct - old_struct
264
-
265
- combined_delta = 0.6 * fact_delta + 0.4 * struct_delta
266
-
267
- if combined_delta > 0.03:
268
- reward = 0.20 + combined_delta
269
- feedback = f"Edit improved answer (fact Ξ”{fact_delta:+.2f}, struct Ξ”{struct_delta:+.2f})."
270
- elif combined_delta < -0.03:
271
- reward = -0.20
272
- feedback = f"Edit degraded answer (fact Ξ”{fact_delta:+.2f}, struct Ξ”{struct_delta:+.2f})."
273
- else:
274
- reward = -0.05
275
- feedback = "Edit had negligible effect."
276
 
277
- elif action_type == ActionType.RESTRUCTURE:
278
- if not payload:
279
- reward = -0.10
280
- feedback = "Empty restructure β€” no content provided."
281
- else:
282
- # restructure should preserve facts but improve structure
283
- old_sim = token_f1(state.current_answer, state.ground_truth_answer)
284
- new_sim = token_f1(payload, state.ground_truth_answer)
285
- fact_delta = new_sim - old_sim
286
-
287
- old_struct, _ = structural_quality(state.current_answer, state.structural_meta)
288
- new_struct, bk = structural_quality(payload, state.structural_meta)
289
- struct_delta = new_struct - old_struct
290
-
291
- if fact_delta < -0.10:
292
- # restructure destroyed factual content
293
- reward = -0.25
294
- feedback = f"Restructure lost factual content (fact Ξ”{fact_delta:+.2f}). Use EDIT if changing facts."
295
- elif struct_delta > 0.05:
296
- reward = 0.25 + struct_delta
297
- feedback = (
298
- f"Restructure improved structure (Ξ”{struct_delta:+.2f}). "
299
- f"Breakdown: term={bk['terminology']:.2f} comp={bk['completeness']:.2f} "
300
- f"order={bk['ordering']:.2f} coh={bk['coherence']:.2f}"
301
- )
302
- elif struct_delta < -0.03:
303
- reward = -0.15
304
- feedback = f"Restructure degraded structure (Ξ”{struct_delta:+.2f})."
305
  else:
306
- reward = -0.05
307
- feedback = "Restructure had negligible structural effect."
308
-
309
- elif action_type == ActionType.CITE:
310
- if not payload:
311
- reward = -0.05
312
- feedback = "Empty citation."
313
- else:
314
- cr = citation_recall([payload], state.ground_truth_citations)
315
- if cr > 0:
316
- reward = 0.15
317
- feedback = "Correct citation added."
318
  else:
319
- reward = -0.05
320
- feedback = "Citation does not match expected sources."
321
 
322
- elif action_type in (ActionType.ACCEPT, ActionType.REJECT):
323
- pass # terminal rewards handled separately
 
 
 
324
 
325
- return reward, feedback
 
 
 
 
326
 
327
 
328
  # ═══════════════════════════════════════════════════════════════════════════════
329
  # TERMINAL REWARD
330
  # ═══════════════════════════════════════════════════════════════════════════════
331
 
332
- def terminal_reward(state: EnvState, action_type: ActionType) -> Tuple[float, str]:
333
- """
334
- Terminal reward blends factual quality AND structural quality.
335
- """
 
 
336
  if action_type == ActionType.REJECT:
 
 
 
 
 
 
 
 
337
  if not state.answer_is_correct:
338
- return 0.30, "Correctly rejected a flawed answer."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  else:
340
- return -0.50, "Incorrectly rejected a valid answer."
341
-
342
- # ── ACCEPT ────────────────────────────────────────────────────────────
343
- # factual component
344
- answer_sim = token_f1(state.current_answer, state.ground_truth_answer)
345
- cit_score = citation_recall(state.current_citations, state.ground_truth_citations)
346
-
347
- # structural component
348
- struct_score, struct_breakdown = structural_quality(
349
- state.current_answer, state.structural_meta
 
 
 
350
  )
351
-
352
- # efficiency bonus (0–0.2)
353
  efficiency = 0.20 * (state.steps_remaining / state.max_steps)
 
354
 
355
- # weighted terminal reward
356
- terminal = (
357
- 0.90 * answer_sim # factual similarity (max 0.90)
358
- + 0.30 * cit_score # citation recall (max 0.30)
359
- + 0.70 * struct_score # structural quality (max 0.70)
360
- + efficiency # efficiency bonus (max 0.20)
361
- )
362
- # theoretical max β‰ˆ 2.10
363
-
364
- # penalty for accepting a still-bad answer
365
- if answer_sim < 0.3 and struct_score < 0.3:
366
  terminal -= 0.50
367
- quality_label = "poor"
368
- elif answer_sim < 0.5:
369
- quality_label = "mediocre"
370
  else:
371
- quality_label = "good"
372
-
373
- feedback = (
374
- f"Accepted a {quality_label} answer "
375
- f"(fact={answer_sim:.2f}, cite={cit_score:.2f}, struct={struct_score:.2f} "
376
- f"[term={struct_breakdown['terminology']:.2f} "
377
- f"comp={struct_breakdown['completeness']:.2f} "
378
- f"ord={struct_breakdown['ordering']:.2f} "
379
- f"coh={struct_breakdown['coherence']:.2f}])"
380
- )
381
 
382
- return terminal, feedback
383
 
384
 
385
  # ═══════════════════════════════════════════════════════════════════════════════
386
  # SCORE NORMALISATION
387
  # ═══════════════════════════════════════════════════════════════════════════════
388
 
389
- # theoretical max: terminal ~2.10 + step bonuses ~0.5 β‰ˆ 2.6
390
  MAX_REASONABLE_REWARD = 2.80
391
 
392
 
393
  def normalize_score(cumulative_reward: float) -> float:
394
  """Clamp cumulative reward into [0, 1]."""
395
- score = cumulative_reward / MAX_REASONABLE_REWARD
396
- return max(0.0, min(1.0, score))
 
1
  """
2
+ Reward computation for IndicScriptureQA β€” LLM-as-a-Judge.
3
 
4
+ Uses an LLM (via OpenAI client) to evaluate both factual accuracy and
5
+ semantic structure quality. Falls back to lightweight token heuristics
6
+ if the LLM call fails.
7
 
8
+ Environment variables (shared with inference.py):
9
+ API_BASE_URL LLM endpoint
10
+ MODEL_NAME Model identifier
11
+ HF_TOKEN API key
12
  """
13
 
14
  from __future__ import annotations
15
 
16
+ import json
17
+ import os
18
  import re
19
+ from typing import List, Optional, Tuple
20
+
21
+ from openai import OpenAI
22
 
23
  from models import ActionType, EnvState, StructuralMeta
24
 
25
 
26
  # ═══════════════════════════════════════════════════════════════════════════════
27
+ # LLM CLIENT
28
  # ═══════════════════════════════════════════════════════════════════════════════
29
 
30
+ _client: Optional[OpenAI] = None
31
+
32
+
33
+ def _get_client() -> OpenAI:
34
+ global _client
35
+ if _client is None:
36
+ _client = OpenAI(
37
+ base_url=os.getenv("API_BASE_URL", "https://router.huggingface.co/v1"),
38
+ api_key=os.getenv("HF_TOKEN") or os.getenv("API_KEY", ""),
39
+ )
40
+ return _client
41
+
42
+
43
+ def _get_model() -> str:
44
+ return os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
45
+
46
+
47
+ def _llm_judge(system: str, user_prompt: str) -> Optional[dict]:
48
+ """Call the LLM and parse a JSON response. Returns None on any failure."""
49
+ try:
50
+ client = _get_client()
51
+ resp = client.chat.completions.create(
52
+ model=_get_model(),
53
+ messages=[
54
+ {"role": "system", "content": system},
55
+ {"role": "user", "content": user_prompt},
56
+ ],
57
+ temperature=0.1,
58
+ max_tokens=500,
59
+ )
60
+ raw = (resp.choices[0].message.content or "").strip()
61
+ if raw.startswith("```"):
62
+ raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
63
+ return json.loads(raw)
64
+ except Exception as exc:
65
+ print(f"[JUDGE] LLM call failed, using fallback: {exc}", flush=True)
66
+ return None
 
67
 
68
 
69
  # ═══════════════════════════════════════════════════════════════════════════════
70
+ # JUDGE PROMPTS
71
  # ═══════════════════════════════════════════════════════════════════════════════
72
 
73
+ JUDGE_SYSTEM = (
74
+ "You are an expert judge evaluating answers about Indic scriptures "
75
+ "(Vedas, Upanishads, Ramayana, Mahabharata, Bhagavad Gita, Puranas). "
76
+ "You evaluate both factual accuracy and semantic structure quality.\n\n"
77
+ "Respond with ONLY a valid JSON object. No markdown fences, no "
78
+ "explanation, no text outside the JSON braces."
79
+ )
80
+
81
+
82
+ def _terminal_accept_prompt(state: EnvState) -> str:
83
+ return json.dumps({
84
+ "task": "Score the candidate answer against the reference on all axes.",
85
+ "question": state.question,
86
+ "candidate_answer": state.current_answer,
87
+ "reference_answer": state.ground_truth_answer,
88
+ "candidate_citations": state.current_citations,
89
+ "expected_citations": state.ground_truth_citations,
90
+ "structural_requirements": {
91
+ "required_terms": state.structural_meta.required_terms,
92
+ "required_sections": state.structural_meta.required_sections,
93
+ "expected_order": state.structural_meta.expected_order,
94
+ "banned_terms": state.structural_meta.banned_terms,
95
+ },
96
+ "output_format": {
97
+ "factual_score": "0.0-1.0: semantic accuracy of candidate vs reference",
98
+ "citation_score": "0.0-1.0: fraction of expected citations covered",
99
+ "terminology_score": "-0.5 to 1.0: correct Sanskrit/domain terms present; NEGATIVE if banned terms found",
100
+ "completeness_score": "0.0-1.0: all required conceptual sections covered",
101
+ "ordering_score": "0.0-1.0: concepts appear in expected logical sequence",
102
+ "coherence_score": "0.0-1.0: smooth transitions, balanced structure, readable flow",
103
+ "feedback": "one-sentence summary of quality",
104
+ },
105
+ }, indent=2)
106
+
107
+
108
+ def _terminal_reject_prompt(state: EnvState) -> str:
109
+ return json.dumps({
110
+ "task": "Judge whether this answer deserves rejection.",
111
+ "question": state.question,
112
+ "candidate_answer": state.current_answer,
113
+ "reference_answer": state.ground_truth_answer,
114
+ "structural_requirements": {
115
+ "required_terms": state.structural_meta.required_terms,
116
+ "banned_terms": state.structural_meta.banned_terms,
117
+ },
118
+ "output_format": {
119
+ "answer_is_flawed": "boolean: true if the answer has significant factual or structural problems",
120
+ "feedback": "one-sentence explanation",
121
+ },
122
+ }, indent=2)
123
+
124
+
125
+ def _step_delta_prompt(
126
+ state: EnvState,
127
+ action_type: ActionType,
128
+ old_answer: str,
129
+ new_answer: str,
130
+ ) -> str:
131
+ if action_type == ActionType.EDIT:
132
+ focus = "Focus on FACTUAL improvement (60%) and STRUCTURAL improvement (40%)."
133
  else:
134
+ focus = (
135
+ "Focus primarily on STRUCTURAL improvement (ordering, terminology, "
136
+ "coherence). Penalise heavily if factual content was lost."
137
+ )
138
+ return json.dumps({
139
+ "task": f"Evaluate whether this {action_type.value} improved the answer.",
140
+ "focus": focus,
141
+ "question": state.question,
142
+ "old_answer": old_answer,
143
+ "new_answer": new_answer,
144
+ "reference_answer": state.ground_truth_answer,
145
+ "structural_requirements": {
146
+ "required_terms": state.structural_meta.required_terms,
147
+ "required_sections": state.structural_meta.required_sections,
148
+ "expected_order": state.structural_meta.expected_order,
149
+ "banned_terms": state.structural_meta.banned_terms,
150
+ },
151
+ "output_format": {
152
+ "factual_delta": "-1.0 to 1.0 (positive = factual improvement)",
153
+ "structural_delta": "-1.0 to 1.0 (positive = structural improvement)",
154
+ "feedback": "one-sentence explanation of what changed",
155
+ },
156
+ }, indent=2)
157
 
 
158
 
159
+ # ═══════════════════════════════════════════════════════════════════════════════
160
+ # FALLBACK HEURISTICS (used when LLM is unavailable)
161
+ # ═══════════════════════════════════════════════════════════════════════════════
162
 
163
+ def _tokenize(text: str) -> List[str]:
164
+ return [t for t in re.split(r"[^a-zA-Z0-9\u0900-\u097F]+", text.lower()) if t]
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ def _token_f1(candidate: str, reference: str) -> float:
168
+ cand = set(_tokenize(candidate))
169
+ ref = set(_tokenize(reference))
170
+ if not cand or not ref:
171
+ return 0.0
172
+ common = cand & ref
173
+ if not common:
174
+ return 0.0
175
+ p, r = len(common) / len(cand), len(common) / len(ref)
176
+ return 2 * p * r / (p + r)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
 
178
 
179
+ def _citation_recall_heuristic(predicted: List[str], ground_truth: List[str]) -> float:
180
+ if not ground_truth:
181
+ return 1.0
182
+ norm = lambda s: re.sub(r"\s+", " ", s.strip().lower())
183
+ gt = [norm(g) for g in ground_truth]
184
+ pr = [norm(p) for p in predicted]
185
+ matched = sum(1 for g in gt if any(g in p or p in g for p in pr))
186
+ return matched / len(gt)
187
 
188
 
189
  # ═══════════════════════════════════════════════════════════════════════════════
190
  # PER-STEP REWARD
191
  # ═══════════════════════════════════════════════════════════════════════════════
192
 
193
+ def step_reward(
194
+ state: EnvState, action_type: ActionType, payload: str | None,
195
+ ) -> Tuple[float, str]:
196
+ """Compute per-step reward and feedback. Uses LLM judge for EDIT/RESTRUCTURE."""
 
 
 
197
 
198
+ # ── RETRIEVE ──────────────────────────────────────────────────────────
199
  if action_type == ActionType.RETRIEVE:
200
  if state.retrieval_count >= 3:
201
+ return -0.15, "Redundant retrieval β€” already retrieved 3 times."
 
202
  elif state.available_passages:
203
+ return 0.05, "Passages retrieved."
 
204
  else:
205
+ return -0.05, "No passages available for retrieval."
 
206
 
207
+ # ── CITE ──────────────────────────────────────────────────────────────
208
+ if action_type == ActionType.CITE:
209
  if not payload:
210
+ return -0.05, "Empty citation."
211
+ cr = _citation_recall_heuristic([payload], state.ground_truth_citations)
212
+ if cr > 0:
213
+ return 0.15, "Correct citation added."
214
+ return -0.05, "Citation does not match expected sources."
215
+
216
+ # ── ACCEPT / REJECT ──────────────────────────────────────────────────
217
+ if action_type in (ActionType.ACCEPT, ActionType.REJECT):
218
+ return 0.0, ""
219
+
220
+ # ── EDIT / RESTRUCTURE β€” LLM judge ───────────────────────────────────
221
+ if not payload:
222
+ return -0.10, f"Empty {action_type.value.lower()} β€” no content provided."
223
+
224
+ old_answer = state.current_answer
225
+ result = _llm_judge(
226
+ JUDGE_SYSTEM,
227
+ _step_delta_prompt(state, action_type, old_answer, payload),
228
+ )
 
 
 
 
 
229
 
230
+ if result is not None:
231
+ fd = max(-1.0, min(1.0, float(result.get("factual_delta", 0.0))))
232
+ sd = max(-1.0, min(1.0, float(result.get("structural_delta", 0.0))))
233
+ fb = result.get("feedback", "")
234
+
235
+ if action_type == ActionType.EDIT:
236
+ combined = 0.6 * fd + 0.4 * sd
237
+ if combined > 0.03:
238
+ return 0.20 + combined, f"Edit improved (fact Ξ”{fd:+.2f}, struct Ξ”{sd:+.2f}). {fb}"
239
+ elif combined < -0.03:
240
+ return -0.20, f"Edit degraded (fact Ξ”{fd:+.2f}, struct Ξ”{sd:+.2f}). {fb}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  else:
242
+ return -0.05, f"Edit had negligible effect. {fb}"
243
+
244
+ else: # RESTRUCTURE
245
+ if fd < -0.10:
246
+ return -0.25, f"Restructure lost factual content (fact Ξ”{fd:+.2f}). {fb}"
247
+ elif sd > 0.05:
248
+ return 0.25 + sd, f"Restructure improved structure (Ξ”{sd:+.2f}). {fb}"
249
+ elif sd < -0.03:
250
+ return -0.15, f"Restructure degraded structure (Ξ”{sd:+.2f}). {fb}"
 
 
 
251
  else:
252
+ return -0.05, f"Restructure had negligible effect. {fb}"
 
253
 
254
+ # ── Fallback: token-F1 delta ──────────────────────────────────────────
255
+ old_sim = _token_f1(old_answer, state.ground_truth_answer)
256
+ new_sim = _token_f1(payload, state.ground_truth_answer)
257
+ delta = new_sim - old_sim
258
+ label = action_type.value
259
 
260
+ if delta > 0.03:
261
+ return 0.20 + delta, f"{label} improved (Ξ”{delta:+.2f}, fallback scoring)."
262
+ elif delta < -0.03:
263
+ return -0.20, f"{label} degraded (Ξ”{delta:+.2f}, fallback scoring)."
264
+ return -0.05, f"{label} negligible effect (fallback scoring)."
265
 
266
 
267
  # ═══════════════════════════════════════════════════════════════════════════════
268
  # TERMINAL REWARD
269
  # ═══════════════════════════════════════════════════════════════════════════════
270
 
271
+ def terminal_reward(
272
+ state: EnvState, action_type: ActionType,
273
+ ) -> Tuple[float, str]:
274
+ """Terminal reward using LLM-as-a-judge, with heuristic fallback."""
275
+
276
+ # ── REJECT ────────────────────────────────────────────────────────────
277
  if action_type == ActionType.REJECT:
278
+ result = _llm_judge(JUDGE_SYSTEM, _terminal_reject_prompt(state))
279
+ if result is not None:
280
+ is_flawed = result.get("answer_is_flawed", True)
281
+ fb = result.get("feedback", "")
282
+ if is_flawed:
283
+ return 0.30, f"Correctly rejected a flawed answer. {fb}"
284
+ else:
285
+ return -0.50, f"Incorrectly rejected a valid answer. {fb}"
286
  if not state.answer_is_correct:
287
+ return 0.30, "Correctly rejected a flawed answer (fallback)."
288
+ return -0.50, "Incorrectly rejected a valid answer (fallback)."
289
+
290
+ # ── ACCEPT β€” LLM judge ────────────────────────────────────────────────
291
+ result = _llm_judge(JUDGE_SYSTEM, _terminal_accept_prompt(state))
292
+
293
+ if result is not None:
294
+ fs = max(0.0, min(1.0, float(result.get("factual_score", 0.0))))
295
+ cs = max(0.0, min(1.0, float(result.get("citation_score", 0.0))))
296
+ ts = max(-0.5, min(1.0, float(result.get("terminology_score", 0.0))))
297
+ comp = max(0.0, min(1.0, float(result.get("completeness_score", 0.0))))
298
+ os_ = max(0.0, min(1.0, float(result.get("ordering_score", 0.0))))
299
+ coh = max(0.0, min(1.0, float(result.get("coherence_score", 0.0))))
300
+ fb = result.get("feedback", "")
301
+
302
+ # structural composite
303
+ struct_score = 0.30 * max(ts, 0.0) + 0.25 * comp + 0.25 * os_ + 0.20 * coh
304
+ if ts < 0:
305
+ struct_score += 0.15 * ts
306
+ struct_score = max(0.0, min(1.0, struct_score))
307
+
308
+ efficiency = 0.20 * (state.steps_remaining / state.max_steps)
309
+
310
+ terminal = 0.90 * fs + 0.30 * cs + 0.70 * struct_score + efficiency
311
+
312
+ if fs < 0.3 and struct_score < 0.3:
313
+ terminal -= 0.50
314
+ quality = "poor"
315
+ elif fs < 0.5:
316
+ quality = "mediocre"
317
  else:
318
+ quality = "good"
319
+
320
+ feedback = (
321
+ f"Accepted a {quality} answer "
322
+ f"(fact={fs:.2f}, cite={cs:.2f}, struct={struct_score:.2f} "
323
+ f"[term={ts:.2f} comp={comp:.2f} ord={os_:.2f} coh={coh:.2f}]). {fb}"
324
+ )
325
+ return terminal, feedback
326
+
327
+ # ── Fallback: heuristic scoring ───────────────────────────────────────
328
+ fs = _token_f1(state.current_answer, state.ground_truth_answer)
329
+ cs = _citation_recall_heuristic(
330
+ state.current_citations, state.ground_truth_citations,
331
  )
 
 
332
  efficiency = 0.20 * (state.steps_remaining / state.max_steps)
333
+ terminal = 0.90 * fs + 0.30 * cs + efficiency
334
 
335
+ if fs < 0.3:
 
 
 
 
 
 
 
 
 
 
336
  terminal -= 0.50
337
+ quality = "poor"
338
+ elif fs < 0.5:
339
+ quality = "mediocre"
340
  else:
341
+ quality = "good"
 
 
 
 
 
 
 
 
 
342
 
343
+ return terminal, f"Accepted a {quality} answer (fact={fs:.2f}, cite={cs:.2f}, fallback)."
344
 
345
 
346
  # ═══════════════════════════════════════════════════════════════════════════════
347
  # SCORE NORMALISATION
348
  # ═══════════════════════════════════════════════════════════════════════════════
349
 
 
350
  MAX_REASONABLE_REWARD = 2.80
351
 
352
 
353
  def normalize_score(cumulative_reward: float) -> float:
354
  """Clamp cumulative reward into [0, 1]."""
355
+ return max(0.0, min(1.0, cumulative_reward / MAX_REASONABLE_REWARD))
 
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Server entry point for IndicScriptureQA β€” OpenEnv compatible.
3
+
4
+ Exposes the FastAPI app and a `main()` callable for the `server` script.
5
+ """
6
+
7
+ import uvicorn
8
+
9
+ from main import app # noqa: F401 β€” re-export for openenv discovery
10
+
11
+
12
+ def main() -> None:
13
+ """Entry point used by `[project.scripts] server`."""
14
+ uvicorn.run("main:app", host="0.0.0.0", port=7860)
15
+
16
+
17
+ if __name__ == "__main__":
18
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff