NITISHRG15102007 commited on
Commit
e65a128
·
verified ·
1 Parent(s): 90e57c3

Harden env isolation and proxy validation

Browse files
Files changed (7) hide show
  1. app.py +222 -555
  2. env/environment.py +515 -514
  3. env/graders.py +145 -144
  4. inference.py +40 -37
  5. tests/test_api.py +47 -0
  6. tests/test_inference_proxy.py +119 -0
  7. validate.py +84 -26
app.py CHANGED
@@ -1,336 +1,173 @@
1
- """
2
- FastAPI server exposing the rag-context-optimizer OpenEnv HTTP API.
3
- """
4
-
5
- from __future__ import annotations
6
-
7
  from contextlib import asynccontextmanager
8
  from dataclasses import asdict, is_dataclass
 
9
  from pathlib import Path
10
  from typing import Any, Literal
11
-
 
12
  from fastapi import Body, FastAPI, HTTPException, Request
13
- from fastapi.middleware.cors import CORSMiddleware
14
- from fastapi.responses import HTMLResponse
15
- from pydantic import BaseModel
16
-
 
17
  from env.environment import RagContextOptimizerEnv
18
  from env.models import RagAction
19
- from env.corpus import list_corpus_families
20
  from env.prompt_optimizer import CompressionMode, optimize_prompt
21
  from env.tasks import ALL_TASKS, TASKS_BY_NAME
22
-
23
-
24
  class ResetRequest(BaseModel):
25
  task_name: Literal["single_domain_qa", "cross_domain_synthesis", "adversarial_compression"] = "single_domain_qa"
26
  custom_query: str | None = None
27
  token_budget: int | None = None
28
  max_steps: int | None = None
29
  corpus_family: str | None = None
30
-
31
-
32
  class OptimizePromptRequest(BaseModel):
33
  prompt: str
34
  corpus_family: str | None = None
35
  compression_mode: CompressionMode = "balanced"
36
-
37
-
38
- @asynccontextmanager
39
- async def lifespan(app: FastAPI):
40
- env = RagContextOptimizerEnv()
41
- await env.reset()
42
- app.state.env = env
43
- yield
44
- await app.state.env.close()
45
-
46
-
47
- app = FastAPI(
48
- title="rag-context-optimizer",
49
- version="1.0.0",
50
- description="RAG pipeline optimization environment — minimize tokens, maximize answer quality",
51
- lifespan=lifespan,
52
- )
53
-
54
- app.add_middleware(
55
- CORSMiddleware,
56
- allow_origins=["*"],
57
- allow_credentials=True,
58
- allow_methods=["*"],
59
- allow_headers=["*"],
60
- )
61
-
62
- UI_TEMPLATE_PATH = Path(__file__).resolve().parent / "server" / "templates" / "ui.html"
63
-
64
-
65
-
66
-
67
-
68
- @app.middleware("http")
69
- async def log_requests(request: Request, call_next):
70
- print(f"[request] {request.method} {request.url.path}")
71
- response = await call_next(request)
72
- print(f"[response] {request.method} {request.url.path} -> {response.status_code}")
73
- return response
74
-
75
-
76
- @app.get("/", response_class=HTMLResponse)
77
- async def home_page():
78
- return HTMLResponse(
79
- UI_TEMPLATE_PATH.read_text(encoding="utf-8"),
80
- headers={
81
- "Cache-Control": "no-store, max-age=0",
82
- "Pragma": "no-cache",
83
- },
84
- )
85
-
86
-
87
- def _serialize_observation(observation: Any) -> dict[str, Any]:
88
- if hasattr(observation, "model_dump"):
89
- return observation.model_dump()
90
- if is_dataclass(observation):
91
- return asdict(observation)
92
- return dict(observation)
93
-
94
-
95
- def _serialize_step_result(result: Any, reset: bool = False) -> dict[str, Any]:
96
- raw_info = result.info or {}
97
- if reset:
98
- return {
99
- "observation": _serialize_observation(result.observation),
100
- "reward": None,
101
- "done": False,
102
- "info": {},
103
- }
104
- return {
105
- "observation": _serialize_observation(result.observation),
106
- "reward": result.reward,
107
- "done": result.done,
108
- "info": {
109
- "grader_breakdown": raw_info.get("grader"),
110
- "event": raw_info.get("event"),
111
- "passed": raw_info.get("passed"),
112
- },
113
- }
114
-
115
-
116
- def _is_bad_action_event(event: str | None) -> bool:
117
- return event in {
118
- "chunk_not_found",
119
- }
120
-
121
-
122
- def _tokenize(text: str) -> set[str]:
123
- import re
124
-
125
- return set(re.findall(r"[a-z0-9]+", text.lower()))
126
-
127
-
128
- def _content_terms(text: str) -> set[str]:
129
- return {term for term in _tokenize(text) if len(term) > 2 and term not in _PROMPT_STOPWORDS}
130
-
131
-
132
- def _clean_output_text(text: str) -> str:
133
- import re
134
-
135
- cleaned = text.replace("```", " ").replace("---", " ")
136
- cleaned = re.sub(r"\s+", " ", cleaned).strip()
137
- cleaned = re.sub(r"[#*_`]+", "", cleaned)
138
- cleaned = re.sub(r'\b(title|emoji|colorfrom|colorto|sdk|app_file|pinned)\s*:\s*', "", cleaned, flags=re.IGNORECASE)
139
- return cleaned.strip(" -:")
140
-
141
-
142
- def _compact_text(text: str, max_words: int = 28) -> str:
143
- words = text.split()
144
- if len(words) <= max_words:
145
- return text
146
- return " ".join(words[:max_words]).rstrip(" ,;:") + " ..."
147
-
148
-
149
- _PROMPT_STOPWORDS = {
150
- "a","an","and","are","as","at","be","but","by","can","could","do","does","did",
151
- "for","from","had","has","have","how","i","if","in","into","is","it","its","me",
152
- "my","of","on","or","our","should","so","than","that","the","their","them","then",
153
- "there","these","they","this","to","too","use","using","was","we","were","what",
154
- "when","where","which","while","with","without","would","you","your",
155
- }
156
-
157
-
158
- def _approx_tokens(text: str) -> int:
159
- return max(1, len(text.strip()) // 4) if text.strip() else 0
160
-
161
-
162
- def _compress_prompt_text(prompt: str, target_tokens: int) -> str:
163
- import re
164
-
165
- raw = " ".join(prompt.strip().split())
166
- if not raw:
167
- return ""
168
-
169
- tokens = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/]*", raw)
170
- kept: list[str] = []
171
- seen: set[str] = set()
172
-
173
- # Keep “meaningful” tokens: numbers, identifiers, longer words, and acronyms. Drop stopwords.
174
- for tok in tokens:
175
- low = tok.lower()
176
- is_number = low.isdigit()
177
- is_identifier = any(ch in tok for ch in ("_", "-", "/")) and len(tok) >= 4
178
- is_acronym = tok.isupper() and len(tok) <= 8
179
- is_meaningful = is_number or is_identifier or is_acronym or len(low) >= 4
180
- if not is_meaningful:
181
- continue
182
- if low in _PROMPT_STOPWORDS:
183
- continue
184
- if low in seen:
185
- continue
186
- seen.add(low)
187
- kept.append(tok)
188
- if len(kept) >= max(10, target_tokens):
189
- break
190
-
191
- if not kept:
192
- # Fallback: truncated raw prompt.
193
- words = raw.split()
194
- return " ".join(words[: max(8, target_tokens)]).rstrip(" ,;:") + (" ..." if len(words) > target_tokens else "")
195
-
196
- # Turn the token list into a copy-paste-ready “goal” sentence.
197
- goal = " ".join(kept)
198
- goal = re.sub(r"\s+", " ", goal).strip()
199
- return goal
200
-
201
-
202
- _INSTRUCTION_PRIORITY_TERMS = {
203
- "must","should","only","not","never","always","include","exclude","cite","answer",
204
- "return","draft","write","summarize","compare","explain","verify","preserve","focus",
205
- "keep","avoid","report","escalate","rollback","refund","incident","customer","security",
206
- }
207
-
208
-
209
- def _trim_sentence(sentence: str, max_terms: int) -> str:
210
- import re
211
-
212
- words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\\-_/]*|[,:;()]", sentence)
213
- if not words:
214
- return ""
215
- kept: list[str] = []
216
-
217
- for index, token in enumerate(words):
218
- normalized = re.sub(r"[^A-Za-z0-9]+", "", token).lower()
219
- if token in {",", ":", ";", "(", ")"}:
220
- if kept and kept[-1] not in {",", ":", ";", "("}:
221
- kept.append(token)
222
- continue
223
- is_priority = normalized in _INSTRUCTION_PRIORITY_TERMS
224
- is_meaningful = (
225
- normalized.isdigit()
226
- or any(ch in token for ch in ("_", "-", "/"))
227
- or len(normalized) >= 4
228
- or is_priority
229
- or index < 3
230
- )
231
- if not is_meaningful:
232
- continue
233
- if normalized in _PROMPT_STOPWORDS and not is_priority and index >= 3:
234
- continue
235
- kept.append(token)
236
- if len([word for word in kept if word not in {",", ":", ";", "(", ")"}]) >= max_terms:
237
- break
238
-
239
- text = " ".join(kept)
240
- text = re.sub(r"\s+([,:;)])", r"\1", text)
241
- text = re.sub(r"(\()\s+", r"\1", text)
242
- return text.strip(" ,;:")
243
-
244
-
245
- def _rewrite_prompt_text(prompt: str, target_tokens: int) -> str:
246
- import re
247
-
248
- raw = " ".join(prompt.strip().split())
249
- if not raw:
250
- return ""
251
-
252
- sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", raw) if segment.strip()]
253
- if not sentences:
254
- sentences = [raw]
255
-
256
- rewritten: list[str] = []
257
- used_terms = 0
258
- max_terms = max(8, target_tokens)
259
- for index, sentence in enumerate(sentences):
260
- remaining = max_terms - used_terms
261
- if remaining <= 0:
262
- break
263
- compact = _trim_sentence(sentence, max(4, remaining if index == 0 else min(remaining, 10)))
264
- if not compact:
265
- continue
266
- rewritten.append(compact)
267
- used_terms += len(compact.split())
268
- if used_terms >= max_terms:
269
- break
270
-
271
- if not rewritten:
272
- fallback = _trim_sentence(raw, max_terms)
273
- return fallback or raw[: max(16, target_tokens * 4)].strip()
274
-
275
- output = ". ".join(rewritten).strip()
276
- if len(rewritten) == 1 and not output.endswith("."):
277
- output += "."
278
- return output
279
-
280
-
281
- def _fit_citations_into_prompt(base_prompt: str, citation_ids: list[str], input_tokens: int, target_tokens: int, source_prompt: str) -> tuple[str, bool, str | None]:
282
- if not citation_ids:
283
- return base_prompt, False, "No high-confidence evidence anchors were selected."
284
-
285
- citation_suffix = " Evidence: " + " ".join(f"[{chunk_id}]" for chunk_id in citation_ids[:3])
286
- with_all = (base_prompt.rstrip(".") + "." + citation_suffix).strip()
287
- if _approx_tokens(with_all) < input_tokens:
288
- return with_all, True, None
289
-
290
- one_citation_suffix = " Evidence: " + f"[{citation_ids[0]}]"
291
- with_one = (base_prompt.rstrip(".") + "." + one_citation_suffix).strip()
292
- if _approx_tokens(with_one) < input_tokens:
293
- return with_one, True, None
294
-
295
- tighter_target = max(8, target_tokens - 3)
296
- tighter_prompt = _rewrite_prompt_text(source_prompt, tighter_target)
297
- tighter_with_one = (tighter_prompt.rstrip(".") + "." + one_citation_suffix).strip()
298
- if _approx_tokens(tighter_with_one) < input_tokens:
299
- return tighter_with_one, True, None
300
-
301
- return base_prompt, False, "Citations were omitted to keep the optimized prompt shorter than the original. Use the evidence notes below if explicit anchors are required."
302
-
303
-
304
- def _summarize_chunk_for_output(chunk: Any, effective_text: str) -> str:
305
- if getattr(chunk, "domain", "").startswith("Project"):
306
- keywords = ", ".join(chunk.keywords[:5])
307
- domain = chunk.domain.replace("Project ", "").lower()
308
- return _compact_text(f"This benchmark's {domain} covers {keywords}.", 24)
309
- ranked_sentences = _sentence_rank(" ".join(chunk.keywords), _clean_output_text(effective_text))
310
- if ranked_sentences:
311
- return _compact_text(_clean_output_text(ranked_sentences[0]))
312
- return _compact_text(_clean_output_text(effective_text))
313
-
314
-
315
- def _sentence_rank(query: str, text: str) -> list[str]:
316
- import re
317
-
318
- query_terms = _tokenize(query)
319
- sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
320
- if not sentences:
321
- return []
322
-
323
- ranked: list[tuple[float, str]] = []
324
- for index, sentence in enumerate(sentences):
325
- sentence_terms = _tokenize(sentence)
326
- overlap = len(query_terms & sentence_terms)
327
- score = (overlap * 2.0) + (0.25 if index == 0 else 0.0)
328
- ranked.append((score, sentence))
329
-
330
- ranked.sort(key=lambda item: (-item[0], len(item[1])))
331
- return [sentence for _score, sentence in ranked]
332
-
333
-
334
  async def _optimize_prompt_backend(
335
  prompt: str,
336
  corpus_family: str | None = None,
@@ -346,172 +183,8 @@ async def _optimize_prompt_backend(
346
  "selected_keywords": result.selected_keywords,
347
  "optimization_mode": result.optimization_mode,
348
  }
349
- clean_prompt = prompt.strip()
350
- env = RagContextOptimizerEnv(
351
- task_name="single_domain_qa",
352
- query_override=clean_prompt,
353
- token_budget_override=800,
354
- max_steps_override=6,
355
- corpus_family_override=corpus_family,
356
- )
357
- await env.reset()
358
-
359
- tuning = env._last_tuning or env.context_tuner.tune(clean_prompt, env._available_chunks)
360
-
361
- ranked_candidates = []
362
- for chunk in env._available_chunks:
363
- tuned = tuning.tuned_scores.get(chunk.chunk_id)
364
- score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
365
- if score < 0.16:
366
- continue
367
- ranked_candidates.append((chunk, score, tuned))
368
- ranked_candidates.sort(
369
- key=lambda item: (
370
- -(item[1] / max(item[0].tokens, 1)),
371
- -(item[2].citation_prior if item[2] is not None else 0.0),
372
- -item[1],
373
- item[0].chunk_id,
374
- )
375
- )
376
 
377
- selected_ids: list[str] = []
378
- token_cap = 360
379
- running_tokens = 0
380
- for chunk, score, tuned in ranked_candidates:
381
- if len(selected_ids) >= 4:
382
- break
383
- if score < 0.22 and selected_ids:
384
- break
385
- projected = running_tokens + chunk.tokens
386
- if projected > token_cap and selected_ids:
387
- continue
388
- selected_ids.append(chunk.chunk_id)
389
- env._selected_chunks.append(chunk.chunk_id)
390
- running_tokens += chunk.tokens
391
-
392
- if not selected_ids and ranked_candidates:
393
- best_chunk = ranked_candidates[0][0]
394
- selected_ids.append(best_chunk.chunk_id)
395
- env._selected_chunks.append(best_chunk.chunk_id)
396
-
397
- for chunk_id in list(selected_ids):
398
- chunk = env._chunk_map().get(chunk_id)
399
- if chunk is None:
400
- continue
401
- tuned = tuning.tuned_scores.get(chunk_id)
402
- score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
403
- ratio = tuned.compression_ratio if tuned is not None else 0.5
404
- if score >= 0.75:
405
- ratio = max(ratio, 0.6)
406
- env._compression_ratios[chunk_id] = ratio
407
-
408
- input_tokens = _approx_tokens(clean_prompt)
409
- # Target: strictly shorter than input, while preserving more structure for longer prompts.
410
- if input_tokens <= 24:
411
- target_ratio = 0.85
412
- elif input_tokens <= 60:
413
- target_ratio = 0.75
414
- elif input_tokens <= 120:
415
- target_ratio = 0.68
416
- else:
417
- target_ratio = 0.62
418
- target_tokens = max(12, int(input_tokens * target_ratio))
419
- target_tokens = min(target_tokens, 80)
420
-
421
- compressed_goal = _rewrite_prompt_text(clean_prompt, target_tokens=target_tokens)
422
-
423
- # Optionally add a tiny amount of distilled context, but only if it still stays shorter overall.
424
- distilled_points: list[tuple[str, str]] = []
425
- for chunk_id in env._selected_chunks:
426
- chunk = env._chunk_map().get(chunk_id)
427
- if chunk is None:
428
- continue
429
- best = _summarize_chunk_for_output(chunk, env._effective_chunk_text(chunk_id))
430
- if best and all(existing_point != best for _existing_chunk_id, existing_point in distilled_points):
431
- distilled_points.append((chunk_id, best))
432
- if len(distilled_points) >= (2 if input_tokens < 80 else 3):
433
- break
434
-
435
- lines: list[str] = []
436
- lines.append(compressed_goal if compressed_goal else clean_prompt)
437
- if distilled_points and input_tokens >= 80:
438
- lines.append("")
439
- lines.append("Context:")
440
- lines.extend([f"- [{chunk_id}] {point}" for chunk_id, point in distilled_points])
441
- optimized_prompt = "\n".join(lines).strip()
442
-
443
- # Hard guarantee: never return an “optimized” prompt longer than the input.
444
- if input_tokens > 0 and _approx_tokens(optimized_prompt) >= input_tokens:
445
- # Enforce by character budget (tokens ~= chars/4).
446
- max_chars = max(12, (input_tokens - 1) * 4)
447
- optimized_prompt = optimized_prompt[:max_chars].rstrip(" ,;:\n\t")
448
- if optimized_prompt and not optimized_prompt.endswith("..."):
449
- optimized_prompt = optimized_prompt + " ..."
450
- # If still not strictly smaller (very small inputs), trim until it is.
451
- while input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens and len(optimized_prompt) > 12:
452
- optimized_prompt = optimized_prompt[:-6].rstrip(" ,;:\n\t") + " ..."
453
- if input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens:
454
- optimized_prompt = _rewrite_prompt_text(clean_prompt, target_tokens=max(5, input_tokens - 1))
455
- if optimized_prompt and not optimized_prompt.endswith("...") and _approx_tokens(optimized_prompt) >= input_tokens:
456
- optimized_prompt = optimized_prompt[: max(8, (input_tokens - 1) * 4)].strip() + " ..."
457
-
458
- optimized_prompt, citation_ready, citation_guidance = _fit_citations_into_prompt(
459
- optimized_prompt,
460
- tuning.suggested_citations or list(env._selected_chunks),
461
- input_tokens,
462
- target_tokens,
463
- clean_prompt,
464
- )
465
 
466
- original_prompt_tokens = input_tokens
467
- optimized_prompt_tokens = _approx_tokens(optimized_prompt)
468
- source_tokens = sum(env._chunk_map()[chunk_id].tokens for chunk_id in env._selected_chunks if chunk_id in env._chunk_map())
469
- compressed_tokens = sum(env._effective_chunk_tokens(chunk_id) for chunk_id in env._selected_chunks)
470
- evidence_terms = _content_terms(" ".join(env._effective_chunk_text(chunk_id) for chunk_id in env._selected_chunks))
471
- prompt_terms = _content_terms(optimized_prompt)
472
- inline_citations = set(re.findall(r"\[([a-z0-9_]+)\]", optimized_prompt.lower()))
473
- grounded_overlap = (len(prompt_terms & evidence_terms) / len(prompt_terms)) if prompt_terms else 0.0
474
-
475
- return {
476
- "optimized_prompt": optimized_prompt,
477
- "stats": {
478
- "selected_chunks": len(env._selected_chunks),
479
- "source_tokens": source_tokens,
480
- "compressed_context_tokens": compressed_tokens,
481
- "original_prompt_tokens": original_prompt_tokens,
482
- "optimized_prompt_tokens": optimized_prompt_tokens,
483
- "compression_gain": max(0, source_tokens - compressed_tokens),
484
- },
485
- "grounding": {
486
- "citations": tuning.suggested_citations or list(env._selected_chunks),
487
- "citation_ready": citation_ready and bool(inline_citations),
488
- "citation_guidance": citation_guidance,
489
- "grounded_overlap": round(grounded_overlap, 3),
490
- "evidence_notes": [
491
- {"chunk_id": chunk_id, "note": note}
492
- for chunk_id, note in distilled_points
493
- ],
494
- },
495
- "context_tuning": {
496
- "mode": tuning.mode,
497
- "top_demo_cases": tuning.top_demo_cases,
498
- "suggested_citations": tuning.suggested_citations,
499
- "token_dropout": tuning.token_dropout,
500
- "leave_one_out": tuning.leave_one_out,
501
- },
502
- "corpus_family": env._corpus_family,
503
- "selected_keywords": [
504
- keyword
505
- for chunk_id in env._selected_chunks
506
- for keyword in (env._chunk_map().get(chunk_id).keywords if env._chunk_map().get(chunk_id) else [])
507
- ][:10],
508
- }
509
-
510
-
511
-
512
-
513
-
514
-
515
  def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
516
  observation = env._build_observation()
517
  selected = set(observation.selected_chunks)
@@ -546,8 +219,7 @@ def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
546
  if chunk.keywords:
547
  chosen_phrases.append(f"[{chunk.chunk_id}] " + ", ".join(chunk.keywords[:2]))
548
  answer = (
549
- "Grounded answer based on selected evidence: "
550
- + "; ".join(chosen_phrases[:3])
551
  if chosen_phrases
552
  else "Grounded answer based on the currently selected evidence."
553
  )
@@ -559,37 +231,37 @@ def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
559
  for chunk in sorted(
560
  available,
561
  key=lambda chunk: (
562
- -(score_map.get(chunk.chunk_id).final_score if score_map.get(chunk.chunk_id) else 0.0)
563
- / max(chunk.tokens, 1),
564
  chunk.tokens,
565
  chunk.chunk_id,
566
  ),
567
  ):
568
  if chunk.tokens <= remaining_budget:
569
  return {"action_type": "select_chunk", "chunk_id": chunk.chunk_id}
570
-
571
- if selected_chunks:
572
- return {
573
- "action_type": "submit_answer",
574
- "answer": "Optimized answer based on the currently selected evidence.",
575
- }
576
- if available:
577
- smallest_chunk = min(available, key=lambda chunk: (chunk.tokens, chunk.chunk_id))
578
- return {
579
- "action_type": "submit_answer",
580
- "answer": (
581
- "No chunk fits within the current token budget. "
582
- f"Increase the budget to at least {smallest_chunk.tokens} tokens or choose a broader budget."
583
- ),
584
- }
585
- return {"action_type": "submit_answer", "answer": "No usable evidence was available."}
586
-
587
-
588
  @app.post("/reset")
589
  async def reset_endpoint(payload: ResetRequest | None = Body(default=None)):
590
  payload = payload or ResetRequest()
591
  if payload.task_name not in TASKS_BY_NAME:
592
  raise HTTPException(status_code=400, detail="Unknown task_name.")
 
593
  env = RagContextOptimizerEnv(
594
  task_name=payload.task_name,
595
  query_override=payload.custom_query,
@@ -597,49 +269,46 @@ async def reset_endpoint(payload: ResetRequest | None = Body(default=None)):
597
  max_steps_override=payload.max_steps,
598
  corpus_family_override=payload.corpus_family,
599
  )
600
- app.state.env = env
601
- result = await env.reset()
602
- return _serialize_step_result(result, reset=True)
603
-
604
-
605
- @app.post("/step")
606
- async def step_endpoint(action: RagAction):
607
- env = getattr(app.state, "env", None)
608
- if env is None:
609
- raise HTTPException(status_code=400, detail="Environment is not initialized. Call /reset first.")
610
-
611
- result = await env.step(action)
612
- event = (result.info or {}).get("event")
613
- if _is_bad_action_event(event):
614
- raise HTTPException(status_code=400, detail=event)
615
- return _serialize_step_result(result, reset=False)
616
-
617
-
618
- @app.get("/state")
619
- async def state_endpoint():
620
- env = getattr(app.state, "env", None)
621
- if env is None:
622
- raise HTTPException(status_code=400, detail="Environment is not initialized.")
623
- return await env.state()
624
-
625
-
626
- @app.get("/health")
627
- async def health_endpoint():
628
- return {"status": "ok", "tasks": [task.name for task in ALL_TASKS]}
629
-
630
-
631
- @app.get("/tasks")
632
  async def tasks_endpoint():
633
- return [
634
- {
635
- "name": task.name,
636
- "description": task.description,
637
- "difficulty": task.difficulty,
638
- "token_budget": task.token_budget,
639
- "query": task.query,
640
- "max_steps": task.max_steps,
641
- }
642
- for task in ALL_TASKS
643
  ]
644
 
645
 
@@ -649,13 +318,11 @@ async def corpus_families_endpoint():
649
 
650
 
651
  @app.post("/optimize-step")
652
- async def optimize_step_endpoint():
653
- env = getattr(app.state, "env", None)
654
- if env is None:
655
- raise HTTPException(status_code=400, detail="Environment is not initialized. Call /reset first.")
656
- return _suggest_action(env)
657
-
658
-
659
  @app.post("/optimize-prompt")
660
  async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
661
  if not payload.prompt.strip():
@@ -665,9 +332,9 @@ async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
665
  corpus_family=payload.corpus_family,
666
  compression_mode=payload.compression_mode,
667
  )
668
-
669
-
670
- if __name__ == "__main__":
671
- import uvicorn
672
-
673
- uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
 
1
+ """
2
+ FastAPI server exposing the rag-context-optimizer OpenEnv HTTP API.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
  from contextlib import asynccontextmanager
8
  from dataclasses import asdict, is_dataclass
9
+ import os
10
  from pathlib import Path
11
  from typing import Any, Literal
12
+ from uuid import uuid4
13
+
14
  from fastapi import Body, FastAPI, HTTPException, Request
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import HTMLResponse
17
+ from pydantic import BaseModel
18
+
19
+ from env.corpus import list_corpus_families
20
  from env.environment import RagContextOptimizerEnv
21
  from env.models import RagAction
 
22
  from env.prompt_optimizer import CompressionMode, optimize_prompt
23
  from env.tasks import ALL_TASKS, TASKS_BY_NAME
24
+
25
+
26
  class ResetRequest(BaseModel):
27
  task_name: Literal["single_domain_qa", "cross_domain_synthesis", "adversarial_compression"] = "single_domain_qa"
28
  custom_query: str | None = None
29
  token_budget: int | None = None
30
  max_steps: int | None = None
31
  corpus_family: str | None = None
32
+
33
+
34
  class OptimizePromptRequest(BaseModel):
35
  prompt: str
36
  corpus_family: str | None = None
37
  compression_mode: CompressionMode = "balanced"
38
+
39
+
40
+ class EpisodeStore:
41
+ def __init__(self, max_episodes: int = 16):
42
+ self._episodes: dict[str, RagContextOptimizerEnv] = {}
43
+ self._order: list[str] = []
44
+ self.latest_episode_id: str | None = None
45
+ self._max_episodes = max_episodes
46
+
47
+ async def close_all(self) -> None:
48
+ for env in self._episodes.values():
49
+ await env.close()
50
+ self._episodes.clear()
51
+ self._order.clear()
52
+ self.latest_episode_id = None
53
+
54
+ async def create(self, env: RagContextOptimizerEnv) -> str:
55
+ episode_id = uuid4().hex
56
+ self._episodes[episode_id] = env
57
+ self._order.append(episode_id)
58
+ self.latest_episode_id = episode_id
59
+
60
+ while len(self._order) > self._max_episodes:
61
+ stale_id = self._order.pop(0)
62
+ stale_env = self._episodes.pop(stale_id, None)
63
+ if stale_env is not None:
64
+ await stale_env.close()
65
+ if self.latest_episode_id == stale_id:
66
+ self.latest_episode_id = self._order[-1] if self._order else None
67
+ return episode_id
68
+
69
+ def get(self, episode_id: str | None) -> tuple[str, RagContextOptimizerEnv]:
70
+ resolved_id = episode_id or self.latest_episode_id
71
+ if resolved_id is None or resolved_id not in self._episodes:
72
+ raise KeyError("episode_not_found")
73
+ return resolved_id, self._episodes[resolved_id]
74
+
75
+
76
+ def _request_logging_enabled() -> bool:
77
+ return os.getenv("DEBUG_LOG_REQUESTS", "").strip().lower() in {"1", "true", "yes"}
78
+
79
+
80
+ @asynccontextmanager
81
+ async def lifespan(app: FastAPI):
82
+ app.state.episodes = EpisodeStore()
83
+ yield
84
+ await app.state.episodes.close_all()
85
+
86
+
87
+ app = FastAPI(
88
+ title="rag-context-optimizer",
89
+ version="1.0.0",
90
+ description="RAG pipeline optimization environment - minimize tokens, maximize answer quality",
91
+ lifespan=lifespan,
92
+ )
93
+
94
+ app.add_middleware(
95
+ CORSMiddleware,
96
+ allow_origins=["*"],
97
+ allow_credentials=False,
98
+ allow_methods=["*"],
99
+ allow_headers=["*"],
100
+ )
101
+
102
+ UI_TEMPLATE_PATH = Path(__file__).resolve().parent / "server" / "templates" / "ui.html"
103
+
104
+
105
+ @app.middleware("http")
106
+ async def log_requests(request: Request, call_next):
107
+ should_log = _request_logging_enabled()
108
+ if should_log:
109
+ print(f"[request] {request.method} {request.url.path}")
110
+ response = await call_next(request)
111
+ if should_log:
112
+ print(f"[response] {request.method} {request.url.path} -> {response.status_code}")
113
+ return response
114
+
115
+
116
+ @app.get("/", response_class=HTMLResponse)
117
+ async def home_page():
118
+ return HTMLResponse(
119
+ UI_TEMPLATE_PATH.read_text(encoding="utf-8"),
120
+ headers={
121
+ "Cache-Control": "no-store, max-age=0",
122
+ "Pragma": "no-cache",
123
+ },
124
+ )
125
+
126
+
127
+ def _serialize_observation(observation: Any) -> dict[str, Any]:
128
+ if hasattr(observation, "model_dump"):
129
+ return observation.model_dump()
130
+ if is_dataclass(observation):
131
+ return asdict(observation)
132
+ return dict(observation)
133
+
134
+
135
+ def _serialize_step_result(result: Any, reset: bool = False, episode_id: str | None = None) -> dict[str, Any]:
136
+ raw_info = result.info or {}
137
+ payload = {
138
+ "observation": _serialize_observation(result.observation),
139
+ "reward": None if reset else result.reward,
140
+ "done": False if reset else result.done,
141
+ "info": {} if reset else {
142
+ "grader_breakdown": raw_info.get("grader"),
143
+ "event": raw_info.get("event"),
144
+ "passed": raw_info.get("passed"),
145
+ },
146
+ }
147
+ if episode_id is not None:
148
+ payload["episode_id"] = episode_id
149
+ return payload
150
+
151
+
152
+ def _is_bad_action_event(event: str | None) -> bool:
153
+ return event in {"chunk_not_found"}
154
+
155
+
156
+ def _episode_store() -> EpisodeStore:
157
+ episodes = getattr(app.state, "episodes", None)
158
+ if episodes is None:
159
+ episodes = EpisodeStore()
160
+ app.state.episodes = episodes
161
+ return episodes
162
+
163
+
164
+ def _resolve_env(episode_id: str | None) -> tuple[str, RagContextOptimizerEnv]:
165
+ try:
166
+ return _episode_store().get(episode_id)
167
+ except KeyError as exc:
168
+ raise HTTPException(status_code=404, detail="Episode not found. Call /reset first.") from exc
169
+
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  async def _optimize_prompt_backend(
172
  prompt: str,
173
  corpus_family: str | None = None,
 
183
  "selected_keywords": result.selected_keywords,
184
  "optimization_mode": result.optimization_mode,
185
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
189
  observation = env._build_observation()
190
  selected = set(observation.selected_chunks)
 
219
  if chunk.keywords:
220
  chosen_phrases.append(f"[{chunk.chunk_id}] " + ", ".join(chunk.keywords[:2]))
221
  answer = (
222
+ "Grounded answer based on selected evidence: " + "; ".join(chosen_phrases[:3])
 
223
  if chosen_phrases
224
  else "Grounded answer based on the currently selected evidence."
225
  )
 
231
  for chunk in sorted(
232
  available,
233
  key=lambda chunk: (
234
+ -(score_map.get(chunk.chunk_id).final_score if score_map.get(chunk.chunk_id) else 0.0) / max(chunk.tokens, 1),
 
235
  chunk.tokens,
236
  chunk.chunk_id,
237
  ),
238
  ):
239
  if chunk.tokens <= remaining_budget:
240
  return {"action_type": "select_chunk", "chunk_id": chunk.chunk_id}
241
+
242
+ if selected_chunks:
243
+ return {
244
+ "action_type": "submit_answer",
245
+ "answer": "Optimized answer based on the currently selected evidence.",
246
+ }
247
+ if available:
248
+ smallest_chunk = min(available, key=lambda chunk: (chunk.tokens, chunk.chunk_id))
249
+ return {
250
+ "action_type": "submit_answer",
251
+ "answer": (
252
+ "No chunk fits within the current token budget. "
253
+ f"Increase the budget to at least {smallest_chunk.tokens} tokens or choose a broader budget."
254
+ ),
255
+ }
256
+ return {"action_type": "submit_answer", "answer": "No usable evidence was available."}
257
+
258
+
259
  @app.post("/reset")
260
  async def reset_endpoint(payload: ResetRequest | None = Body(default=None)):
261
  payload = payload or ResetRequest()
262
  if payload.task_name not in TASKS_BY_NAME:
263
  raise HTTPException(status_code=400, detail="Unknown task_name.")
264
+
265
  env = RagContextOptimizerEnv(
266
  task_name=payload.task_name,
267
  query_override=payload.custom_query,
 
269
  max_steps_override=payload.max_steps,
270
  corpus_family_override=payload.corpus_family,
271
  )
272
+ result = await env.reset()
273
+ episode_id = await _episode_store().create(env)
274
+ return _serialize_step_result(result, reset=True, episode_id=episode_id)
275
+
276
+
277
+ @app.post("/step")
278
+ async def step_endpoint(action: RagAction, episode_id: str | None = None):
279
+ resolved_episode_id, env = _resolve_env(episode_id)
280
+ result = await env.step(action)
281
+ event = (result.info or {}).get("event")
282
+ if _is_bad_action_event(event):
283
+ raise HTTPException(status_code=400, detail=event)
284
+ return _serialize_step_result(result, reset=False, episode_id=resolved_episode_id)
285
+
286
+
287
+ @app.get("/state")
288
+ async def state_endpoint(episode_id: str | None = None):
289
+ resolved_episode_id, env = _resolve_env(episode_id)
290
+ state = await env.state()
291
+ state["episode_id"] = resolved_episode_id
292
+ return state
293
+
294
+
295
+ @app.get("/health")
296
+ async def health_endpoint():
297
+ return {"status": "ok", "tasks": [task.name for task in ALL_TASKS]}
298
+
299
+
300
+ @app.get("/tasks")
 
 
 
301
  async def tasks_endpoint():
302
+ return [
303
+ {
304
+ "name": task.name,
305
+ "description": task.description,
306
+ "difficulty": task.difficulty,
307
+ "token_budget": task.token_budget,
308
+ "query": task.query,
309
+ "max_steps": task.max_steps,
310
+ }
311
+ for task in ALL_TASKS
312
  ]
313
 
314
 
 
318
 
319
 
320
  @app.post("/optimize-step")
321
+ async def optimize_step_endpoint(episode_id: str | None = None):
322
+ _resolved_episode_id, env = _resolve_env(episode_id)
323
+ return _suggest_action(env)
324
+
325
+
 
 
326
  @app.post("/optimize-prompt")
327
  async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
328
  if not payload.prompt.strip():
 
332
  corpus_family=payload.corpus_family,
333
  compression_mode=payload.compression_mode,
334
  )
335
+
336
+
337
+ if __name__ == "__main__":
338
+ import uvicorn
339
+
340
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
env/environment.py CHANGED
@@ -1,532 +1,533 @@
1
- """
2
- Main OpenEnv-style environment for rag-context-optimizer.
3
- """
4
-
5
- from __future__ import annotations
6
-
7
- from dataclasses import asdict, dataclass, is_dataclass, replace
8
- import os
9
- from pathlib import Path
10
- import re
11
- from typing import Any
12
-
13
- from env.corpus import Chunk, load_corpus, resolve_corpus_path
14
- from env.context_tuner import ContextTunedPlanner
15
- from env.graders import TaskGrader
16
- from env.models import ChunkSummary, RagAction, RagObservation
17
- from env.retriever import HybridRetriever
18
- from env.tasks import ALL_TASKS, TASKS_BY_NAME, Task
19
-
20
-
21
- @dataclass(slots=True)
22
- class StepResult:
23
- observation: RagObservation
24
- reward: float
25
- done: bool
26
- info: dict[str, Any]
27
-
28
-
29
- class RagContextOptimizerEnv:
30
- _PROJECT_STOPWORDS = {
31
- "the", "and", "for", "with", "that", "this", "from", "into", "your", "have", "will",
32
- "using", "used", "use", "into", "they", "them", "their", "about", "while", "where",
33
- "when", "what", "which", "should", "would", "could", "there", "here", "then", "than",
34
- "each", "such", "only", "also", "been", "being", "does", "did", "done", "just", "more",
35
- "most", "very", "over", "under", "like", "same", "across", "because", "through", "make",
36
- "made", "many", "much", "some", "into", "onto", "must", "need", "needs", "task", "tasks",
37
- "chunk", "chunks", "query", "prompt", "environment", "agent", "agents", "model", "models",
38
- }
39
- _PROJECT_QUERY_HINTS = {
40
- "openenv", "benchmark", "rag-context-optimizer", "readme", "docker", "fastapi", "api",
41
- "endpoint", "inference.py", "app.py", "tasks.py", "graders.py", "environment.py", "repo",
42
- "repository", "codebase", "ui", "frontend", "backend", "space", "validator",
43
- }
44
-
45
- def __init__(
46
- self,
47
- task_name: str = "single_domain_qa",
48
- query_override: str | None = None,
49
- token_budget_override: int | None = None,
50
- max_steps_override: int | None = None,
51
- corpus_family_override: str | None = None,
52
- ):
53
- if task_name not in TASKS_BY_NAME:
54
- raise ValueError(f"Unknown task_name: {task_name}")
55
-
56
  self._corpus_family = corpus_family_override or os.getenv("RAG_CORPUS_FAMILY") or "enterprise_v1"
57
  explicit_path = os.getenv("RAG_CORPUS_PATH")
58
  self._corpus_path = resolve_corpus_path(explicit_path, family=None if explicit_path else self._corpus_family)
59
  self._all_chunks = load_corpus(self._corpus_path)
60
  self._query_overridden = bool(query_override and query_override.strip())
61
- self._project_chunks = self._load_project_chunks()
 
62
  self.retriever = HybridRetriever(self._all_chunks + self._project_chunks)
63
- self.context_tuner = ContextTunedPlanner(
64
- self.retriever,
65
- self._all_chunks + self._project_chunks,
66
- list(ALL_TASKS),
67
- )
68
- self.grader = TaskGrader()
69
- self.task: Task = self._build_task(
70
- TASKS_BY_NAME[task_name],
71
- query_override=query_override,
72
- token_budget_override=token_budget_override,
73
- max_steps_override=max_steps_override,
74
- )
75
-
76
- self._available_chunks: list[Chunk] = []
77
- self._selected_chunks: list[str] = []
78
- self._compression_ratios: dict[str, float] = {}
79
- self._step_number = 0
80
- self._done = False
81
- self._last_action_feedback: str | None = None
82
- self._last_answer = ""
83
- self._last_tuning = None
84
-
85
- @staticmethod
86
- def _build_task(
87
- base_task: Task,
88
- query_override: str | None = None,
89
- token_budget_override: int | None = None,
90
- max_steps_override: int | None = None,
91
- ) -> Task:
92
- updated_task = base_task
93
- if query_override and query_override.strip():
94
- updated_task = replace(updated_task, query=query_override.strip(), domain_filter=None)
95
- if token_budget_override is not None and token_budget_override > 0:
96
- updated_task = replace(updated_task, token_budget=token_budget_override)
97
- if max_steps_override is not None and max_steps_override > 0:
98
- updated_task = replace(updated_task, max_steps=max_steps_override)
99
- return updated_task
100
-
101
- async def reset(self) -> StepResult:
102
- candidate_chunks = self._filter_chunks_for_task(self.task)
103
- self._available_chunks = self._rank_chunks_for_query(self.task.query, candidate_chunks)
104
- if not self._query_overridden:
105
- chunk_by_id = {chunk.chunk_id: chunk for chunk in candidate_chunks}
106
- for chunk_id in self.task.required_chunk_ids:
107
- chunk = chunk_by_id.get(chunk_id)
108
- if chunk and all(existing.chunk_id != chunk_id for existing in self._available_chunks):
109
- self._available_chunks.append(chunk)
110
- self._selected_chunks = []
111
- self._compression_ratios = {}
112
- self._step_number = 0
113
- self._done = False
114
- self._last_action_feedback = None
115
- self._last_answer = ""
116
-
117
- observation = self._build_observation()
118
- return StepResult(
119
- observation=observation,
120
- reward=0.0,
121
- done=False,
122
- info={"task": self.task.name, "event": "reset"},
123
- )
124
-
125
- async def step(self, action: RagAction) -> StepResult:
126
- if self._done:
127
- return StepResult(
128
- observation=self._build_observation(),
129
- reward=0.0,
130
- done=True,
131
- info={"task": self.task.name, "event": "episode_already_done"},
132
- )
133
-
134
- reward = 0.0
135
- info: dict[str, Any] = {"task": self.task.name, "action_type": action.action_type}
136
-
137
- if action.action_type == "select_chunk":
138
- reward, info = self._handle_select(action.chunk_id or "")
139
- elif action.action_type == "deselect_chunk":
140
- reward, info = self._handle_deselect(action.chunk_id or "")
141
- elif action.action_type == "compress_chunk":
142
- reward, info = self._handle_compress(action.chunk_id or "", float(action.compression_ratio or 0.0))
143
- elif action.action_type == "submit_answer":
144
- self._last_answer = action.answer or ""
145
- result = self._finalize_submission(reason="submit_answer")
146
- self._step_number += 1
147
- result.observation.step_number = self._step_number
148
- return result
149
-
150
- self._step_number += 1
151
-
152
- if self._step_number >= self.task.max_steps:
153
- return self._finalize_submission(reason="max_steps_reached")
154
-
155
- observation = self._build_observation()
156
- return StepResult(
157
- observation=observation,
158
- reward=reward,
159
- done=False,
160
- info=info,
161
- )
162
-
163
- async def state(self) -> dict:
164
- selected_chunk_details = []
165
- for chunk_id in self._selected_chunks:
166
- chunk = self._chunk_map().get(chunk_id)
167
- if chunk is None:
168
- continue
169
- selected_chunk_details.append(
170
- {
171
- "chunk_id": chunk.chunk_id,
172
- "domain": chunk.domain,
173
- "original_tokens": chunk.tokens,
174
- "effective_tokens": self._effective_chunk_tokens(chunk_id),
175
- "compression_ratio": round(self._compression_ratios.get(chunk_id, 1.0), 3),
176
- "text": self._effective_chunk_text(chunk_id),
177
- "keywords": chunk.keywords,
178
- }
179
- )
180
- optimized_prompt = self._build_optimized_prompt()
181
- return {
182
- "task": asdict(self.task) if is_dataclass(self.task) else self.task,
183
- "step_number": self._step_number,
184
- "done": self._done,
185
- "selected_chunks": list(self._selected_chunks),
186
- "compression_ratios": dict(self._compression_ratios),
187
- "total_tokens_used": self._total_tokens_used(),
188
- "token_budget": self.task.token_budget,
189
- "last_action_feedback": self._last_action_feedback,
190
- "last_answer": self._last_answer,
191
- "corpus_family": self._corpus_family,
192
- "corpus_path": str(self._corpus_path),
193
- "available_chunk_ids": [chunk.chunk_id for chunk in self._available_chunks],
194
- "selected_chunk_details": selected_chunk_details,
195
- "optimized_prompt_preview": optimized_prompt,
196
- "optimized_prompt_tokens": max(1, len(optimized_prompt) // 4) if optimized_prompt else 0,
197
- "context_tuning": (
198
- {
199
- "mode": self._last_tuning.mode,
200
- "top_demo_cases": self._last_tuning.top_demo_cases,
201
- "suggested_citations": self._last_tuning.suggested_citations,
202
- "token_dropout": self._last_tuning.token_dropout,
203
- "leave_one_out": self._last_tuning.leave_one_out,
204
- }
205
- if self._last_tuning is not None
206
- else None
207
- ),
208
- }
209
-
210
- async def close(self):
211
- self._done = True
212
-
213
- def _filter_chunks_for_task(self, task: Task) -> list[Chunk]:
214
- domain_mapping = {
215
- "customer_support_operations": "Customer Support Operations",
216
- "incident_response_playbooks": "Incident Response Playbooks",
217
- "platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
218
  }
219
  if self._query_overridden:
220
- if self._is_project_query(task.query):
221
  return list(self._all_chunks) + list(self._project_chunks)
222
  return list(self._all_chunks)
223
- if task.domain_filter is None:
224
- return list(self._all_chunks)
225
- normalized = domain_mapping.get(task.domain_filter, task.domain_filter)
226
- return [chunk for chunk in self._all_chunks if chunk.domain == normalized]
227
-
228
- def _is_project_query(self, query: str) -> bool:
229
- lowered = query.lower()
230
- return any(hint in lowered for hint in self._PROJECT_QUERY_HINTS)
231
-
232
- def _rank_chunks_for_query(self, query: str, chunks: list[Chunk], top_k: int = 20) -> list[Chunk]:
233
- tuning = self.context_tuner.tune(query, chunks)
234
- self._last_tuning = tuning
235
- scored = []
236
  for chunk in chunks:
237
  tuned = tuning.tuned_scores.get(chunk.chunk_id)
238
  score = tuned.final_score if tuned is not None else self.retriever.hybrid_score(query, chunk)
239
- if self._query_overridden and chunk.domain.startswith("Project"):
240
  score = min(1.0, score + 0.08)
241
  scored.append((chunk, score))
242
- scored.sort(key=lambda item: (-item[1], item[0].tokens, item[0].chunk_id))
243
- if not scored:
244
- return []
245
-
246
- capped = scored[: max(1, min(top_k * 2, len(scored)))]
247
- best_score = capped[0][1]
248
- floor = max(0.12, best_score * 0.38)
249
- filtered_pairs = [(chunk, score) for chunk, score in capped if score >= floor]
250
-
251
- if self._query_overridden:
252
  project_pairs = [(chunk, score) for chunk, score in filtered_pairs if chunk.domain.startswith("Project")]
253
  if len(project_pairs) >= 4:
254
  filtered_pairs = project_pairs + [
255
- (chunk, score)
256
- for chunk, score in filtered_pairs
257
- if not chunk.domain.startswith("Project")
258
- ]
259
-
260
- filtered = [chunk for chunk, _score in filtered_pairs]
261
- if not filtered:
262
- filtered = [chunk for chunk, _score in capped[: max(1, min(top_k, len(capped)))]]
263
-
264
- return filtered[: max(1, min(top_k, len(filtered)))]
265
-
266
- def _load_project_chunks(self) -> list[Chunk]:
267
- root = Path(__file__).resolve().parent.parent
268
- chunks: list[Chunk] = []
269
- file_specs = [
270
- ("Project Documentation", root / "README.md", ["project_docs", "readme"]),
271
- ("Project Configuration", root / "openenv.yaml", ["project_docs", "config", "openenv_spec"]),
272
- ("Project API", root / "app.py", ["project_docs", "api", "server"]),
273
- ("Project Baseline", root / "inference.py", ["project_docs", "baseline", "inference"]),
274
- ("Project Environment", root / "env" / "environment.py", ["project_docs", "environment", "state_management"]),
275
- ("Project Retrieval", root / "env" / "retriever.py", ["project_docs", "retrieval", "ranking"]),
276
- ("Project Grading", root / "env" / "graders.py", ["project_docs", "grading", "reward_design"]),
277
- ("Project Tasks", root / "env" / "tasks.py", ["project_docs", "tasks", "difficulty"]),
278
- ("Project Validation", root / "validate.py", ["project_docs", "validation", "testing"]),
279
- ]
280
-
281
- for domain, path, tags in file_specs:
282
- if not path.exists():
283
- continue
284
- raw_text = path.read_text(encoding="utf-8", errors="ignore")
285
- sections = self._chunk_project_text(raw_text)
286
- stem = re.sub(r"[^a-z0-9]+", "_", path.stem.lower()).strip("_") or "file"
287
- for index, section in enumerate(sections, start=1):
288
- keywords = self._extract_project_keywords(section)
289
- if not keywords:
290
- keywords = [stem, domain.lower()]
291
- chunks.append(
292
- Chunk(
293
- chunk_id=f"project_{stem}_{index:03d}",
294
- domain=domain,
295
- text=section,
296
- tokens=max(30, len(section) // 4),
297
- keywords=keywords[:5],
298
- relevance_tags=tags,
299
- )
300
- )
301
- return chunks
302
-
303
- def _chunk_project_text(self, raw_text: str, chunk_words: int = 140, stride_words: int = 100) -> list[str]:
304
- cleaned = " ".join(raw_text.split())
305
- words = cleaned.split()
306
- if not words:
307
- return []
308
- if len(words) <= chunk_words:
309
- return [" ".join(words)]
310
-
311
- chunks: list[str] = []
312
- start = 0
313
- while start < len(words):
314
- window = words[start : start + chunk_words]
315
- if not window:
316
- break
317
- chunks.append(" ".join(window))
318
- if start + chunk_words >= len(words):
319
- break
320
- start += stride_words
321
- return chunks
322
-
323
- def _extract_project_keywords(self, text: str) -> list[str]:
324
- terms = re.findall(r"[a-z0-9_]+", text.lower())
325
- counts: dict[str, int] = {}
326
- for term in terms:
327
- if len(term) < 4 or term in self._PROJECT_STOPWORDS:
328
- continue
329
- counts[term] = counts.get(term, 0) + 1
330
- ranked = sorted(counts.items(), key=lambda item: (-item[1], item[0]))
331
- return [term.replace("_", " ") for term, _count in ranked[:8]]
332
-
333
- def _build_observation(self) -> RagObservation:
334
- return RagObservation(
335
- query=self.task.query,
336
- available_chunks=[
337
- ChunkSummary(
338
- chunk_id=chunk.chunk_id,
339
- domain=chunk.domain,
340
- tokens=self._effective_chunk_tokens(chunk.chunk_id),
341
- keywords=chunk.keywords,
342
- )
343
- for chunk in self._available_chunks
344
- ],
345
- selected_chunks=list(self._selected_chunks),
346
- total_tokens_used=self._total_tokens_used(),
347
- token_budget=self.task.token_budget,
348
- step_number=self._step_number,
349
- task_name=self.task.name,
350
- last_action_feedback=self._last_action_feedback,
351
- )
352
-
353
- def _chunk_map(self) -> dict[str, Chunk]:
354
- return {chunk.chunk_id: chunk for chunk in self._available_chunks}
355
-
356
- def _effective_chunk_tokens(self, chunk_id: str) -> int:
357
- chunk = self._chunk_map().get(chunk_id)
358
- if chunk is None:
359
- return 0
360
- ratio = self._compression_ratios.get(chunk_id, 1.0)
361
- return max(1, int(round(chunk.tokens * ratio)))
362
-
363
- def _total_tokens_used(self) -> int:
364
- return sum(self._effective_chunk_tokens(chunk_id) for chunk_id in self._selected_chunks)
365
-
366
- def _effective_chunk_text(self, chunk_id: str) -> str:
367
- chunk = self._chunk_map().get(chunk_id)
368
- if chunk is None:
369
- return ""
370
- ratio = self._compression_ratios.get(chunk_id, 1.0)
371
- text = " ".join(chunk.text.split())
372
- if ratio >= 0.999:
373
- return text
374
-
375
- query_terms = self._query_terms(self.task.query)
376
- keyword_terms = self._query_terms(" ".join(chunk.keywords))
377
- sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
378
- if not sentences:
379
- return self._truncate_words(text, ratio)
380
-
381
- ranked_sentences: list[tuple[int, float, int, str]] = []
382
- for index, sentence in enumerate(sentences):
383
- sentence_terms = self._query_terms(sentence)
384
- overlap = len(sentence_terms & query_terms)
385
- keyword_overlap = len(sentence_terms & keyword_terms)
386
- score = (overlap * 2.0) + keyword_overlap + (0.25 if index == 0 else 0.0)
387
- ranked_sentences.append((index, score, len(sentence.split()), sentence))
388
-
389
- target_words = max(20, int(len(text.split()) * ratio))
390
- chosen: list[tuple[int, str]] = []
391
- used_words = 0
392
- for index, _score, word_count, sentence in sorted(
393
- ranked_sentences,
394
- key=lambda item: (-item[1], item[2], item[0]),
395
- ):
396
- if used_words >= target_words:
397
- break
398
- chosen.append((index, sentence))
399
- used_words += word_count
400
-
401
- if not chosen:
402
- return self._truncate_words(text, ratio)
403
-
404
- chosen.sort(key=lambda item: item[0])
405
- compressed = " ".join(sentence for _index, sentence in chosen)
406
- return self._truncate_words(compressed, ratio)
407
-
408
- @staticmethod
409
- def _truncate_words(text: str, ratio: float) -> str:
410
- words = text.split()
411
- if not words:
412
- return ""
413
- keep = max(12, int(len(words) * ratio))
414
- truncated = " ".join(words[:keep])
415
- if keep < len(words):
416
- return truncated + " ..."
417
- return truncated
418
-
419
- @staticmethod
420
- def _query_terms(text: str) -> set[str]:
421
- return {token for token in re.findall(r"[a-z0-9]+", text.lower()) if len(token) > 2}
422
-
423
- def _build_optimized_prompt(self) -> str:
424
- if not self._selected_chunks:
425
- return ""
426
- sections = [f"Question: {self.task.query}", "", "Optimized Context:"]
427
- for chunk_id in self._selected_chunks:
428
- chunk = self._chunk_map().get(chunk_id)
429
- if chunk is None:
430
- continue
431
- sections.append(
432
- f"[{chunk.chunk_id} | {self._effective_chunk_tokens(chunk_id)} tokens] {self._effective_chunk_text(chunk_id)}"
433
- )
434
- return "\n".join(sections).strip()
435
-
436
- def _is_relevant(self, chunk_id: str) -> tuple[bool, float]:
437
- chunk = self._chunk_map().get(chunk_id)
438
- if chunk is None:
439
- return False, 0.0
440
- score = self.retriever.hybrid_score(self.task.query, chunk)
441
- return score >= 0.3, score
442
-
443
- def _handle_select(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
444
- chunk = self._chunk_map().get(chunk_id)
445
- if chunk is None:
446
- self._last_action_feedback = "chunk_not_found"
447
- return -0.1, {"event": "chunk_not_found"}
448
- if chunk_id in self._selected_chunks:
449
- self._last_action_feedback = "chunk_already_selected"
450
- return 0.0, {"event": "chunk_already_selected"}
451
-
452
- projected_tokens = self._total_tokens_used() + self._effective_chunk_tokens(chunk_id)
453
- if projected_tokens > self.task.token_budget:
454
- self._last_action_feedback = "exceeded_budget"
455
- return -0.1, {"event": "exceeded_budget", "chunk_id": chunk_id}
456
-
457
- self._selected_chunks.append(chunk_id)
458
- _, score = self._is_relevant(chunk_id)
459
- self._last_action_feedback = "chunk_selected"
460
- return score * 0.2, {"event": "chunk_selected", "chunk_id": chunk_id, "hybrid_score": score}
461
-
462
- def _handle_deselect(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
463
- if chunk_id not in self._selected_chunks:
464
- self._last_action_feedback = "chunk_not_selected"
465
- return 0.0, {"event": "chunk_not_selected", "chunk_id": chunk_id}
466
-
467
- self._selected_chunks.remove(chunk_id)
468
- is_relevant, score = self._is_relevant(chunk_id)
469
- self._last_action_feedback = "chunk_deselected"
470
- reward = 0.0 if is_relevant else 0.05
471
- return reward, {"event": "chunk_deselected", "chunk_id": chunk_id, "hybrid_score": score}
472
-
473
- def _handle_compress(self, chunk_id: str, compression_ratio: float) -> tuple[float, dict[str, Any]]:
474
- chunk = self._chunk_map().get(chunk_id)
475
- if chunk is None:
476
- self._last_action_feedback = "chunk_not_found"
477
- return -0.1, {"event": "chunk_not_found", "chunk_id": chunk_id}
478
-
479
- self._compression_ratios[chunk_id] = compression_ratio
480
- is_relevant, score = self._is_relevant(chunk_id)
481
- reward = 0.03 if is_relevant else 0.0
482
- if score >= 0.6 and compression_ratio < 0.4:
483
- reward -= 0.05
484
- self._last_action_feedback = "overcompressed_relevant_chunk"
485
- return reward, {
486
- "event": "overcompressed_relevant_chunk",
487
- "chunk_id": chunk_id,
488
- "hybrid_score": score,
489
- "compression_ratio": compression_ratio,
490
- }
491
-
492
- self._last_action_feedback = "chunk_compressed"
493
- return reward, {
494
- "event": "chunk_compressed",
495
- "chunk_id": chunk_id,
496
- "hybrid_score": score,
497
- "compression_ratio": compression_ratio,
498
- }
499
-
500
- def _finalize_submission(self, reason: str) -> StepResult:
501
- self._done = True
502
-
503
- if not self._selected_chunks:
504
- self._last_action_feedback = "no_chunks_selected"
505
- observation = self._build_observation()
506
- return StepResult(
507
- observation=observation,
508
- reward=0.0,
509
- done=True,
510
- info={"event": reason, "grader": None, "passed": False},
511
- )
512
-
513
- grader_result = self.grader.grade(
514
- selected_chunk_ids=list(self._selected_chunks),
515
- answer=self._last_answer,
516
- token_budget=self.task.token_budget,
517
- total_tokens_used=self._total_tokens_used(),
518
- retriever=self.retriever,
519
- task=self.task,
520
- )
521
- self._last_action_feedback = reason
522
- observation = self._build_observation()
523
- return StepResult(
524
- observation=observation,
525
- reward=grader_result.score,
526
- done=True,
527
- info={
528
- "event": reason,
529
- "grader": grader_result.breakdown,
530
- "passed": grader_result.passed,
531
- },
532
- )
 
1
+ """
2
+ Main OpenEnv-style environment for rag-context-optimizer.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import asdict, dataclass, is_dataclass, replace
8
+ import os
9
+ from pathlib import Path
10
+ import re
11
+ from typing import Any
12
+
13
+ from env.corpus import Chunk, load_corpus, resolve_corpus_path
14
+ from env.context_tuner import ContextTunedPlanner
15
+ from env.graders import TaskGrader
16
+ from env.models import ChunkSummary, RagAction, RagObservation
17
+ from env.retriever import HybridRetriever
18
+ from env.tasks import ALL_TASKS, TASKS_BY_NAME, Task
19
+
20
+
21
+ @dataclass(slots=True)
22
+ class StepResult:
23
+ observation: RagObservation
24
+ reward: float
25
+ done: bool
26
+ info: dict[str, Any]
27
+
28
+
29
+ class RagContextOptimizerEnv:
30
+ _PROJECT_STOPWORDS = {
31
+ "the", "and", "for", "with", "that", "this", "from", "into", "your", "have", "will",
32
+ "using", "used", "use", "into", "they", "them", "their", "about", "while", "where",
33
+ "when", "what", "which", "should", "would", "could", "there", "here", "then", "than",
34
+ "each", "such", "only", "also", "been", "being", "does", "did", "done", "just", "more",
35
+ "most", "very", "over", "under", "like", "same", "across", "because", "through", "make",
36
+ "made", "many", "much", "some", "into", "onto", "must", "need", "needs", "task", "tasks",
37
+ "chunk", "chunks", "query", "prompt", "environment", "agent", "agents", "model", "models",
38
+ }
39
+ _PROJECT_QUERY_HINTS = {
40
+ "openenv", "benchmark", "rag-context-optimizer", "readme", "docker", "fastapi", "api",
41
+ "endpoint", "inference.py", "app.py", "tasks.py", "graders.py", "environment.py", "repo",
42
+ "repository", "codebase", "ui", "frontend", "backend", "space", "validator",
43
+ }
44
+
45
+ def __init__(
46
+ self,
47
+ task_name: str = "single_domain_qa",
48
+ query_override: str | None = None,
49
+ token_budget_override: int | None = None,
50
+ max_steps_override: int | None = None,
51
+ corpus_family_override: str | None = None,
52
+ ):
53
+ if task_name not in TASKS_BY_NAME:
54
+ raise ValueError(f"Unknown task_name: {task_name}")
55
+
56
  self._corpus_family = corpus_family_override or os.getenv("RAG_CORPUS_FAMILY") or "enterprise_v1"
57
  explicit_path = os.getenv("RAG_CORPUS_PATH")
58
  self._corpus_path = resolve_corpus_path(explicit_path, family=None if explicit_path else self._corpus_family)
59
  self._all_chunks = load_corpus(self._corpus_path)
60
  self._query_overridden = bool(query_override and query_override.strip())
61
+ self._include_project_chunks = os.getenv("ENABLE_PROJECT_CORPUS", "").strip().lower() in {"1", "true", "yes"}
62
+ self._project_chunks = self._load_project_chunks() if self._include_project_chunks else []
63
  self.retriever = HybridRetriever(self._all_chunks + self._project_chunks)
64
+ self.context_tuner = ContextTunedPlanner(
65
+ self.retriever,
66
+ self._all_chunks + self._project_chunks,
67
+ list(ALL_TASKS),
68
+ )
69
+ self.grader = TaskGrader()
70
+ self.task: Task = self._build_task(
71
+ TASKS_BY_NAME[task_name],
72
+ query_override=query_override,
73
+ token_budget_override=token_budget_override,
74
+ max_steps_override=max_steps_override,
75
+ )
76
+
77
+ self._available_chunks: list[Chunk] = []
78
+ self._selected_chunks: list[str] = []
79
+ self._compression_ratios: dict[str, float] = {}
80
+ self._step_number = 0
81
+ self._done = False
82
+ self._last_action_feedback: str | None = None
83
+ self._last_answer = ""
84
+ self._last_tuning = None
85
+
86
+ @staticmethod
87
+ def _build_task(
88
+ base_task: Task,
89
+ query_override: str | None = None,
90
+ token_budget_override: int | None = None,
91
+ max_steps_override: int | None = None,
92
+ ) -> Task:
93
+ updated_task = base_task
94
+ if query_override and query_override.strip():
95
+ updated_task = replace(updated_task, query=query_override.strip(), domain_filter=None)
96
+ if token_budget_override is not None and token_budget_override > 0:
97
+ updated_task = replace(updated_task, token_budget=token_budget_override)
98
+ if max_steps_override is not None and max_steps_override > 0:
99
+ updated_task = replace(updated_task, max_steps=max_steps_override)
100
+ return updated_task
101
+
102
+ async def reset(self) -> StepResult:
103
+ candidate_chunks = self._filter_chunks_for_task(self.task)
104
+ self._available_chunks = self._rank_chunks_for_query(self.task.query, candidate_chunks)
105
+ if not self._query_overridden:
106
+ chunk_by_id = {chunk.chunk_id: chunk for chunk in candidate_chunks}
107
+ for chunk_id in self.task.required_chunk_ids:
108
+ chunk = chunk_by_id.get(chunk_id)
109
+ if chunk and all(existing.chunk_id != chunk_id for existing in self._available_chunks):
110
+ self._available_chunks.append(chunk)
111
+ self._selected_chunks = []
112
+ self._compression_ratios = {}
113
+ self._step_number = 0
114
+ self._done = False
115
+ self._last_action_feedback = None
116
+ self._last_answer = ""
117
+
118
+ observation = self._build_observation()
119
+ return StepResult(
120
+ observation=observation,
121
+ reward=0.0,
122
+ done=False,
123
+ info={"task": self.task.name, "event": "reset"},
124
+ )
125
+
126
+ async def step(self, action: RagAction) -> StepResult:
127
+ if self._done:
128
+ return StepResult(
129
+ observation=self._build_observation(),
130
+ reward=0.0,
131
+ done=True,
132
+ info={"task": self.task.name, "event": "episode_already_done"},
133
+ )
134
+
135
+ reward = 0.0
136
+ info: dict[str, Any] = {"task": self.task.name, "action_type": action.action_type}
137
+
138
+ if action.action_type == "select_chunk":
139
+ reward, info = self._handle_select(action.chunk_id or "")
140
+ elif action.action_type == "deselect_chunk":
141
+ reward, info = self._handle_deselect(action.chunk_id or "")
142
+ elif action.action_type == "compress_chunk":
143
+ reward, info = self._handle_compress(action.chunk_id or "", float(action.compression_ratio or 0.0))
144
+ elif action.action_type == "submit_answer":
145
+ self._last_answer = action.answer or ""
146
+ result = self._finalize_submission(reason="submit_answer")
147
+ self._step_number += 1
148
+ result.observation.step_number = self._step_number
149
+ return result
150
+
151
+ self._step_number += 1
152
+
153
+ if self._step_number >= self.task.max_steps:
154
+ return self._finalize_submission(reason="max_steps_reached")
155
+
156
+ observation = self._build_observation()
157
+ return StepResult(
158
+ observation=observation,
159
+ reward=reward,
160
+ done=False,
161
+ info=info,
162
+ )
163
+
164
+ async def state(self) -> dict:
165
+ selected_chunk_details = []
166
+ for chunk_id in self._selected_chunks:
167
+ chunk = self._chunk_map().get(chunk_id)
168
+ if chunk is None:
169
+ continue
170
+ selected_chunk_details.append(
171
+ {
172
+ "chunk_id": chunk.chunk_id,
173
+ "domain": chunk.domain,
174
+ "original_tokens": chunk.tokens,
175
+ "effective_tokens": self._effective_chunk_tokens(chunk_id),
176
+ "compression_ratio": round(self._compression_ratios.get(chunk_id, 1.0), 3),
177
+ "text": self._effective_chunk_text(chunk_id),
178
+ "keywords": chunk.keywords,
179
+ }
180
+ )
181
+ optimized_prompt = self._build_optimized_prompt()
182
+ return {
183
+ "task": asdict(self.task) if is_dataclass(self.task) else self.task,
184
+ "step_number": self._step_number,
185
+ "done": self._done,
186
+ "selected_chunks": list(self._selected_chunks),
187
+ "compression_ratios": dict(self._compression_ratios),
188
+ "total_tokens_used": self._total_tokens_used(),
189
+ "token_budget": self.task.token_budget,
190
+ "last_action_feedback": self._last_action_feedback,
191
+ "last_answer": self._last_answer,
192
+ "corpus_family": self._corpus_family,
193
+ "corpus_path": str(self._corpus_path),
194
+ "available_chunk_ids": [chunk.chunk_id for chunk in self._available_chunks],
195
+ "selected_chunk_details": selected_chunk_details,
196
+ "optimized_prompt_preview": optimized_prompt,
197
+ "optimized_prompt_tokens": max(1, len(optimized_prompt) // 4) if optimized_prompt else 0,
198
+ "context_tuning": (
199
+ {
200
+ "mode": self._last_tuning.mode,
201
+ "top_demo_cases": self._last_tuning.top_demo_cases,
202
+ "suggested_citations": self._last_tuning.suggested_citations,
203
+ "token_dropout": self._last_tuning.token_dropout,
204
+ "leave_one_out": self._last_tuning.leave_one_out,
205
+ }
206
+ if self._last_tuning is not None
207
+ else None
208
+ ),
209
+ }
210
+
211
+ async def close(self):
212
+ self._done = True
213
+
214
+ def _filter_chunks_for_task(self, task: Task) -> list[Chunk]:
215
+ domain_mapping = {
216
+ "customer_support_operations": "Customer Support Operations",
217
+ "incident_response_playbooks": "Incident Response Playbooks",
218
+ "platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
219
  }
220
  if self._query_overridden:
221
+ if self._include_project_chunks and self._is_project_query(task.query):
222
  return list(self._all_chunks) + list(self._project_chunks)
223
  return list(self._all_chunks)
224
+ if task.domain_filter is None:
225
+ return list(self._all_chunks)
226
+ normalized = domain_mapping.get(task.domain_filter, task.domain_filter)
227
+ return [chunk for chunk in self._all_chunks if chunk.domain == normalized]
228
+
229
+ def _is_project_query(self, query: str) -> bool:
230
+ lowered = query.lower()
231
+ return any(hint in lowered for hint in self._PROJECT_QUERY_HINTS)
232
+
233
+ def _rank_chunks_for_query(self, query: str, chunks: list[Chunk], top_k: int = 20) -> list[Chunk]:
234
+ tuning = self.context_tuner.tune(query, chunks)
235
+ self._last_tuning = tuning
236
+ scored = []
237
  for chunk in chunks:
238
  tuned = tuning.tuned_scores.get(chunk.chunk_id)
239
  score = tuned.final_score if tuned is not None else self.retriever.hybrid_score(query, chunk)
240
+ if self._include_project_chunks and self._query_overridden and chunk.domain.startswith("Project"):
241
  score = min(1.0, score + 0.08)
242
  scored.append((chunk, score))
243
+ scored.sort(key=lambda item: (-item[1], item[0].tokens, item[0].chunk_id))
244
+ if not scored:
245
+ return []
246
+
247
+ capped = scored[: max(1, min(top_k * 2, len(scored)))]
248
+ best_score = capped[0][1]
249
+ floor = max(0.12, best_score * 0.38)
250
+ filtered_pairs = [(chunk, score) for chunk, score in capped if score >= floor]
251
+
252
+ if self._include_project_chunks and self._query_overridden:
253
  project_pairs = [(chunk, score) for chunk, score in filtered_pairs if chunk.domain.startswith("Project")]
254
  if len(project_pairs) >= 4:
255
  filtered_pairs = project_pairs + [
256
+ (chunk, score)
257
+ for chunk, score in filtered_pairs
258
+ if not chunk.domain.startswith("Project")
259
+ ]
260
+
261
+ filtered = [chunk for chunk, _score in filtered_pairs]
262
+ if not filtered:
263
+ filtered = [chunk for chunk, _score in capped[: max(1, min(top_k, len(capped)))]]
264
+
265
+ return filtered[: max(1, min(top_k, len(filtered)))]
266
+
267
+ def _load_project_chunks(self) -> list[Chunk]:
268
+ root = Path(__file__).resolve().parent.parent
269
+ chunks: list[Chunk] = []
270
+ file_specs = [
271
+ ("Project Documentation", root / "README.md", ["project_docs", "readme"]),
272
+ ("Project Configuration", root / "openenv.yaml", ["project_docs", "config", "openenv_spec"]),
273
+ ("Project API", root / "app.py", ["project_docs", "api", "server"]),
274
+ ("Project Baseline", root / "inference.py", ["project_docs", "baseline", "inference"]),
275
+ ("Project Environment", root / "env" / "environment.py", ["project_docs", "environment", "state_management"]),
276
+ ("Project Retrieval", root / "env" / "retriever.py", ["project_docs", "retrieval", "ranking"]),
277
+ ("Project Grading", root / "env" / "graders.py", ["project_docs", "grading", "reward_design"]),
278
+ ("Project Tasks", root / "env" / "tasks.py", ["project_docs", "tasks", "difficulty"]),
279
+ ("Project Validation", root / "validate.py", ["project_docs", "validation", "testing"]),
280
+ ]
281
+
282
+ for domain, path, tags in file_specs:
283
+ if not path.exists():
284
+ continue
285
+ raw_text = path.read_text(encoding="utf-8", errors="ignore")
286
+ sections = self._chunk_project_text(raw_text)
287
+ stem = re.sub(r"[^a-z0-9]+", "_", path.stem.lower()).strip("_") or "file"
288
+ for index, section in enumerate(sections, start=1):
289
+ keywords = self._extract_project_keywords(section)
290
+ if not keywords:
291
+ keywords = [stem, domain.lower()]
292
+ chunks.append(
293
+ Chunk(
294
+ chunk_id=f"project_{stem}_{index:03d}",
295
+ domain=domain,
296
+ text=section,
297
+ tokens=max(30, len(section) // 4),
298
+ keywords=keywords[:5],
299
+ relevance_tags=tags,
300
+ )
301
+ )
302
+ return chunks
303
+
304
+ def _chunk_project_text(self, raw_text: str, chunk_words: int = 140, stride_words: int = 100) -> list[str]:
305
+ cleaned = " ".join(raw_text.split())
306
+ words = cleaned.split()
307
+ if not words:
308
+ return []
309
+ if len(words) <= chunk_words:
310
+ return [" ".join(words)]
311
+
312
+ chunks: list[str] = []
313
+ start = 0
314
+ while start < len(words):
315
+ window = words[start : start + chunk_words]
316
+ if not window:
317
+ break
318
+ chunks.append(" ".join(window))
319
+ if start + chunk_words >= len(words):
320
+ break
321
+ start += stride_words
322
+ return chunks
323
+
324
+ def _extract_project_keywords(self, text: str) -> list[str]:
325
+ terms = re.findall(r"[a-z0-9_]+", text.lower())
326
+ counts: dict[str, int] = {}
327
+ for term in terms:
328
+ if len(term) < 4 or term in self._PROJECT_STOPWORDS:
329
+ continue
330
+ counts[term] = counts.get(term, 0) + 1
331
+ ranked = sorted(counts.items(), key=lambda item: (-item[1], item[0]))
332
+ return [term.replace("_", " ") for term, _count in ranked[:8]]
333
+
334
+ def _build_observation(self) -> RagObservation:
335
+ return RagObservation(
336
+ query=self.task.query,
337
+ available_chunks=[
338
+ ChunkSummary(
339
+ chunk_id=chunk.chunk_id,
340
+ domain=chunk.domain,
341
+ tokens=self._effective_chunk_tokens(chunk.chunk_id),
342
+ keywords=chunk.keywords,
343
+ )
344
+ for chunk in self._available_chunks
345
+ ],
346
+ selected_chunks=list(self._selected_chunks),
347
+ total_tokens_used=self._total_tokens_used(),
348
+ token_budget=self.task.token_budget,
349
+ step_number=self._step_number,
350
+ task_name=self.task.name,
351
+ last_action_feedback=self._last_action_feedback,
352
+ )
353
+
354
+ def _chunk_map(self) -> dict[str, Chunk]:
355
+ return {chunk.chunk_id: chunk for chunk in self._available_chunks}
356
+
357
+ def _effective_chunk_tokens(self, chunk_id: str) -> int:
358
+ chunk = self._chunk_map().get(chunk_id)
359
+ if chunk is None:
360
+ return 0
361
+ ratio = self._compression_ratios.get(chunk_id, 1.0)
362
+ return max(1, int(round(chunk.tokens * ratio)))
363
+
364
+ def _total_tokens_used(self) -> int:
365
+ return sum(self._effective_chunk_tokens(chunk_id) for chunk_id in self._selected_chunks)
366
+
367
+ def _effective_chunk_text(self, chunk_id: str) -> str:
368
+ chunk = self._chunk_map().get(chunk_id)
369
+ if chunk is None:
370
+ return ""
371
+ ratio = self._compression_ratios.get(chunk_id, 1.0)
372
+ text = " ".join(chunk.text.split())
373
+ if ratio >= 0.999:
374
+ return text
375
+
376
+ query_terms = self._query_terms(self.task.query)
377
+ keyword_terms = self._query_terms(" ".join(chunk.keywords))
378
+ sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
379
+ if not sentences:
380
+ return self._truncate_words(text, ratio)
381
+
382
+ ranked_sentences: list[tuple[int, float, int, str]] = []
383
+ for index, sentence in enumerate(sentences):
384
+ sentence_terms = self._query_terms(sentence)
385
+ overlap = len(sentence_terms & query_terms)
386
+ keyword_overlap = len(sentence_terms & keyword_terms)
387
+ score = (overlap * 2.0) + keyword_overlap + (0.25 if index == 0 else 0.0)
388
+ ranked_sentences.append((index, score, len(sentence.split()), sentence))
389
+
390
+ target_words = max(20, int(len(text.split()) * ratio))
391
+ chosen: list[tuple[int, str]] = []
392
+ used_words = 0
393
+ for index, _score, word_count, sentence in sorted(
394
+ ranked_sentences,
395
+ key=lambda item: (-item[1], item[2], item[0]),
396
+ ):
397
+ if used_words >= target_words:
398
+ break
399
+ chosen.append((index, sentence))
400
+ used_words += word_count
401
+
402
+ if not chosen:
403
+ return self._truncate_words(text, ratio)
404
+
405
+ chosen.sort(key=lambda item: item[0])
406
+ compressed = " ".join(sentence for _index, sentence in chosen)
407
+ return self._truncate_words(compressed, ratio)
408
+
409
+ @staticmethod
410
+ def _truncate_words(text: str, ratio: float) -> str:
411
+ words = text.split()
412
+ if not words:
413
+ return ""
414
+ keep = max(12, int(len(words) * ratio))
415
+ truncated = " ".join(words[:keep])
416
+ if keep < len(words):
417
+ return truncated + " ..."
418
+ return truncated
419
+
420
+ @staticmethod
421
+ def _query_terms(text: str) -> set[str]:
422
+ return {token for token in re.findall(r"[a-z0-9]+", text.lower()) if len(token) > 2}
423
+
424
+ def _build_optimized_prompt(self) -> str:
425
+ if not self._selected_chunks:
426
+ return ""
427
+ sections = [f"Question: {self.task.query}", "", "Optimized Context:"]
428
+ for chunk_id in self._selected_chunks:
429
+ chunk = self._chunk_map().get(chunk_id)
430
+ if chunk is None:
431
+ continue
432
+ sections.append(
433
+ f"[{chunk.chunk_id} | {self._effective_chunk_tokens(chunk_id)} tokens] {self._effective_chunk_text(chunk_id)}"
434
+ )
435
+ return "\n".join(sections).strip()
436
+
437
+ def _is_relevant(self, chunk_id: str) -> tuple[bool, float]:
438
+ chunk = self._chunk_map().get(chunk_id)
439
+ if chunk is None:
440
+ return False, 0.0
441
+ score = self.retriever.hybrid_score(self.task.query, chunk)
442
+ return score >= 0.3, score
443
+
444
+ def _handle_select(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
445
+ chunk = self._chunk_map().get(chunk_id)
446
+ if chunk is None:
447
+ self._last_action_feedback = "chunk_not_found"
448
+ return -0.1, {"event": "chunk_not_found"}
449
+ if chunk_id in self._selected_chunks:
450
+ self._last_action_feedback = "chunk_already_selected"
451
+ return 0.0, {"event": "chunk_already_selected"}
452
+
453
+ projected_tokens = self._total_tokens_used() + self._effective_chunk_tokens(chunk_id)
454
+ if projected_tokens > self.task.token_budget:
455
+ self._last_action_feedback = "exceeded_budget"
456
+ return -0.1, {"event": "exceeded_budget", "chunk_id": chunk_id}
457
+
458
+ self._selected_chunks.append(chunk_id)
459
+ _, score = self._is_relevant(chunk_id)
460
+ self._last_action_feedback = "chunk_selected"
461
+ return score * 0.2, {"event": "chunk_selected", "chunk_id": chunk_id, "hybrid_score": score}
462
+
463
+ def _handle_deselect(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
464
+ if chunk_id not in self._selected_chunks:
465
+ self._last_action_feedback = "chunk_not_selected"
466
+ return 0.0, {"event": "chunk_not_selected", "chunk_id": chunk_id}
467
+
468
+ self._selected_chunks.remove(chunk_id)
469
+ is_relevant, score = self._is_relevant(chunk_id)
470
+ self._last_action_feedback = "chunk_deselected"
471
+ reward = 0.0 if is_relevant else 0.05
472
+ return reward, {"event": "chunk_deselected", "chunk_id": chunk_id, "hybrid_score": score}
473
+
474
+ def _handle_compress(self, chunk_id: str, compression_ratio: float) -> tuple[float, dict[str, Any]]:
475
+ chunk = self._chunk_map().get(chunk_id)
476
+ if chunk is None:
477
+ self._last_action_feedback = "chunk_not_found"
478
+ return -0.1, {"event": "chunk_not_found", "chunk_id": chunk_id}
479
+
480
+ self._compression_ratios[chunk_id] = compression_ratio
481
+ is_relevant, score = self._is_relevant(chunk_id)
482
+ reward = 0.03 if is_relevant else 0.0
483
+ if score >= 0.6 and compression_ratio < 0.4:
484
+ reward -= 0.05
485
+ self._last_action_feedback = "overcompressed_relevant_chunk"
486
+ return reward, {
487
+ "event": "overcompressed_relevant_chunk",
488
+ "chunk_id": chunk_id,
489
+ "hybrid_score": score,
490
+ "compression_ratio": compression_ratio,
491
+ }
492
+
493
+ self._last_action_feedback = "chunk_compressed"
494
+ return reward, {
495
+ "event": "chunk_compressed",
496
+ "chunk_id": chunk_id,
497
+ "hybrid_score": score,
498
+ "compression_ratio": compression_ratio,
499
+ }
500
+
501
+ def _finalize_submission(self, reason: str) -> StepResult:
502
+ self._done = True
503
+
504
+ if not self._selected_chunks:
505
+ self._last_action_feedback = "no_chunks_selected"
506
+ observation = self._build_observation()
507
+ return StepResult(
508
+ observation=observation,
509
+ reward=0.0,
510
+ done=True,
511
+ info={"event": reason, "grader": None, "passed": False},
512
+ )
513
+
514
+ grader_result = self.grader.grade(
515
+ selected_chunk_ids=list(self._selected_chunks),
516
+ answer=self._last_answer,
517
+ token_budget=self.task.token_budget,
518
+ total_tokens_used=self._total_tokens_used(),
519
+ retriever=self.retriever,
520
+ task=self.task,
521
+ )
522
+ self._last_action_feedback = reason
523
+ observation = self._build_observation()
524
+ return StepResult(
525
+ observation=observation,
526
+ reward=grader_result.score,
527
+ done=True,
528
+ info={
529
+ "event": reason,
530
+ "grader": grader_result.breakdown,
531
+ "passed": grader_result.passed,
532
+ },
533
+ )
env/graders.py CHANGED
@@ -1,124 +1,125 @@
1
- """
2
- Deterministic graders for rag-context-optimizer tasks.
3
- """
4
-
5
- from __future__ import annotations
6
-
7
- import re
8
- from dataclasses import dataclass
9
-
10
- from env.corpus import Chunk
11
- from env.retriever import HybridRetriever
12
- from env.tasks import Task
13
-
14
-
15
- _STOPWORDS = {
16
- "a", "an", "and", "are", "as", "at", "be", "because", "by", "for", "from", "how",
17
- "if", "in", "into", "is", "it", "its", "of", "on", "or", "that", "the", "their",
18
- "them", "there", "these", "this", "to", "was", "were", "what", "when", "where",
19
- "which", "while", "with", "within", "without", "you", "your",
20
- }
21
-
22
-
23
- def _tokenize(text: str) -> set[str]:
24
- return set(re.findall(r"[a-z0-9]+", text.lower()))
25
-
26
-
27
- def _content_terms(text: str) -> set[str]:
28
- return {term for term in _tokenize(text) if len(term) > 2 and term not in _STOPWORDS}
29
-
30
-
31
- def _extract_citations(text: str) -> list[str]:
32
- return re.findall(r"\[([a-z0-9_]+)\]", text.lower())
33
-
34
-
35
- def _normalize_chunk_id(chunk_id: str) -> str:
36
- chunk_id = chunk_id.strip()
37
- return chunk_id
38
-
39
-
40
- def _normalize_domain_filter(domain_filter: str | None) -> str | None:
41
- if domain_filter is None:
42
- return None
43
- mapping = {
44
- "customer_support_operations": "Customer Support Operations",
45
- "incident_response_playbooks": "Incident Response Playbooks",
46
- "platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
47
- }
48
- return mapping.get(domain_filter, domain_filter)
49
-
50
-
51
- def _f1_score(selected: set[str], relevant: set[str]) -> float:
52
- if not selected and not relevant:
53
- return 1.0
54
- if not selected or not relevant:
55
- return 0.0
56
- overlap = len(selected & relevant)
57
- if overlap == 0:
58
- return 0.0
59
- precision = overlap / len(selected)
60
- recall = overlap / len(relevant)
61
- return 2 * precision * recall / (precision + recall)
62
-
63
-
64
- @dataclass(frozen=True, slots=True)
65
- class GraderResult:
66
- score: float
67
- breakdown: dict[str, float]
68
- passed: bool
69
-
70
-
71
- class TaskGrader:
72
- def _filter_relevant_by_domain(self, relevant_ids: set[str], retriever: HybridRetriever, task: Task) -> set[str]:
73
- normalized_domain = _normalize_domain_filter(task.domain_filter)
74
- if normalized_domain is None:
75
- return relevant_ids
76
- allowed_ids = {chunk.chunk_id for chunk in retriever.corpus if chunk.domain == normalized_domain}
77
- return relevant_ids & allowed_ids
78
-
79
- def _required_chunks(self, retriever: HybridRetriever, task: Task) -> list[Chunk]:
80
- normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
81
- return [chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_required]
82
-
83
  def _answer_quality(self, answer: str, required_chunks: list[Chunk]) -> float:
84
  answer_terms = _content_terms(answer)
85
  required_terms = _content_terms(" ".join(chunk.text for chunk in required_chunks))
 
86
  if not answer_terms or not required_terms:
87
  return 0.0
88
  union = answer_terms | required_terms
89
- if not union:
90
- return 0.0
91
- return len(answer_terms & required_terms) / len(union)
92
-
93
- def _citation_accuracy(self, answer: str, selected_chunk_ids: set[str], expected_citation_ids: set[str]) -> float:
94
- citations = {_normalize_chunk_id(chunk_id) for chunk_id in _extract_citations(answer)}
95
- if not citations:
96
- return 0.0
97
- valid_citations = citations & selected_chunk_ids
98
- precision = len(valid_citations) / len(citations)
99
- recall = len(valid_citations & expected_citation_ids) / len(expected_citation_ids) if expected_citation_ids else 1.0
100
- return (precision + recall) / 2.0
101
-
102
- def _unsupported_claim_rate(self, answer: str, evidence_chunks: list[Chunk]) -> float:
103
- answer_terms = _content_terms(re.sub(r"\[[a-z0-9_]+\]", " ", answer.lower()))
104
- evidence_terms = _content_terms(" ".join(chunk.text for chunk in evidence_chunks))
105
- if not answer_terms:
106
- return 0.0
107
- unsupported = answer_terms - evidence_terms
108
- return len(unsupported) / len(answer_terms)
109
-
110
- def grade(
111
- self,
112
- selected_chunk_ids: list[str],
113
- answer: str,
114
- token_budget: int,
115
- total_tokens_used: int,
116
- retriever: HybridRetriever,
117
- task: Task,
118
  ) -> GraderResult:
119
  normalized_selected = {_normalize_chunk_id(chunk_id) for chunk_id in selected_chunk_ids}
120
- relevant = retriever.get_ground_truth_relevant(task.query, threshold=0.3)
121
- relevant = self._filter_relevant_by_domain(relevant, retriever, task)
122
 
123
  retrieval_precision = _f1_score(normalized_selected, relevant)
124
  token_efficiency = 1.0 - (total_tokens_used / token_budget) if total_tokens_used <= token_budget else 0.0
@@ -127,41 +128,41 @@ class TaskGrader:
127
  required_chunks = self._required_chunks(retriever, task)
128
  answer_quality = self._answer_quality(answer, required_chunks)
129
 
130
- normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
131
  normalized_expected_citations = {
132
  _normalize_chunk_id(chunk_id)
133
  for chunk_id in (task.expected_citation_ids or task.required_chunk_ids)
134
- }
135
- required_chunks_hit = (
136
- len(normalized_selected & normalized_required) / len(normalized_required)
137
- if normalized_required
138
- else 1.0
139
- )
140
-
141
  selected_chunks = [
142
  chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_selected
143
  ]
 
144
  citation_accuracy = self._citation_accuracy(answer, normalized_selected, normalized_expected_citations)
145
- unsupported_claim_rate = self._unsupported_claim_rate(answer, selected_chunks)
146
- hallucination_penalty = min(1.0, unsupported_claim_rate)
147
-
148
- base_score = (
149
- 0.25 * retrieval_precision
150
- + 0.25 * token_efficiency
151
- + 0.35 * answer_quality
152
- + 0.15 * required_chunks_hit
153
- )
154
- score = base_score + (0.10 * citation_accuracy) - (0.15 * hallucination_penalty)
155
- score = max(0.0, min(1.0, score))
156
-
157
- breakdown = {
158
- "retrieval_precision": retrieval_precision,
159
- "token_efficiency": token_efficiency,
160
- "answer_quality": answer_quality,
161
- "required_chunks_hit": required_chunks_hit,
162
- "citation_accuracy": citation_accuracy,
163
- "unsupported_claim_rate": unsupported_claim_rate,
164
- "hallucination_penalty": hallucination_penalty,
165
- }
166
- passed = score >= 0.7
167
- return GraderResult(score=score, breakdown=breakdown, passed=passed)
 
1
+ """
2
+ Deterministic graders for rag-context-optimizer tasks.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import re
8
+ from dataclasses import dataclass
9
+
10
+ from env.corpus import Chunk
11
+ from env.retriever import HybridRetriever
12
+ from env.tasks import Task
13
+
14
+
15
+ _STOPWORDS = {
16
+ "a", "an", "and", "are", "as", "at", "be", "because", "by", "for", "from", "how",
17
+ "if", "in", "into", "is", "it", "its", "of", "on", "or", "that", "the", "their",
18
+ "them", "there", "these", "this", "to", "was", "were", "what", "when", "where",
19
+ "which", "while", "with", "within", "without", "you", "your",
20
+ }
21
+
22
+
23
+ def _tokenize(text: str) -> set[str]:
24
+ return set(re.findall(r"[a-z0-9]+", text.lower()))
25
+
26
+
27
+ def _content_terms(text: str) -> set[str]:
28
+ return {term for term in _tokenize(text) if len(term) > 2 and term not in _STOPWORDS}
29
+
30
+
31
+ def _extract_citations(text: str) -> list[str]:
32
+ return re.findall(r"\[([a-z0-9_]+)\]", text.lower())
33
+
34
+
35
+ def _normalize_chunk_id(chunk_id: str) -> str:
36
+ chunk_id = chunk_id.strip()
37
+ return chunk_id
38
+
39
+
40
+ def _normalize_domain_filter(domain_filter: str | None) -> str | None:
41
+ if domain_filter is None:
42
+ return None
43
+ mapping = {
44
+ "customer_support_operations": "Customer Support Operations",
45
+ "incident_response_playbooks": "Incident Response Playbooks",
46
+ "platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
47
+ }
48
+ return mapping.get(domain_filter, domain_filter)
49
+
50
+
51
+ def _f1_score(selected: set[str], relevant: set[str]) -> float:
52
+ if not selected and not relevant:
53
+ return 1.0
54
+ if not selected or not relevant:
55
+ return 0.0
56
+ overlap = len(selected & relevant)
57
+ if overlap == 0:
58
+ return 0.0
59
+ precision = overlap / len(selected)
60
+ recall = overlap / len(relevant)
61
+ return 2 * precision * recall / (precision + recall)
62
+
63
+
64
+ @dataclass(frozen=True, slots=True)
65
+ class GraderResult:
66
+ score: float
67
+ breakdown: dict[str, float]
68
+ passed: bool
69
+
70
+
71
+ class TaskGrader:
72
+ def _filter_relevant_by_domain(self, relevant_ids: set[str], retriever: HybridRetriever, task: Task) -> set[str]:
73
+ normalized_domain = _normalize_domain_filter(task.domain_filter)
74
+ if normalized_domain is None:
75
+ return relevant_ids
76
+ allowed_ids = {chunk.chunk_id for chunk in retriever.corpus if chunk.domain == normalized_domain}
77
+ return relevant_ids & allowed_ids
78
+
79
+ def _required_chunks(self, retriever: HybridRetriever, task: Task) -> list[Chunk]:
80
+ normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
81
+ return [chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_required]
82
+
83
  def _answer_quality(self, answer: str, required_chunks: list[Chunk]) -> float:
84
  answer_terms = _content_terms(answer)
85
  required_terms = _content_terms(" ".join(chunk.text for chunk in required_chunks))
86
+ required_terms |= _content_terms(" ".join(" ".join(chunk.keywords) for chunk in required_chunks))
87
  if not answer_terms or not required_terms:
88
  return 0.0
89
  union = answer_terms | required_terms
90
+ if not union:
91
+ return 0.0
92
+ return len(answer_terms & required_terms) / len(union)
93
+
94
+ def _citation_accuracy(self, answer: str, selected_chunk_ids: set[str], expected_citation_ids: set[str]) -> float:
95
+ citations = {_normalize_chunk_id(chunk_id) for chunk_id in _extract_citations(answer)}
96
+ if not citations:
97
+ return 0.0
98
+ valid_citations = citations & selected_chunk_ids
99
+ precision = len(valid_citations) / len(citations)
100
+ recall = len(valid_citations & expected_citation_ids) / len(expected_citation_ids) if expected_citation_ids else 1.0
101
+ return (precision + recall) / 2.0
102
+
103
+ def _unsupported_claim_rate(self, answer: str, evidence_chunks: list[Chunk]) -> float:
104
+ answer_terms = _content_terms(re.sub(r"\[[a-z0-9_]+\]", " ", answer.lower()))
105
+ evidence_terms = _content_terms(" ".join(chunk.text for chunk in evidence_chunks))
106
+ if not answer_terms:
107
+ return 0.0
108
+ unsupported = answer_terms - evidence_terms
109
+ return len(unsupported) / len(answer_terms)
110
+
111
+ def grade(
112
+ self,
113
+ selected_chunk_ids: list[str],
114
+ answer: str,
115
+ token_budget: int,
116
+ total_tokens_used: int,
117
+ retriever: HybridRetriever,
118
+ task: Task,
119
  ) -> GraderResult:
120
  normalized_selected = {_normalize_chunk_id(chunk_id) for chunk_id in selected_chunk_ids}
121
+ normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
122
+ relevant = self._filter_relevant_by_domain(normalized_required, retriever, task)
123
 
124
  retrieval_precision = _f1_score(normalized_selected, relevant)
125
  token_efficiency = 1.0 - (total_tokens_used / token_budget) if total_tokens_used <= token_budget else 0.0
 
128
  required_chunks = self._required_chunks(retriever, task)
129
  answer_quality = self._answer_quality(answer, required_chunks)
130
 
 
131
  normalized_expected_citations = {
132
  _normalize_chunk_id(chunk_id)
133
  for chunk_id in (task.expected_citation_ids or task.required_chunk_ids)
134
+ }
135
+ required_chunks_hit = (
136
+ len(normalized_selected & normalized_required) / len(normalized_required)
137
+ if normalized_required
138
+ else 1.0
139
+ )
140
+
141
  selected_chunks = [
142
  chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_selected
143
  ]
144
+ evidence_chunks = selected_chunks or required_chunks
145
  citation_accuracy = self._citation_accuracy(answer, normalized_selected, normalized_expected_citations)
146
+ unsupported_claim_rate = self._unsupported_claim_rate(answer, evidence_chunks)
147
+ hallucination_penalty = min(1.0, unsupported_claim_rate)
148
+
149
+ base_score = (
150
+ 0.25 * retrieval_precision
151
+ + 0.25 * token_efficiency
152
+ + 0.35 * answer_quality
153
+ + 0.15 * required_chunks_hit
154
+ )
155
+ score = base_score + (0.10 * citation_accuracy) - (0.15 * hallucination_penalty)
156
+ score = max(0.0, min(1.0, score))
157
+
158
+ breakdown = {
159
+ "retrieval_precision": retrieval_precision,
160
+ "token_efficiency": token_efficiency,
161
+ "answer_quality": answer_quality,
162
+ "required_chunks_hit": required_chunks_hit,
163
+ "citation_accuracy": citation_accuracy,
164
+ "unsupported_claim_rate": unsupported_claim_rate,
165
+ "hallucination_penalty": hallucination_penalty,
166
+ }
167
+ passed = score >= 0.7
168
+ return GraderResult(score=score, breakdown=breakdown, passed=passed)
inference.py CHANGED
@@ -26,16 +26,16 @@ TASK_SEQUENCE = [
26
  "adversarial_compression",
27
  ]
28
 
29
- SYSTEM_PROMPT = """You are a baseline RAG context optimizer.
30
- Read the query and available chunks using chunk_id, keywords, tokens, and domain.
31
  Select chunks that maximize keyword overlap with the query.
32
  Stay under the token budget.
33
  Compress chunks that are mildly relevant but token-heavy.
34
  Submit a concise answer once enough useful chunks are selected.
35
  When you submit an answer, cite selected chunks inline like [support_003] or [incident_002].
36
  Return only valid JSON matching one of these forms:
37
- {"action_type":"select_chunk","chunk_id":"support_003"}
38
- {"action_type":"deselect_chunk","chunk_id":"support_003"}
39
  {"action_type":"compress_chunk","chunk_id":"support_003","compression_ratio":0.5}
40
  {"action_type":"submit_answer","answer":"Verify outage evidence and the billing ledger before refunding [support_001] [support_003]."}"""
41
 
@@ -222,12 +222,12 @@ async def _post_json(http_client: httpx.AsyncClient, path: str, payload: dict[st
222
  return response.json()
223
 
224
 
225
- async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
226
  rewards: list[float] = []
227
  steps = 0
228
  success = False
229
- score = 0.0
230
- terminal_error: str | None = None
231
  fallback_reason: str | None = None
232
  model_name = _model_name()
233
 
@@ -253,7 +253,7 @@ async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
253
  flush=True,
254
  )
255
  print("[END] success=false steps=0 score=0.000 rewards=")
256
- return 0.0, [], 0
257
 
258
  try:
259
  async with httpx.AsyncClient(timeout=30.0) as http_client:
@@ -276,7 +276,7 @@ async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
276
  print(
277
  f"[END] success=false steps={steps} score={_clamp_score(score):.3f} rewards={_format_rewards(rewards)}",
278
  )
279
- return score, rewards, steps
280
  print(
281
  f"[warn] Falling back to deterministic policy for {task_name}: {fallback_reason}",
282
  file=sys.stderr,
@@ -313,31 +313,34 @@ async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
313
  success = terminal_error is None and fallback_reason is None
314
  break
315
 
316
- score = _clamp_score(score)
317
- print(
318
- f"[END] success={_format_bool(success)} steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
319
- )
320
- return score, rewards, steps
321
- except Exception:
322
- score = _clamp_score(score)
323
- print(
324
- f"[END] success=false steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
325
- )
326
- return score, rewards, steps
327
-
328
-
329
- def run_task(task_name: str) -> tuple[float, list[float], int]:
330
- return asyncio.run(_run_task_http(task_name))
331
-
332
-
333
- def main() -> None:
334
- if RAG_ENV_TASK in TASK_SEQUENCE:
335
- tasks = [RAG_ENV_TASK] + [task for task in TASK_SEQUENCE if task != RAG_ENV_TASK]
336
- else:
337
- tasks = list(TASK_SEQUENCE)
338
- for task_name in tasks:
339
- run_task(task_name)
340
-
341
-
342
- if __name__ == "__main__":
343
- main()
 
 
 
 
26
  "adversarial_compression",
27
  ]
28
 
29
+ SYSTEM_PROMPT = """You are a baseline RAG context optimizer.
30
+ Read the query and available chunks using chunk_id, keywords, tokens, and domain.
31
  Select chunks that maximize keyword overlap with the query.
32
  Stay under the token budget.
33
  Compress chunks that are mildly relevant but token-heavy.
34
  Submit a concise answer once enough useful chunks are selected.
35
  When you submit an answer, cite selected chunks inline like [support_003] or [incident_002].
36
  Return only valid JSON matching one of these forms:
37
+ {"action_type":"select_chunk","chunk_id":"support_003"}
38
+ {"action_type":"deselect_chunk","chunk_id":"support_003"}
39
  {"action_type":"compress_chunk","chunk_id":"support_003","compression_ratio":0.5}
40
  {"action_type":"submit_answer","answer":"Verify outage evidence and the billing ledger before refunding [support_001] [support_003]."}"""
41
 
 
222
  return response.json()
223
 
224
 
225
+ async def _run_task_http(task_name: str) -> tuple[float, list[float], int, bool]:
226
  rewards: list[float] = []
227
  steps = 0
228
  success = False
229
+ score = 0.0
230
+ terminal_error: str | None = None
231
  fallback_reason: str | None = None
232
  model_name = _model_name()
233
 
 
253
  flush=True,
254
  )
255
  print("[END] success=false steps=0 score=0.000 rewards=")
256
+ return 0.0, [], 0, False
257
 
258
  try:
259
  async with httpx.AsyncClient(timeout=30.0) as http_client:
 
276
  print(
277
  f"[END] success=false steps={steps} score={_clamp_score(score):.3f} rewards={_format_rewards(rewards)}",
278
  )
279
+ return score, rewards, steps, False
280
  print(
281
  f"[warn] Falling back to deterministic policy for {task_name}: {fallback_reason}",
282
  file=sys.stderr,
 
313
  success = terminal_error is None and fallback_reason is None
314
  break
315
 
316
+ score = _clamp_score(score)
317
+ print(
318
+ f"[END] success={_format_bool(success)} steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
319
+ )
320
+ return score, rewards, steps, success
321
+ except Exception:
322
+ score = _clamp_score(score)
323
+ print(
324
+ f"[END] success=false steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
325
+ )
326
+ return score, rewards, steps, False
327
+
328
+
329
+ def run_task(task_name: str) -> tuple[float, list[float], int, bool]:
330
+ return asyncio.run(_run_task_http(task_name))
331
+
332
+
333
+ def main() -> int:
334
+ if RAG_ENV_TASK in TASK_SEQUENCE:
335
+ tasks = [RAG_ENV_TASK] + [task for task in TASK_SEQUENCE if task != RAG_ENV_TASK]
336
+ else:
337
+ tasks = list(TASK_SEQUENCE)
338
+ all_success = True
339
+ for task_name in tasks:
340
+ _score, _rewards, _steps, success = run_task(task_name)
341
+ all_success &= success
342
+ return 0 if all_success else 1
343
+
344
+
345
+ if __name__ == "__main__":
346
+ raise SystemExit(main())
tests/test_api.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ from fastapi.testclient import TestClient
7
+
8
+ ROOT = Path(__file__).resolve().parents[1]
9
+ if str(ROOT) not in sys.path:
10
+ sys.path.insert(0, str(ROOT))
11
+
12
+ from app import app
13
+
14
+
15
+ client = TestClient(app)
16
+
17
+
18
+ def test_reset_accepts_empty_body():
19
+ response = client.post("/reset")
20
+ assert response.status_code == 200
21
+ body = response.json()
22
+ assert "episode_id" in body
23
+ assert body["done"] is False
24
+ assert "observation" in body
25
+
26
+
27
+ def test_episode_state_is_isolated():
28
+ first_reset = client.post("/reset", json={"task_name": "single_domain_qa"})
29
+ second_reset = client.post("/reset", json={"task_name": "cross_domain_synthesis"})
30
+ assert first_reset.status_code == 200
31
+ assert second_reset.status_code == 200
32
+
33
+ first_episode = first_reset.json()["episode_id"]
34
+ second_episode = second_reset.json()["episode_id"]
35
+ assert first_episode != second_episode
36
+
37
+ first_chunk = first_reset.json()["observation"]["available_chunks"][0]["chunk_id"]
38
+ step = client.post(f"/step?episode_id={first_episode}", json={"action_type": "select_chunk", "chunk_id": first_chunk})
39
+ assert step.status_code == 200
40
+ assert step.json()["episode_id"] == first_episode
41
+
42
+ first_state = client.get(f"/state?episode_id={first_episode}")
43
+ second_state = client.get(f"/state?episode_id={second_episode}")
44
+ assert first_state.status_code == 200
45
+ assert second_state.status_code == 200
46
+ assert first_chunk in first_state.json()["selected_chunks"]
47
+ assert second_state.json()["selected_chunks"] == []
tests/test_inference_proxy.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import socket
6
+ import subprocess
7
+ import sys
8
+ import threading
9
+ import time
10
+ from http.server import BaseHTTPRequestHandler, HTTPServer
11
+ from pathlib import Path
12
+
13
+ import httpx
14
+
15
+
16
+ ROOT = Path(__file__).resolve().parents[1]
17
+ PYTHON = ROOT / ".venv" / "Scripts" / "python.exe"
18
+
19
+
20
+ def _free_port() -> int:
21
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
22
+ sock.bind(("127.0.0.1", 0))
23
+ return int(sock.getsockname()[1])
24
+
25
+
26
+ def test_inference_uses_proxy_api_key():
27
+ app_port = _free_port()
28
+ proxy_port = _free_port()
29
+ requests_seen: list[dict[str, str | None]] = []
30
+
31
+ class ProxyHandler(BaseHTTPRequestHandler):
32
+ def do_POST(self):
33
+ length = int(self.headers.get("Content-Length", "0"))
34
+ body = self.rfile.read(length).decode("utf-8")
35
+ requests_seen.append(
36
+ {
37
+ "path": self.path,
38
+ "authorization": self.headers.get("Authorization"),
39
+ "body": body,
40
+ }
41
+ )
42
+ payload = {
43
+ "id": "chatcmpl-test",
44
+ "object": "chat.completion",
45
+ "created": int(time.time()),
46
+ "model": "proxy-test-model",
47
+ "choices": [
48
+ {
49
+ "index": 0,
50
+ "message": {
51
+ "role": "assistant",
52
+ "content": json.dumps(
53
+ {
54
+ "action_type": "submit_answer",
55
+ "answer": "Proxy verified [support_003]",
56
+ }
57
+ ),
58
+ },
59
+ "finish_reason": "stop",
60
+ }
61
+ ],
62
+ }
63
+ encoded = json.dumps(payload).encode("utf-8")
64
+ self.send_response(200)
65
+ self.send_header("Content-Type", "application/json")
66
+ self.send_header("Content-Length", str(len(encoded)))
67
+ self.end_headers()
68
+ self.wfile.write(encoded)
69
+
70
+ def log_message(self, format: str, *args):
71
+ return
72
+
73
+ proxy_server = HTTPServer(("127.0.0.1", proxy_port), ProxyHandler)
74
+ proxy_thread = threading.Thread(target=proxy_server.serve_forever, daemon=True)
75
+ proxy_thread.start()
76
+
77
+ app_process = subprocess.Popen(
78
+ [str(PYTHON), "-m", "uvicorn", "app:app", "--host", "127.0.0.1", "--port", str(app_port)],
79
+ cwd=ROOT,
80
+ stdout=subprocess.DEVNULL,
81
+ stderr=subprocess.DEVNULL,
82
+ )
83
+
84
+ try:
85
+ deadline = time.time() + 20
86
+ while time.time() < deadline:
87
+ try:
88
+ if httpx.get(f"http://127.0.0.1:{app_port}/health", timeout=2).status_code == 200:
89
+ break
90
+ except Exception:
91
+ time.sleep(0.5)
92
+
93
+ env = os.environ.copy()
94
+ env["RAG_ENV_URL"] = f"http://127.0.0.1:{app_port}"
95
+ env["RAG_ENV_TASK"] = "single_domain_qa"
96
+ env["API_BASE_URL"] = f"http://127.0.0.1:{proxy_port}/v1"
97
+ env["API_KEY"] = "proxy-check-token"
98
+ env["HF_TOKEN"] = "legacy-should-not-win"
99
+ result = subprocess.run(
100
+ [str(PYTHON), "inference.py"],
101
+ cwd=ROOT,
102
+ env=env,
103
+ capture_output=True,
104
+ text=True,
105
+ timeout=60,
106
+ )
107
+ assert result.returncode == 0
108
+ assert requests_seen
109
+ assert requests_seen[0]["path"] == "/v1/chat/completions"
110
+ assert requests_seen[0]["authorization"] == "Bearer proxy-check-token"
111
+ assert any(line.startswith("[END]") and "score=" in line for line in result.stdout.splitlines())
112
+ finally:
113
+ proxy_server.shutdown()
114
+ proxy_server.server_close()
115
+ app_process.terminate()
116
+ try:
117
+ app_process.wait(timeout=5)
118
+ except Exception:
119
+ app_process.kill()
validate.py CHANGED
@@ -1,13 +1,15 @@
1
  from __future__ import annotations
2
 
3
- import json
4
- import os
5
- import signal
6
- import socket
7
- import subprocess
8
- import sys
9
- import time
10
- from pathlib import Path
 
 
11
 
12
  import httpx
13
 
@@ -117,24 +119,80 @@ def run_task(client: httpx.Client, base_url: str, task_name: str) -> tuple[bool,
117
 
118
 
119
  def run_inference_script(base_url: str) -> bool:
120
- env = os.environ.copy()
121
- env["RAG_ENV_URL"] = base_url
122
- env["ALLOW_BASELINE_FALLBACK"] = "1"
123
- env["API_BASE_URL"] = "http://127.0.0.1:9/v1"
124
- env["API_KEY"] = "offline-validation-token"
125
- process = subprocess.run(
126
- [sys.executable, "inference.py"],
127
- cwd=PROJECT_ROOT,
128
- capture_output=True,
129
- text=True,
130
- timeout=120,
131
- env=env,
132
- )
133
- stdout = process.stdout or ""
134
- has_start = "[START]" in stdout
135
- has_end = "[END]" in stdout
136
- end_has_score = " score=" in stdout
137
- return process.returncode == 0 and has_start and has_end and end_has_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
  def main() -> int:
 
1
  from __future__ import annotations
2
 
3
+ import json
4
+ import os
5
+ import signal
6
+ import socket
7
+ import subprocess
8
+ import sys
9
+ import threading
10
+ import time
11
+ from http.server import BaseHTTPRequestHandler, HTTPServer
12
+ from pathlib import Path
13
 
14
  import httpx
15
 
 
119
 
120
 
121
  def run_inference_script(base_url: str) -> bool:
122
+ proxy_port = find_free_port()
123
+ requests_seen: list[dict[str, str | None]] = []
124
+
125
+ class ProxyHandler(BaseHTTPRequestHandler):
126
+ def do_POST(self):
127
+ length = int(self.headers.get("Content-Length", "0"))
128
+ body = self.rfile.read(length).decode("utf-8")
129
+ requests_seen.append(
130
+ {
131
+ "path": self.path,
132
+ "authorization": self.headers.get("Authorization"),
133
+ "body": body,
134
+ }
135
+ )
136
+ payload = {
137
+ "id": "chatcmpl-validate",
138
+ "object": "chat.completion",
139
+ "created": int(time.time()),
140
+ "model": "validator-proxy",
141
+ "choices": [
142
+ {
143
+ "index": 0,
144
+ "message": {
145
+ "role": "assistant",
146
+ "content": json.dumps(
147
+ {
148
+ "action_type": "submit_answer",
149
+ "answer": "Validated via proxy [support_003]",
150
+ }
151
+ ),
152
+ },
153
+ "finish_reason": "stop",
154
+ }
155
+ ],
156
+ }
157
+ encoded = json.dumps(payload).encode("utf-8")
158
+ self.send_response(200)
159
+ self.send_header("Content-Type", "application/json")
160
+ self.send_header("Content-Length", str(len(encoded)))
161
+ self.end_headers()
162
+ self.wfile.write(encoded)
163
+
164
+ def log_message(self, format: str, *args):
165
+ return
166
+
167
+ proxy_server = HTTPServer(("127.0.0.1", proxy_port), ProxyHandler)
168
+ proxy_thread = threading.Thread(target=proxy_server.serve_forever, daemon=True)
169
+ proxy_thread.start()
170
+
171
+ try:
172
+ env = os.environ.copy()
173
+ env["RAG_ENV_URL"] = base_url
174
+ env.pop("ALLOW_BASELINE_FALLBACK", None)
175
+ env["API_BASE_URL"] = f"http://127.0.0.1:{proxy_port}/v1"
176
+ env["API_KEY"] = "offline-validation-token"
177
+ env["HF_TOKEN"] = "legacy-should-not-win"
178
+ process = subprocess.run(
179
+ [sys.executable, "inference.py"],
180
+ cwd=PROJECT_ROOT,
181
+ capture_output=True,
182
+ text=True,
183
+ timeout=120,
184
+ env=env,
185
+ )
186
+ stdout = process.stdout or ""
187
+ has_start = "[START]" in stdout
188
+ has_end = "[END]" in stdout
189
+ end_has_score = " score=" in stdout
190
+ proxy_called = any(request["path"] == "/v1/chat/completions" for request in requests_seen)
191
+ auth_ok = any(request["authorization"] == "Bearer offline-validation-token" for request in requests_seen)
192
+ return process.returncode == 0 and has_start and has_end and end_has_score and proxy_called and auth_ok
193
+ finally:
194
+ proxy_server.shutdown()
195
+ proxy_server.server_close()
196
 
197
 
198
  def main() -> int: