Harden env isolation and proxy validation
Browse files- app.py +222 -555
- env/environment.py +515 -514
- env/graders.py +145 -144
- inference.py +40 -37
- tests/test_api.py +47 -0
- tests/test_inference_proxy.py +119 -0
- validate.py +84 -26
app.py
CHANGED
|
@@ -1,336 +1,173 @@
|
|
| 1 |
-
"""
|
| 2 |
-
FastAPI server exposing the rag-context-optimizer OpenEnv HTTP API.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
from dataclasses import asdict, is_dataclass
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Any, Literal
|
| 11 |
-
|
|
|
|
| 12 |
from fastapi import Body, FastAPI, HTTPException, Request
|
| 13 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
-
from fastapi.responses import HTMLResponse
|
| 15 |
-
from pydantic import BaseModel
|
| 16 |
-
|
|
|
|
| 17 |
from env.environment import RagContextOptimizerEnv
|
| 18 |
from env.models import RagAction
|
| 19 |
-
from env.corpus import list_corpus_families
|
| 20 |
from env.prompt_optimizer import CompressionMode, optimize_prompt
|
| 21 |
from env.tasks import ALL_TASKS, TASKS_BY_NAME
|
| 22 |
-
|
| 23 |
-
|
| 24 |
class ResetRequest(BaseModel):
|
| 25 |
task_name: Literal["single_domain_qa", "cross_domain_synthesis", "adversarial_compression"] = "single_domain_qa"
|
| 26 |
custom_query: str | None = None
|
| 27 |
token_budget: int | None = None
|
| 28 |
max_steps: int | None = None
|
| 29 |
corpus_family: str | None = None
|
| 30 |
-
|
| 31 |
-
|
| 32 |
class OptimizePromptRequest(BaseModel):
|
| 33 |
prompt: str
|
| 34 |
corpus_family: str | None = None
|
| 35 |
compression_mode: CompressionMode = "balanced"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
"
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
return
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
tokens = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/]*", raw)
|
| 170 |
-
kept: list[str] = []
|
| 171 |
-
seen: set[str] = set()
|
| 172 |
-
|
| 173 |
-
# Keep “meaningful” tokens: numbers, identifiers, longer words, and acronyms. Drop stopwords.
|
| 174 |
-
for tok in tokens:
|
| 175 |
-
low = tok.lower()
|
| 176 |
-
is_number = low.isdigit()
|
| 177 |
-
is_identifier = any(ch in tok for ch in ("_", "-", "/")) and len(tok) >= 4
|
| 178 |
-
is_acronym = tok.isupper() and len(tok) <= 8
|
| 179 |
-
is_meaningful = is_number or is_identifier or is_acronym or len(low) >= 4
|
| 180 |
-
if not is_meaningful:
|
| 181 |
-
continue
|
| 182 |
-
if low in _PROMPT_STOPWORDS:
|
| 183 |
-
continue
|
| 184 |
-
if low in seen:
|
| 185 |
-
continue
|
| 186 |
-
seen.add(low)
|
| 187 |
-
kept.append(tok)
|
| 188 |
-
if len(kept) >= max(10, target_tokens):
|
| 189 |
-
break
|
| 190 |
-
|
| 191 |
-
if not kept:
|
| 192 |
-
# Fallback: truncated raw prompt.
|
| 193 |
-
words = raw.split()
|
| 194 |
-
return " ".join(words[: max(8, target_tokens)]).rstrip(" ,;:") + (" ..." if len(words) > target_tokens else "")
|
| 195 |
-
|
| 196 |
-
# Turn the token list into a copy-paste-ready “goal” sentence.
|
| 197 |
-
goal = " ".join(kept)
|
| 198 |
-
goal = re.sub(r"\s+", " ", goal).strip()
|
| 199 |
-
return goal
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
_INSTRUCTION_PRIORITY_TERMS = {
|
| 203 |
-
"must","should","only","not","never","always","include","exclude","cite","answer",
|
| 204 |
-
"return","draft","write","summarize","compare","explain","verify","preserve","focus",
|
| 205 |
-
"keep","avoid","report","escalate","rollback","refund","incident","customer","security",
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
def _trim_sentence(sentence: str, max_terms: int) -> str:
|
| 210 |
-
import re
|
| 211 |
-
|
| 212 |
-
words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\\-_/]*|[,:;()]", sentence)
|
| 213 |
-
if not words:
|
| 214 |
-
return ""
|
| 215 |
-
kept: list[str] = []
|
| 216 |
-
|
| 217 |
-
for index, token in enumerate(words):
|
| 218 |
-
normalized = re.sub(r"[^A-Za-z0-9]+", "", token).lower()
|
| 219 |
-
if token in {",", ":", ";", "(", ")"}:
|
| 220 |
-
if kept and kept[-1] not in {",", ":", ";", "("}:
|
| 221 |
-
kept.append(token)
|
| 222 |
-
continue
|
| 223 |
-
is_priority = normalized in _INSTRUCTION_PRIORITY_TERMS
|
| 224 |
-
is_meaningful = (
|
| 225 |
-
normalized.isdigit()
|
| 226 |
-
or any(ch in token for ch in ("_", "-", "/"))
|
| 227 |
-
or len(normalized) >= 4
|
| 228 |
-
or is_priority
|
| 229 |
-
or index < 3
|
| 230 |
-
)
|
| 231 |
-
if not is_meaningful:
|
| 232 |
-
continue
|
| 233 |
-
if normalized in _PROMPT_STOPWORDS and not is_priority and index >= 3:
|
| 234 |
-
continue
|
| 235 |
-
kept.append(token)
|
| 236 |
-
if len([word for word in kept if word not in {",", ":", ";", "(", ")"}]) >= max_terms:
|
| 237 |
-
break
|
| 238 |
-
|
| 239 |
-
text = " ".join(kept)
|
| 240 |
-
text = re.sub(r"\s+([,:;)])", r"\1", text)
|
| 241 |
-
text = re.sub(r"(\()\s+", r"\1", text)
|
| 242 |
-
return text.strip(" ,;:")
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
def _rewrite_prompt_text(prompt: str, target_tokens: int) -> str:
|
| 246 |
-
import re
|
| 247 |
-
|
| 248 |
-
raw = " ".join(prompt.strip().split())
|
| 249 |
-
if not raw:
|
| 250 |
-
return ""
|
| 251 |
-
|
| 252 |
-
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", raw) if segment.strip()]
|
| 253 |
-
if not sentences:
|
| 254 |
-
sentences = [raw]
|
| 255 |
-
|
| 256 |
-
rewritten: list[str] = []
|
| 257 |
-
used_terms = 0
|
| 258 |
-
max_terms = max(8, target_tokens)
|
| 259 |
-
for index, sentence in enumerate(sentences):
|
| 260 |
-
remaining = max_terms - used_terms
|
| 261 |
-
if remaining <= 0:
|
| 262 |
-
break
|
| 263 |
-
compact = _trim_sentence(sentence, max(4, remaining if index == 0 else min(remaining, 10)))
|
| 264 |
-
if not compact:
|
| 265 |
-
continue
|
| 266 |
-
rewritten.append(compact)
|
| 267 |
-
used_terms += len(compact.split())
|
| 268 |
-
if used_terms >= max_terms:
|
| 269 |
-
break
|
| 270 |
-
|
| 271 |
-
if not rewritten:
|
| 272 |
-
fallback = _trim_sentence(raw, max_terms)
|
| 273 |
-
return fallback or raw[: max(16, target_tokens * 4)].strip()
|
| 274 |
-
|
| 275 |
-
output = ". ".join(rewritten).strip()
|
| 276 |
-
if len(rewritten) == 1 and not output.endswith("."):
|
| 277 |
-
output += "."
|
| 278 |
-
return output
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def _fit_citations_into_prompt(base_prompt: str, citation_ids: list[str], input_tokens: int, target_tokens: int, source_prompt: str) -> tuple[str, bool, str | None]:
|
| 282 |
-
if not citation_ids:
|
| 283 |
-
return base_prompt, False, "No high-confidence evidence anchors were selected."
|
| 284 |
-
|
| 285 |
-
citation_suffix = " Evidence: " + " ".join(f"[{chunk_id}]" for chunk_id in citation_ids[:3])
|
| 286 |
-
with_all = (base_prompt.rstrip(".") + "." + citation_suffix).strip()
|
| 287 |
-
if _approx_tokens(with_all) < input_tokens:
|
| 288 |
-
return with_all, True, None
|
| 289 |
-
|
| 290 |
-
one_citation_suffix = " Evidence: " + f"[{citation_ids[0]}]"
|
| 291 |
-
with_one = (base_prompt.rstrip(".") + "." + one_citation_suffix).strip()
|
| 292 |
-
if _approx_tokens(with_one) < input_tokens:
|
| 293 |
-
return with_one, True, None
|
| 294 |
-
|
| 295 |
-
tighter_target = max(8, target_tokens - 3)
|
| 296 |
-
tighter_prompt = _rewrite_prompt_text(source_prompt, tighter_target)
|
| 297 |
-
tighter_with_one = (tighter_prompt.rstrip(".") + "." + one_citation_suffix).strip()
|
| 298 |
-
if _approx_tokens(tighter_with_one) < input_tokens:
|
| 299 |
-
return tighter_with_one, True, None
|
| 300 |
-
|
| 301 |
-
return base_prompt, False, "Citations were omitted to keep the optimized prompt shorter than the original. Use the evidence notes below if explicit anchors are required."
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
def _summarize_chunk_for_output(chunk: Any, effective_text: str) -> str:
|
| 305 |
-
if getattr(chunk, "domain", "").startswith("Project"):
|
| 306 |
-
keywords = ", ".join(chunk.keywords[:5])
|
| 307 |
-
domain = chunk.domain.replace("Project ", "").lower()
|
| 308 |
-
return _compact_text(f"This benchmark's {domain} covers {keywords}.", 24)
|
| 309 |
-
ranked_sentences = _sentence_rank(" ".join(chunk.keywords), _clean_output_text(effective_text))
|
| 310 |
-
if ranked_sentences:
|
| 311 |
-
return _compact_text(_clean_output_text(ranked_sentences[0]))
|
| 312 |
-
return _compact_text(_clean_output_text(effective_text))
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def _sentence_rank(query: str, text: str) -> list[str]:
|
| 316 |
-
import re
|
| 317 |
-
|
| 318 |
-
query_terms = _tokenize(query)
|
| 319 |
-
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
|
| 320 |
-
if not sentences:
|
| 321 |
-
return []
|
| 322 |
-
|
| 323 |
-
ranked: list[tuple[float, str]] = []
|
| 324 |
-
for index, sentence in enumerate(sentences):
|
| 325 |
-
sentence_terms = _tokenize(sentence)
|
| 326 |
-
overlap = len(query_terms & sentence_terms)
|
| 327 |
-
score = (overlap * 2.0) + (0.25 if index == 0 else 0.0)
|
| 328 |
-
ranked.append((score, sentence))
|
| 329 |
-
|
| 330 |
-
ranked.sort(key=lambda item: (-item[0], len(item[1])))
|
| 331 |
-
return [sentence for _score, sentence in ranked]
|
| 332 |
-
|
| 333 |
-
|
| 334 |
async def _optimize_prompt_backend(
|
| 335 |
prompt: str,
|
| 336 |
corpus_family: str | None = None,
|
|
@@ -346,172 +183,8 @@ async def _optimize_prompt_backend(
|
|
| 346 |
"selected_keywords": result.selected_keywords,
|
| 347 |
"optimization_mode": result.optimization_mode,
|
| 348 |
}
|
| 349 |
-
clean_prompt = prompt.strip()
|
| 350 |
-
env = RagContextOptimizerEnv(
|
| 351 |
-
task_name="single_domain_qa",
|
| 352 |
-
query_override=clean_prompt,
|
| 353 |
-
token_budget_override=800,
|
| 354 |
-
max_steps_override=6,
|
| 355 |
-
corpus_family_override=corpus_family,
|
| 356 |
-
)
|
| 357 |
-
await env.reset()
|
| 358 |
-
|
| 359 |
-
tuning = env._last_tuning or env.context_tuner.tune(clean_prompt, env._available_chunks)
|
| 360 |
-
|
| 361 |
-
ranked_candidates = []
|
| 362 |
-
for chunk in env._available_chunks:
|
| 363 |
-
tuned = tuning.tuned_scores.get(chunk.chunk_id)
|
| 364 |
-
score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
|
| 365 |
-
if score < 0.16:
|
| 366 |
-
continue
|
| 367 |
-
ranked_candidates.append((chunk, score, tuned))
|
| 368 |
-
ranked_candidates.sort(
|
| 369 |
-
key=lambda item: (
|
| 370 |
-
-(item[1] / max(item[0].tokens, 1)),
|
| 371 |
-
-(item[2].citation_prior if item[2] is not None else 0.0),
|
| 372 |
-
-item[1],
|
| 373 |
-
item[0].chunk_id,
|
| 374 |
-
)
|
| 375 |
-
)
|
| 376 |
|
| 377 |
-
selected_ids: list[str] = []
|
| 378 |
-
token_cap = 360
|
| 379 |
-
running_tokens = 0
|
| 380 |
-
for chunk, score, tuned in ranked_candidates:
|
| 381 |
-
if len(selected_ids) >= 4:
|
| 382 |
-
break
|
| 383 |
-
if score < 0.22 and selected_ids:
|
| 384 |
-
break
|
| 385 |
-
projected = running_tokens + chunk.tokens
|
| 386 |
-
if projected > token_cap and selected_ids:
|
| 387 |
-
continue
|
| 388 |
-
selected_ids.append(chunk.chunk_id)
|
| 389 |
-
env._selected_chunks.append(chunk.chunk_id)
|
| 390 |
-
running_tokens += chunk.tokens
|
| 391 |
-
|
| 392 |
-
if not selected_ids and ranked_candidates:
|
| 393 |
-
best_chunk = ranked_candidates[0][0]
|
| 394 |
-
selected_ids.append(best_chunk.chunk_id)
|
| 395 |
-
env._selected_chunks.append(best_chunk.chunk_id)
|
| 396 |
-
|
| 397 |
-
for chunk_id in list(selected_ids):
|
| 398 |
-
chunk = env._chunk_map().get(chunk_id)
|
| 399 |
-
if chunk is None:
|
| 400 |
-
continue
|
| 401 |
-
tuned = tuning.tuned_scores.get(chunk_id)
|
| 402 |
-
score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
|
| 403 |
-
ratio = tuned.compression_ratio if tuned is not None else 0.5
|
| 404 |
-
if score >= 0.75:
|
| 405 |
-
ratio = max(ratio, 0.6)
|
| 406 |
-
env._compression_ratios[chunk_id] = ratio
|
| 407 |
-
|
| 408 |
-
input_tokens = _approx_tokens(clean_prompt)
|
| 409 |
-
# Target: strictly shorter than input, while preserving more structure for longer prompts.
|
| 410 |
-
if input_tokens <= 24:
|
| 411 |
-
target_ratio = 0.85
|
| 412 |
-
elif input_tokens <= 60:
|
| 413 |
-
target_ratio = 0.75
|
| 414 |
-
elif input_tokens <= 120:
|
| 415 |
-
target_ratio = 0.68
|
| 416 |
-
else:
|
| 417 |
-
target_ratio = 0.62
|
| 418 |
-
target_tokens = max(12, int(input_tokens * target_ratio))
|
| 419 |
-
target_tokens = min(target_tokens, 80)
|
| 420 |
-
|
| 421 |
-
compressed_goal = _rewrite_prompt_text(clean_prompt, target_tokens=target_tokens)
|
| 422 |
-
|
| 423 |
-
# Optionally add a tiny amount of distilled context, but only if it still stays shorter overall.
|
| 424 |
-
distilled_points: list[tuple[str, str]] = []
|
| 425 |
-
for chunk_id in env._selected_chunks:
|
| 426 |
-
chunk = env._chunk_map().get(chunk_id)
|
| 427 |
-
if chunk is None:
|
| 428 |
-
continue
|
| 429 |
-
best = _summarize_chunk_for_output(chunk, env._effective_chunk_text(chunk_id))
|
| 430 |
-
if best and all(existing_point != best for _existing_chunk_id, existing_point in distilled_points):
|
| 431 |
-
distilled_points.append((chunk_id, best))
|
| 432 |
-
if len(distilled_points) >= (2 if input_tokens < 80 else 3):
|
| 433 |
-
break
|
| 434 |
-
|
| 435 |
-
lines: list[str] = []
|
| 436 |
-
lines.append(compressed_goal if compressed_goal else clean_prompt)
|
| 437 |
-
if distilled_points and input_tokens >= 80:
|
| 438 |
-
lines.append("")
|
| 439 |
-
lines.append("Context:")
|
| 440 |
-
lines.extend([f"- [{chunk_id}] {point}" for chunk_id, point in distilled_points])
|
| 441 |
-
optimized_prompt = "\n".join(lines).strip()
|
| 442 |
-
|
| 443 |
-
# Hard guarantee: never return an “optimized” prompt longer than the input.
|
| 444 |
-
if input_tokens > 0 and _approx_tokens(optimized_prompt) >= input_tokens:
|
| 445 |
-
# Enforce by character budget (tokens ~= chars/4).
|
| 446 |
-
max_chars = max(12, (input_tokens - 1) * 4)
|
| 447 |
-
optimized_prompt = optimized_prompt[:max_chars].rstrip(" ,;:\n\t")
|
| 448 |
-
if optimized_prompt and not optimized_prompt.endswith("..."):
|
| 449 |
-
optimized_prompt = optimized_prompt + " ..."
|
| 450 |
-
# If still not strictly smaller (very small inputs), trim until it is.
|
| 451 |
-
while input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens and len(optimized_prompt) > 12:
|
| 452 |
-
optimized_prompt = optimized_prompt[:-6].rstrip(" ,;:\n\t") + " ..."
|
| 453 |
-
if input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens:
|
| 454 |
-
optimized_prompt = _rewrite_prompt_text(clean_prompt, target_tokens=max(5, input_tokens - 1))
|
| 455 |
-
if optimized_prompt and not optimized_prompt.endswith("...") and _approx_tokens(optimized_prompt) >= input_tokens:
|
| 456 |
-
optimized_prompt = optimized_prompt[: max(8, (input_tokens - 1) * 4)].strip() + " ..."
|
| 457 |
-
|
| 458 |
-
optimized_prompt, citation_ready, citation_guidance = _fit_citations_into_prompt(
|
| 459 |
-
optimized_prompt,
|
| 460 |
-
tuning.suggested_citations or list(env._selected_chunks),
|
| 461 |
-
input_tokens,
|
| 462 |
-
target_tokens,
|
| 463 |
-
clean_prompt,
|
| 464 |
-
)
|
| 465 |
|
| 466 |
-
original_prompt_tokens = input_tokens
|
| 467 |
-
optimized_prompt_tokens = _approx_tokens(optimized_prompt)
|
| 468 |
-
source_tokens = sum(env._chunk_map()[chunk_id].tokens for chunk_id in env._selected_chunks if chunk_id in env._chunk_map())
|
| 469 |
-
compressed_tokens = sum(env._effective_chunk_tokens(chunk_id) for chunk_id in env._selected_chunks)
|
| 470 |
-
evidence_terms = _content_terms(" ".join(env._effective_chunk_text(chunk_id) for chunk_id in env._selected_chunks))
|
| 471 |
-
prompt_terms = _content_terms(optimized_prompt)
|
| 472 |
-
inline_citations = set(re.findall(r"\[([a-z0-9_]+)\]", optimized_prompt.lower()))
|
| 473 |
-
grounded_overlap = (len(prompt_terms & evidence_terms) / len(prompt_terms)) if prompt_terms else 0.0
|
| 474 |
-
|
| 475 |
-
return {
|
| 476 |
-
"optimized_prompt": optimized_prompt,
|
| 477 |
-
"stats": {
|
| 478 |
-
"selected_chunks": len(env._selected_chunks),
|
| 479 |
-
"source_tokens": source_tokens,
|
| 480 |
-
"compressed_context_tokens": compressed_tokens,
|
| 481 |
-
"original_prompt_tokens": original_prompt_tokens,
|
| 482 |
-
"optimized_prompt_tokens": optimized_prompt_tokens,
|
| 483 |
-
"compression_gain": max(0, source_tokens - compressed_tokens),
|
| 484 |
-
},
|
| 485 |
-
"grounding": {
|
| 486 |
-
"citations": tuning.suggested_citations or list(env._selected_chunks),
|
| 487 |
-
"citation_ready": citation_ready and bool(inline_citations),
|
| 488 |
-
"citation_guidance": citation_guidance,
|
| 489 |
-
"grounded_overlap": round(grounded_overlap, 3),
|
| 490 |
-
"evidence_notes": [
|
| 491 |
-
{"chunk_id": chunk_id, "note": note}
|
| 492 |
-
for chunk_id, note in distilled_points
|
| 493 |
-
],
|
| 494 |
-
},
|
| 495 |
-
"context_tuning": {
|
| 496 |
-
"mode": tuning.mode,
|
| 497 |
-
"top_demo_cases": tuning.top_demo_cases,
|
| 498 |
-
"suggested_citations": tuning.suggested_citations,
|
| 499 |
-
"token_dropout": tuning.token_dropout,
|
| 500 |
-
"leave_one_out": tuning.leave_one_out,
|
| 501 |
-
},
|
| 502 |
-
"corpus_family": env._corpus_family,
|
| 503 |
-
"selected_keywords": [
|
| 504 |
-
keyword
|
| 505 |
-
for chunk_id in env._selected_chunks
|
| 506 |
-
for keyword in (env._chunk_map().get(chunk_id).keywords if env._chunk_map().get(chunk_id) else [])
|
| 507 |
-
][:10],
|
| 508 |
-
}
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
|
| 516 |
observation = env._build_observation()
|
| 517 |
selected = set(observation.selected_chunks)
|
|
@@ -546,8 +219,7 @@ def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
|
|
| 546 |
if chunk.keywords:
|
| 547 |
chosen_phrases.append(f"[{chunk.chunk_id}] " + ", ".join(chunk.keywords[:2]))
|
| 548 |
answer = (
|
| 549 |
-
"Grounded answer based on selected evidence: "
|
| 550 |
-
+ "; ".join(chosen_phrases[:3])
|
| 551 |
if chosen_phrases
|
| 552 |
else "Grounded answer based on the currently selected evidence."
|
| 553 |
)
|
|
@@ -559,37 +231,37 @@ def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
|
|
| 559 |
for chunk in sorted(
|
| 560 |
available,
|
| 561 |
key=lambda chunk: (
|
| 562 |
-
-(score_map.get(chunk.chunk_id).final_score if score_map.get(chunk.chunk_id) else 0.0)
|
| 563 |
-
/ max(chunk.tokens, 1),
|
| 564 |
chunk.tokens,
|
| 565 |
chunk.chunk_id,
|
| 566 |
),
|
| 567 |
):
|
| 568 |
if chunk.tokens <= remaining_budget:
|
| 569 |
return {"action_type": "select_chunk", "chunk_id": chunk.chunk_id}
|
| 570 |
-
|
| 571 |
-
if selected_chunks:
|
| 572 |
-
return {
|
| 573 |
-
"action_type": "submit_answer",
|
| 574 |
-
"answer": "Optimized answer based on the currently selected evidence.",
|
| 575 |
-
}
|
| 576 |
-
if available:
|
| 577 |
-
smallest_chunk = min(available, key=lambda chunk: (chunk.tokens, chunk.chunk_id))
|
| 578 |
-
return {
|
| 579 |
-
"action_type": "submit_answer",
|
| 580 |
-
"answer": (
|
| 581 |
-
"No chunk fits within the current token budget. "
|
| 582 |
-
f"Increase the budget to at least {smallest_chunk.tokens} tokens or choose a broader budget."
|
| 583 |
-
),
|
| 584 |
-
}
|
| 585 |
-
return {"action_type": "submit_answer", "answer": "No usable evidence was available."}
|
| 586 |
-
|
| 587 |
-
|
| 588 |
@app.post("/reset")
|
| 589 |
async def reset_endpoint(payload: ResetRequest | None = Body(default=None)):
|
| 590 |
payload = payload or ResetRequest()
|
| 591 |
if payload.task_name not in TASKS_BY_NAME:
|
| 592 |
raise HTTPException(status_code=400, detail="Unknown task_name.")
|
|
|
|
| 593 |
env = RagContextOptimizerEnv(
|
| 594 |
task_name=payload.task_name,
|
| 595 |
query_override=payload.custom_query,
|
|
@@ -597,49 +269,46 @@ async def reset_endpoint(payload: ResetRequest | None = Body(default=None)):
|
|
| 597 |
max_steps_override=payload.max_steps,
|
| 598 |
corpus_family_override=payload.corpus_family,
|
| 599 |
)
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
return _serialize_step_result(result, reset=True)
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
@app.post("/step")
|
| 606 |
-
async def step_endpoint(action: RagAction):
|
| 607 |
-
env =
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
@app.get("/tasks")
|
| 632 |
async def tasks_endpoint():
|
| 633 |
-
return [
|
| 634 |
-
{
|
| 635 |
-
"name": task.name,
|
| 636 |
-
"description": task.description,
|
| 637 |
-
"difficulty": task.difficulty,
|
| 638 |
-
"token_budget": task.token_budget,
|
| 639 |
-
"query": task.query,
|
| 640 |
-
"max_steps": task.max_steps,
|
| 641 |
-
}
|
| 642 |
-
for task in ALL_TASKS
|
| 643 |
]
|
| 644 |
|
| 645 |
|
|
@@ -649,13 +318,11 @@ async def corpus_families_endpoint():
|
|
| 649 |
|
| 650 |
|
| 651 |
@app.post("/optimize-step")
|
| 652 |
-
async def optimize_step_endpoint():
|
| 653 |
-
env =
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
@app.post("/optimize-prompt")
|
| 660 |
async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
|
| 661 |
if not payload.prompt.strip():
|
|
@@ -665,9 +332,9 @@ async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
|
|
| 665 |
corpus_family=payload.corpus_family,
|
| 666 |
compression_mode=payload.compression_mode,
|
| 667 |
)
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
if __name__ == "__main__":
|
| 671 |
-
import uvicorn
|
| 672 |
-
|
| 673 |
-
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server exposing the rag-context-optimizer OpenEnv HTTP API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
from dataclasses import asdict, is_dataclass
|
| 9 |
+
import os
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any, Literal
|
| 12 |
+
from uuid import uuid4
|
| 13 |
+
|
| 14 |
from fastapi import Body, FastAPI, HTTPException, Request
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from fastapi.responses import HTMLResponse
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
|
| 19 |
+
from env.corpus import list_corpus_families
|
| 20 |
from env.environment import RagContextOptimizerEnv
|
| 21 |
from env.models import RagAction
|
|
|
|
| 22 |
from env.prompt_optimizer import CompressionMode, optimize_prompt
|
| 23 |
from env.tasks import ALL_TASKS, TASKS_BY_NAME
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class ResetRequest(BaseModel):
|
| 27 |
task_name: Literal["single_domain_qa", "cross_domain_synthesis", "adversarial_compression"] = "single_domain_qa"
|
| 28 |
custom_query: str | None = None
|
| 29 |
token_budget: int | None = None
|
| 30 |
max_steps: int | None = None
|
| 31 |
corpus_family: str | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
class OptimizePromptRequest(BaseModel):
|
| 35 |
prompt: str
|
| 36 |
corpus_family: str | None = None
|
| 37 |
compression_mode: CompressionMode = "balanced"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EpisodeStore:
|
| 41 |
+
def __init__(self, max_episodes: int = 16):
|
| 42 |
+
self._episodes: dict[str, RagContextOptimizerEnv] = {}
|
| 43 |
+
self._order: list[str] = []
|
| 44 |
+
self.latest_episode_id: str | None = None
|
| 45 |
+
self._max_episodes = max_episodes
|
| 46 |
+
|
| 47 |
+
async def close_all(self) -> None:
|
| 48 |
+
for env in self._episodes.values():
|
| 49 |
+
await env.close()
|
| 50 |
+
self._episodes.clear()
|
| 51 |
+
self._order.clear()
|
| 52 |
+
self.latest_episode_id = None
|
| 53 |
+
|
| 54 |
+
async def create(self, env: RagContextOptimizerEnv) -> str:
|
| 55 |
+
episode_id = uuid4().hex
|
| 56 |
+
self._episodes[episode_id] = env
|
| 57 |
+
self._order.append(episode_id)
|
| 58 |
+
self.latest_episode_id = episode_id
|
| 59 |
+
|
| 60 |
+
while len(self._order) > self._max_episodes:
|
| 61 |
+
stale_id = self._order.pop(0)
|
| 62 |
+
stale_env = self._episodes.pop(stale_id, None)
|
| 63 |
+
if stale_env is not None:
|
| 64 |
+
await stale_env.close()
|
| 65 |
+
if self.latest_episode_id == stale_id:
|
| 66 |
+
self.latest_episode_id = self._order[-1] if self._order else None
|
| 67 |
+
return episode_id
|
| 68 |
+
|
| 69 |
+
def get(self, episode_id: str | None) -> tuple[str, RagContextOptimizerEnv]:
|
| 70 |
+
resolved_id = episode_id or self.latest_episode_id
|
| 71 |
+
if resolved_id is None or resolved_id not in self._episodes:
|
| 72 |
+
raise KeyError("episode_not_found")
|
| 73 |
+
return resolved_id, self._episodes[resolved_id]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _request_logging_enabled() -> bool:
|
| 77 |
+
return os.getenv("DEBUG_LOG_REQUESTS", "").strip().lower() in {"1", "true", "yes"}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@asynccontextmanager
|
| 81 |
+
async def lifespan(app: FastAPI):
|
| 82 |
+
app.state.episodes = EpisodeStore()
|
| 83 |
+
yield
|
| 84 |
+
await app.state.episodes.close_all()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
app = FastAPI(
|
| 88 |
+
title="rag-context-optimizer",
|
| 89 |
+
version="1.0.0",
|
| 90 |
+
description="RAG pipeline optimization environment - minimize tokens, maximize answer quality",
|
| 91 |
+
lifespan=lifespan,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
app.add_middleware(
|
| 95 |
+
CORSMiddleware,
|
| 96 |
+
allow_origins=["*"],
|
| 97 |
+
allow_credentials=False,
|
| 98 |
+
allow_methods=["*"],
|
| 99 |
+
allow_headers=["*"],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
UI_TEMPLATE_PATH = Path(__file__).resolve().parent / "server" / "templates" / "ui.html"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@app.middleware("http")
|
| 106 |
+
async def log_requests(request: Request, call_next):
|
| 107 |
+
should_log = _request_logging_enabled()
|
| 108 |
+
if should_log:
|
| 109 |
+
print(f"[request] {request.method} {request.url.path}")
|
| 110 |
+
response = await call_next(request)
|
| 111 |
+
if should_log:
|
| 112 |
+
print(f"[response] {request.method} {request.url.path} -> {response.status_code}")
|
| 113 |
+
return response
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@app.get("/", response_class=HTMLResponse)
|
| 117 |
+
async def home_page():
|
| 118 |
+
return HTMLResponse(
|
| 119 |
+
UI_TEMPLATE_PATH.read_text(encoding="utf-8"),
|
| 120 |
+
headers={
|
| 121 |
+
"Cache-Control": "no-store, max-age=0",
|
| 122 |
+
"Pragma": "no-cache",
|
| 123 |
+
},
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _serialize_observation(observation: Any) -> dict[str, Any]:
|
| 128 |
+
if hasattr(observation, "model_dump"):
|
| 129 |
+
return observation.model_dump()
|
| 130 |
+
if is_dataclass(observation):
|
| 131 |
+
return asdict(observation)
|
| 132 |
+
return dict(observation)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _serialize_step_result(result: Any, reset: bool = False, episode_id: str | None = None) -> dict[str, Any]:
|
| 136 |
+
raw_info = result.info or {}
|
| 137 |
+
payload = {
|
| 138 |
+
"observation": _serialize_observation(result.observation),
|
| 139 |
+
"reward": None if reset else result.reward,
|
| 140 |
+
"done": False if reset else result.done,
|
| 141 |
+
"info": {} if reset else {
|
| 142 |
+
"grader_breakdown": raw_info.get("grader"),
|
| 143 |
+
"event": raw_info.get("event"),
|
| 144 |
+
"passed": raw_info.get("passed"),
|
| 145 |
+
},
|
| 146 |
+
}
|
| 147 |
+
if episode_id is not None:
|
| 148 |
+
payload["episode_id"] = episode_id
|
| 149 |
+
return payload
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _is_bad_action_event(event: str | None) -> bool:
|
| 153 |
+
return event in {"chunk_not_found"}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _episode_store() -> EpisodeStore:
|
| 157 |
+
episodes = getattr(app.state, "episodes", None)
|
| 158 |
+
if episodes is None:
|
| 159 |
+
episodes = EpisodeStore()
|
| 160 |
+
app.state.episodes = episodes
|
| 161 |
+
return episodes
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _resolve_env(episode_id: str | None) -> tuple[str, RagContextOptimizerEnv]:
|
| 165 |
+
try:
|
| 166 |
+
return _episode_store().get(episode_id)
|
| 167 |
+
except KeyError as exc:
|
| 168 |
+
raise HTTPException(status_code=404, detail="Episode not found. Call /reset first.") from exc
|
| 169 |
+
|
| 170 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
async def _optimize_prompt_backend(
|
| 172 |
prompt: str,
|
| 173 |
corpus_family: str | None = None,
|
|
|
|
| 183 |
"selected_keywords": result.selected_keywords,
|
| 184 |
"optimization_mode": result.optimization_mode,
|
| 185 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def _suggest_action(env: RagContextOptimizerEnv) -> dict[str, Any]:
|
| 189 |
observation = env._build_observation()
|
| 190 |
selected = set(observation.selected_chunks)
|
|
|
|
| 219 |
if chunk.keywords:
|
| 220 |
chosen_phrases.append(f"[{chunk.chunk_id}] " + ", ".join(chunk.keywords[:2]))
|
| 221 |
answer = (
|
| 222 |
+
"Grounded answer based on selected evidence: " + "; ".join(chosen_phrases[:3])
|
|
|
|
| 223 |
if chosen_phrases
|
| 224 |
else "Grounded answer based on the currently selected evidence."
|
| 225 |
)
|
|
|
|
| 231 |
for chunk in sorted(
|
| 232 |
available,
|
| 233 |
key=lambda chunk: (
|
| 234 |
+
-(score_map.get(chunk.chunk_id).final_score if score_map.get(chunk.chunk_id) else 0.0) / max(chunk.tokens, 1),
|
|
|
|
| 235 |
chunk.tokens,
|
| 236 |
chunk.chunk_id,
|
| 237 |
),
|
| 238 |
):
|
| 239 |
if chunk.tokens <= remaining_budget:
|
| 240 |
return {"action_type": "select_chunk", "chunk_id": chunk.chunk_id}
|
| 241 |
+
|
| 242 |
+
if selected_chunks:
|
| 243 |
+
return {
|
| 244 |
+
"action_type": "submit_answer",
|
| 245 |
+
"answer": "Optimized answer based on the currently selected evidence.",
|
| 246 |
+
}
|
| 247 |
+
if available:
|
| 248 |
+
smallest_chunk = min(available, key=lambda chunk: (chunk.tokens, chunk.chunk_id))
|
| 249 |
+
return {
|
| 250 |
+
"action_type": "submit_answer",
|
| 251 |
+
"answer": (
|
| 252 |
+
"No chunk fits within the current token budget. "
|
| 253 |
+
f"Increase the budget to at least {smallest_chunk.tokens} tokens or choose a broader budget."
|
| 254 |
+
),
|
| 255 |
+
}
|
| 256 |
+
return {"action_type": "submit_answer", "answer": "No usable evidence was available."}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
@app.post("/reset")
|
| 260 |
async def reset_endpoint(payload: ResetRequest | None = Body(default=None)):
|
| 261 |
payload = payload or ResetRequest()
|
| 262 |
if payload.task_name not in TASKS_BY_NAME:
|
| 263 |
raise HTTPException(status_code=400, detail="Unknown task_name.")
|
| 264 |
+
|
| 265 |
env = RagContextOptimizerEnv(
|
| 266 |
task_name=payload.task_name,
|
| 267 |
query_override=payload.custom_query,
|
|
|
|
| 269 |
max_steps_override=payload.max_steps,
|
| 270 |
corpus_family_override=payload.corpus_family,
|
| 271 |
)
|
| 272 |
+
result = await env.reset()
|
| 273 |
+
episode_id = await _episode_store().create(env)
|
| 274 |
+
return _serialize_step_result(result, reset=True, episode_id=episode_id)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@app.post("/step")
|
| 278 |
+
async def step_endpoint(action: RagAction, episode_id: str | None = None):
|
| 279 |
+
resolved_episode_id, env = _resolve_env(episode_id)
|
| 280 |
+
result = await env.step(action)
|
| 281 |
+
event = (result.info or {}).get("event")
|
| 282 |
+
if _is_bad_action_event(event):
|
| 283 |
+
raise HTTPException(status_code=400, detail=event)
|
| 284 |
+
return _serialize_step_result(result, reset=False, episode_id=resolved_episode_id)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@app.get("/state")
|
| 288 |
+
async def state_endpoint(episode_id: str | None = None):
|
| 289 |
+
resolved_episode_id, env = _resolve_env(episode_id)
|
| 290 |
+
state = await env.state()
|
| 291 |
+
state["episode_id"] = resolved_episode_id
|
| 292 |
+
return state
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@app.get("/health")
|
| 296 |
+
async def health_endpoint():
|
| 297 |
+
return {"status": "ok", "tasks": [task.name for task in ALL_TASKS]}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@app.get("/tasks")
|
|
|
|
|
|
|
|
|
|
| 301 |
async def tasks_endpoint():
|
| 302 |
+
return [
|
| 303 |
+
{
|
| 304 |
+
"name": task.name,
|
| 305 |
+
"description": task.description,
|
| 306 |
+
"difficulty": task.difficulty,
|
| 307 |
+
"token_budget": task.token_budget,
|
| 308 |
+
"query": task.query,
|
| 309 |
+
"max_steps": task.max_steps,
|
| 310 |
+
}
|
| 311 |
+
for task in ALL_TASKS
|
| 312 |
]
|
| 313 |
|
| 314 |
|
|
|
|
| 318 |
|
| 319 |
|
| 320 |
@app.post("/optimize-step")
|
| 321 |
+
async def optimize_step_endpoint(episode_id: str | None = None):
|
| 322 |
+
_resolved_episode_id, env = _resolve_env(episode_id)
|
| 323 |
+
return _suggest_action(env)
|
| 324 |
+
|
| 325 |
+
|
|
|
|
|
|
|
| 326 |
@app.post("/optimize-prompt")
|
| 327 |
async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
|
| 328 |
if not payload.prompt.strip():
|
|
|
|
| 332 |
corpus_family=payload.corpus_family,
|
| 333 |
compression_mode=payload.compression_mode,
|
| 334 |
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
import uvicorn
|
| 339 |
+
|
| 340 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
|
env/environment.py
CHANGED
|
@@ -1,532 +1,533 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Main OpenEnv-style environment for rag-context-optimizer.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
from dataclasses import asdict, dataclass, is_dataclass, replace
|
| 8 |
-
import os
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
import re
|
| 11 |
-
from typing import Any
|
| 12 |
-
|
| 13 |
-
from env.corpus import Chunk, load_corpus, resolve_corpus_path
|
| 14 |
-
from env.context_tuner import ContextTunedPlanner
|
| 15 |
-
from env.graders import TaskGrader
|
| 16 |
-
from env.models import ChunkSummary, RagAction, RagObservation
|
| 17 |
-
from env.retriever import HybridRetriever
|
| 18 |
-
from env.tasks import ALL_TASKS, TASKS_BY_NAME, Task
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass(slots=True)
|
| 22 |
-
class StepResult:
|
| 23 |
-
observation: RagObservation
|
| 24 |
-
reward: float
|
| 25 |
-
done: bool
|
| 26 |
-
info: dict[str, Any]
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class RagContextOptimizerEnv:
|
| 30 |
-
_PROJECT_STOPWORDS = {
|
| 31 |
-
"the", "and", "for", "with", "that", "this", "from", "into", "your", "have", "will",
|
| 32 |
-
"using", "used", "use", "into", "they", "them", "their", "about", "while", "where",
|
| 33 |
-
"when", "what", "which", "should", "would", "could", "there", "here", "then", "than",
|
| 34 |
-
"each", "such", "only", "also", "been", "being", "does", "did", "done", "just", "more",
|
| 35 |
-
"most", "very", "over", "under", "like", "same", "across", "because", "through", "make",
|
| 36 |
-
"made", "many", "much", "some", "into", "onto", "must", "need", "needs", "task", "tasks",
|
| 37 |
-
"chunk", "chunks", "query", "prompt", "environment", "agent", "agents", "model", "models",
|
| 38 |
-
}
|
| 39 |
-
_PROJECT_QUERY_HINTS = {
|
| 40 |
-
"openenv", "benchmark", "rag-context-optimizer", "readme", "docker", "fastapi", "api",
|
| 41 |
-
"endpoint", "inference.py", "app.py", "tasks.py", "graders.py", "environment.py", "repo",
|
| 42 |
-
"repository", "codebase", "ui", "frontend", "backend", "space", "validator",
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
def __init__(
|
| 46 |
-
self,
|
| 47 |
-
task_name: str = "single_domain_qa",
|
| 48 |
-
query_override: str | None = None,
|
| 49 |
-
token_budget_override: int | None = None,
|
| 50 |
-
max_steps_override: int | None = None,
|
| 51 |
-
corpus_family_override: str | None = None,
|
| 52 |
-
):
|
| 53 |
-
if task_name not in TASKS_BY_NAME:
|
| 54 |
-
raise ValueError(f"Unknown task_name: {task_name}")
|
| 55 |
-
|
| 56 |
self._corpus_family = corpus_family_override or os.getenv("RAG_CORPUS_FAMILY") or "enterprise_v1"
|
| 57 |
explicit_path = os.getenv("RAG_CORPUS_PATH")
|
| 58 |
self._corpus_path = resolve_corpus_path(explicit_path, family=None if explicit_path else self._corpus_family)
|
| 59 |
self._all_chunks = load_corpus(self._corpus_path)
|
| 60 |
self._query_overridden = bool(query_override and query_override.strip())
|
| 61 |
-
self.
|
|
|
|
| 62 |
self.retriever = HybridRetriever(self._all_chunks + self._project_chunks)
|
| 63 |
-
self.context_tuner = ContextTunedPlanner(
|
| 64 |
-
self.retriever,
|
| 65 |
-
self._all_chunks + self._project_chunks,
|
| 66 |
-
list(ALL_TASKS),
|
| 67 |
-
)
|
| 68 |
-
self.grader = TaskGrader()
|
| 69 |
-
self.task: Task = self._build_task(
|
| 70 |
-
TASKS_BY_NAME[task_name],
|
| 71 |
-
query_override=query_override,
|
| 72 |
-
token_budget_override=token_budget_override,
|
| 73 |
-
max_steps_override=max_steps_override,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
self._available_chunks: list[Chunk] = []
|
| 77 |
-
self._selected_chunks: list[str] = []
|
| 78 |
-
self._compression_ratios: dict[str, float] = {}
|
| 79 |
-
self._step_number = 0
|
| 80 |
-
self._done = False
|
| 81 |
-
self._last_action_feedback: str | None = None
|
| 82 |
-
self._last_answer = ""
|
| 83 |
-
self._last_tuning = None
|
| 84 |
-
|
| 85 |
-
@staticmethod
|
| 86 |
-
def _build_task(
|
| 87 |
-
base_task: Task,
|
| 88 |
-
query_override: str | None = None,
|
| 89 |
-
token_budget_override: int | None = None,
|
| 90 |
-
max_steps_override: int | None = None,
|
| 91 |
-
) -> Task:
|
| 92 |
-
updated_task = base_task
|
| 93 |
-
if query_override and query_override.strip():
|
| 94 |
-
updated_task = replace(updated_task, query=query_override.strip(), domain_filter=None)
|
| 95 |
-
if token_budget_override is not None and token_budget_override > 0:
|
| 96 |
-
updated_task = replace(updated_task, token_budget=token_budget_override)
|
| 97 |
-
if max_steps_override is not None and max_steps_override > 0:
|
| 98 |
-
updated_task = replace(updated_task, max_steps=max_steps_override)
|
| 99 |
-
return updated_task
|
| 100 |
-
|
| 101 |
-
async def reset(self) -> StepResult:
|
| 102 |
-
candidate_chunks = self._filter_chunks_for_task(self.task)
|
| 103 |
-
self._available_chunks = self._rank_chunks_for_query(self.task.query, candidate_chunks)
|
| 104 |
-
if not self._query_overridden:
|
| 105 |
-
chunk_by_id = {chunk.chunk_id: chunk for chunk in candidate_chunks}
|
| 106 |
-
for chunk_id in self.task.required_chunk_ids:
|
| 107 |
-
chunk = chunk_by_id.get(chunk_id)
|
| 108 |
-
if chunk and all(existing.chunk_id != chunk_id for existing in self._available_chunks):
|
| 109 |
-
self._available_chunks.append(chunk)
|
| 110 |
-
self._selected_chunks = []
|
| 111 |
-
self._compression_ratios = {}
|
| 112 |
-
self._step_number = 0
|
| 113 |
-
self._done = False
|
| 114 |
-
self._last_action_feedback = None
|
| 115 |
-
self._last_answer = ""
|
| 116 |
-
|
| 117 |
-
observation = self._build_observation()
|
| 118 |
-
return StepResult(
|
| 119 |
-
observation=observation,
|
| 120 |
-
reward=0.0,
|
| 121 |
-
done=False,
|
| 122 |
-
info={"task": self.task.name, "event": "reset"},
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
async def step(self, action: RagAction) -> StepResult:
|
| 126 |
-
if self._done:
|
| 127 |
-
return StepResult(
|
| 128 |
-
observation=self._build_observation(),
|
| 129 |
-
reward=0.0,
|
| 130 |
-
done=True,
|
| 131 |
-
info={"task": self.task.name, "event": "episode_already_done"},
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
reward = 0.0
|
| 135 |
-
info: dict[str, Any] = {"task": self.task.name, "action_type": action.action_type}
|
| 136 |
-
|
| 137 |
-
if action.action_type == "select_chunk":
|
| 138 |
-
reward, info = self._handle_select(action.chunk_id or "")
|
| 139 |
-
elif action.action_type == "deselect_chunk":
|
| 140 |
-
reward, info = self._handle_deselect(action.chunk_id or "")
|
| 141 |
-
elif action.action_type == "compress_chunk":
|
| 142 |
-
reward, info = self._handle_compress(action.chunk_id or "", float(action.compression_ratio or 0.0))
|
| 143 |
-
elif action.action_type == "submit_answer":
|
| 144 |
-
self._last_answer = action.answer or ""
|
| 145 |
-
result = self._finalize_submission(reason="submit_answer")
|
| 146 |
-
self._step_number += 1
|
| 147 |
-
result.observation.step_number = self._step_number
|
| 148 |
-
return result
|
| 149 |
-
|
| 150 |
-
self._step_number += 1
|
| 151 |
-
|
| 152 |
-
if self._step_number >= self.task.max_steps:
|
| 153 |
-
return self._finalize_submission(reason="max_steps_reached")
|
| 154 |
-
|
| 155 |
-
observation = self._build_observation()
|
| 156 |
-
return StepResult(
|
| 157 |
-
observation=observation,
|
| 158 |
-
reward=reward,
|
| 159 |
-
done=False,
|
| 160 |
-
info=info,
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
async def state(self) -> dict:
|
| 164 |
-
selected_chunk_details = []
|
| 165 |
-
for chunk_id in self._selected_chunks:
|
| 166 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 167 |
-
if chunk is None:
|
| 168 |
-
continue
|
| 169 |
-
selected_chunk_details.append(
|
| 170 |
-
{
|
| 171 |
-
"chunk_id": chunk.chunk_id,
|
| 172 |
-
"domain": chunk.domain,
|
| 173 |
-
"original_tokens": chunk.tokens,
|
| 174 |
-
"effective_tokens": self._effective_chunk_tokens(chunk_id),
|
| 175 |
-
"compression_ratio": round(self._compression_ratios.get(chunk_id, 1.0), 3),
|
| 176 |
-
"text": self._effective_chunk_text(chunk_id),
|
| 177 |
-
"keywords": chunk.keywords,
|
| 178 |
-
}
|
| 179 |
-
)
|
| 180 |
-
optimized_prompt = self._build_optimized_prompt()
|
| 181 |
-
return {
|
| 182 |
-
"task": asdict(self.task) if is_dataclass(self.task) else self.task,
|
| 183 |
-
"step_number": self._step_number,
|
| 184 |
-
"done": self._done,
|
| 185 |
-
"selected_chunks": list(self._selected_chunks),
|
| 186 |
-
"compression_ratios": dict(self._compression_ratios),
|
| 187 |
-
"total_tokens_used": self._total_tokens_used(),
|
| 188 |
-
"token_budget": self.task.token_budget,
|
| 189 |
-
"last_action_feedback": self._last_action_feedback,
|
| 190 |
-
"last_answer": self._last_answer,
|
| 191 |
-
"corpus_family": self._corpus_family,
|
| 192 |
-
"corpus_path": str(self._corpus_path),
|
| 193 |
-
"available_chunk_ids": [chunk.chunk_id for chunk in self._available_chunks],
|
| 194 |
-
"selected_chunk_details": selected_chunk_details,
|
| 195 |
-
"optimized_prompt_preview": optimized_prompt,
|
| 196 |
-
"optimized_prompt_tokens": max(1, len(optimized_prompt) // 4) if optimized_prompt else 0,
|
| 197 |
-
"context_tuning": (
|
| 198 |
-
{
|
| 199 |
-
"mode": self._last_tuning.mode,
|
| 200 |
-
"top_demo_cases": self._last_tuning.top_demo_cases,
|
| 201 |
-
"suggested_citations": self._last_tuning.suggested_citations,
|
| 202 |
-
"token_dropout": self._last_tuning.token_dropout,
|
| 203 |
-
"leave_one_out": self._last_tuning.leave_one_out,
|
| 204 |
-
}
|
| 205 |
-
if self._last_tuning is not None
|
| 206 |
-
else None
|
| 207 |
-
),
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
async def close(self):
|
| 211 |
-
self._done = True
|
| 212 |
-
|
| 213 |
-
def _filter_chunks_for_task(self, task: Task) -> list[Chunk]:
|
| 214 |
-
domain_mapping = {
|
| 215 |
-
"customer_support_operations": "Customer Support Operations",
|
| 216 |
-
"incident_response_playbooks": "Incident Response Playbooks",
|
| 217 |
-
"platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
|
| 218 |
}
|
| 219 |
if self._query_overridden:
|
| 220 |
-
if self._is_project_query(task.query):
|
| 221 |
return list(self._all_chunks) + list(self._project_chunks)
|
| 222 |
return list(self._all_chunks)
|
| 223 |
-
if task.domain_filter is None:
|
| 224 |
-
return list(self._all_chunks)
|
| 225 |
-
normalized = domain_mapping.get(task.domain_filter, task.domain_filter)
|
| 226 |
-
return [chunk for chunk in self._all_chunks if chunk.domain == normalized]
|
| 227 |
-
|
| 228 |
-
def _is_project_query(self, query: str) -> bool:
|
| 229 |
-
lowered = query.lower()
|
| 230 |
-
return any(hint in lowered for hint in self._PROJECT_QUERY_HINTS)
|
| 231 |
-
|
| 232 |
-
def _rank_chunks_for_query(self, query: str, chunks: list[Chunk], top_k: int = 20) -> list[Chunk]:
|
| 233 |
-
tuning = self.context_tuner.tune(query, chunks)
|
| 234 |
-
self._last_tuning = tuning
|
| 235 |
-
scored = []
|
| 236 |
for chunk in chunks:
|
| 237 |
tuned = tuning.tuned_scores.get(chunk.chunk_id)
|
| 238 |
score = tuned.final_score if tuned is not None else self.retriever.hybrid_score(query, chunk)
|
| 239 |
-
if self._query_overridden and chunk.domain.startswith("Project"):
|
| 240 |
score = min(1.0, score + 0.08)
|
| 241 |
scored.append((chunk, score))
|
| 242 |
-
scored.sort(key=lambda item: (-item[1], item[0].tokens, item[0].chunk_id))
|
| 243 |
-
if not scored:
|
| 244 |
-
return []
|
| 245 |
-
|
| 246 |
-
capped = scored[: max(1, min(top_k * 2, len(scored)))]
|
| 247 |
-
best_score = capped[0][1]
|
| 248 |
-
floor = max(0.12, best_score * 0.38)
|
| 249 |
-
filtered_pairs = [(chunk, score) for chunk, score in capped if score >= floor]
|
| 250 |
-
|
| 251 |
-
if self._query_overridden:
|
| 252 |
project_pairs = [(chunk, score) for chunk, score in filtered_pairs if chunk.domain.startswith("Project")]
|
| 253 |
if len(project_pairs) >= 4:
|
| 254 |
filtered_pairs = project_pairs + [
|
| 255 |
-
(chunk, score)
|
| 256 |
-
for chunk, score in filtered_pairs
|
| 257 |
-
if not chunk.domain.startswith("Project")
|
| 258 |
-
]
|
| 259 |
-
|
| 260 |
-
filtered = [chunk for chunk, _score in filtered_pairs]
|
| 261 |
-
if not filtered:
|
| 262 |
-
filtered = [chunk for chunk, _score in capped[: max(1, min(top_k, len(capped)))]]
|
| 263 |
-
|
| 264 |
-
return filtered[: max(1, min(top_k, len(filtered)))]
|
| 265 |
-
|
| 266 |
-
def _load_project_chunks(self) -> list[Chunk]:
|
| 267 |
-
root = Path(__file__).resolve().parent.parent
|
| 268 |
-
chunks: list[Chunk] = []
|
| 269 |
-
file_specs = [
|
| 270 |
-
("Project Documentation", root / "README.md", ["project_docs", "readme"]),
|
| 271 |
-
("Project Configuration", root / "openenv.yaml", ["project_docs", "config", "openenv_spec"]),
|
| 272 |
-
("Project API", root / "app.py", ["project_docs", "api", "server"]),
|
| 273 |
-
("Project Baseline", root / "inference.py", ["project_docs", "baseline", "inference"]),
|
| 274 |
-
("Project Environment", root / "env" / "environment.py", ["project_docs", "environment", "state_management"]),
|
| 275 |
-
("Project Retrieval", root / "env" / "retriever.py", ["project_docs", "retrieval", "ranking"]),
|
| 276 |
-
("Project Grading", root / "env" / "graders.py", ["project_docs", "grading", "reward_design"]),
|
| 277 |
-
("Project Tasks", root / "env" / "tasks.py", ["project_docs", "tasks", "difficulty"]),
|
| 278 |
-
("Project Validation", root / "validate.py", ["project_docs", "validation", "testing"]),
|
| 279 |
-
]
|
| 280 |
-
|
| 281 |
-
for domain, path, tags in file_specs:
|
| 282 |
-
if not path.exists():
|
| 283 |
-
continue
|
| 284 |
-
raw_text = path.read_text(encoding="utf-8", errors="ignore")
|
| 285 |
-
sections = self._chunk_project_text(raw_text)
|
| 286 |
-
stem = re.sub(r"[^a-z0-9]+", "_", path.stem.lower()).strip("_") or "file"
|
| 287 |
-
for index, section in enumerate(sections, start=1):
|
| 288 |
-
keywords = self._extract_project_keywords(section)
|
| 289 |
-
if not keywords:
|
| 290 |
-
keywords = [stem, domain.lower()]
|
| 291 |
-
chunks.append(
|
| 292 |
-
Chunk(
|
| 293 |
-
chunk_id=f"project_{stem}_{index:03d}",
|
| 294 |
-
domain=domain,
|
| 295 |
-
text=section,
|
| 296 |
-
tokens=max(30, len(section) // 4),
|
| 297 |
-
keywords=keywords[:5],
|
| 298 |
-
relevance_tags=tags,
|
| 299 |
-
)
|
| 300 |
-
)
|
| 301 |
-
return chunks
|
| 302 |
-
|
| 303 |
-
def _chunk_project_text(self, raw_text: str, chunk_words: int = 140, stride_words: int = 100) -> list[str]:
|
| 304 |
-
cleaned = " ".join(raw_text.split())
|
| 305 |
-
words = cleaned.split()
|
| 306 |
-
if not words:
|
| 307 |
-
return []
|
| 308 |
-
if len(words) <= chunk_words:
|
| 309 |
-
return [" ".join(words)]
|
| 310 |
-
|
| 311 |
-
chunks: list[str] = []
|
| 312 |
-
start = 0
|
| 313 |
-
while start < len(words):
|
| 314 |
-
window = words[start : start + chunk_words]
|
| 315 |
-
if not window:
|
| 316 |
-
break
|
| 317 |
-
chunks.append(" ".join(window))
|
| 318 |
-
if start + chunk_words >= len(words):
|
| 319 |
-
break
|
| 320 |
-
start += stride_words
|
| 321 |
-
return chunks
|
| 322 |
-
|
| 323 |
-
def _extract_project_keywords(self, text: str) -> list[str]:
|
| 324 |
-
terms = re.findall(r"[a-z0-9_]+", text.lower())
|
| 325 |
-
counts: dict[str, int] = {}
|
| 326 |
-
for term in terms:
|
| 327 |
-
if len(term) < 4 or term in self._PROJECT_STOPWORDS:
|
| 328 |
-
continue
|
| 329 |
-
counts[term] = counts.get(term, 0) + 1
|
| 330 |
-
ranked = sorted(counts.items(), key=lambda item: (-item[1], item[0]))
|
| 331 |
-
return [term.replace("_", " ") for term, _count in ranked[:8]]
|
| 332 |
-
|
| 333 |
-
def _build_observation(self) -> RagObservation:
|
| 334 |
-
return RagObservation(
|
| 335 |
-
query=self.task.query,
|
| 336 |
-
available_chunks=[
|
| 337 |
-
ChunkSummary(
|
| 338 |
-
chunk_id=chunk.chunk_id,
|
| 339 |
-
domain=chunk.domain,
|
| 340 |
-
tokens=self._effective_chunk_tokens(chunk.chunk_id),
|
| 341 |
-
keywords=chunk.keywords,
|
| 342 |
-
)
|
| 343 |
-
for chunk in self._available_chunks
|
| 344 |
-
],
|
| 345 |
-
selected_chunks=list(self._selected_chunks),
|
| 346 |
-
total_tokens_used=self._total_tokens_used(),
|
| 347 |
-
token_budget=self.task.token_budget,
|
| 348 |
-
step_number=self._step_number,
|
| 349 |
-
task_name=self.task.name,
|
| 350 |
-
last_action_feedback=self._last_action_feedback,
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
def _chunk_map(self) -> dict[str, Chunk]:
|
| 354 |
-
return {chunk.chunk_id: chunk for chunk in self._available_chunks}
|
| 355 |
-
|
| 356 |
-
def _effective_chunk_tokens(self, chunk_id: str) -> int:
|
| 357 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 358 |
-
if chunk is None:
|
| 359 |
-
return 0
|
| 360 |
-
ratio = self._compression_ratios.get(chunk_id, 1.0)
|
| 361 |
-
return max(1, int(round(chunk.tokens * ratio)))
|
| 362 |
-
|
| 363 |
-
def _total_tokens_used(self) -> int:
|
| 364 |
-
return sum(self._effective_chunk_tokens(chunk_id) for chunk_id in self._selected_chunks)
|
| 365 |
-
|
| 366 |
-
def _effective_chunk_text(self, chunk_id: str) -> str:
|
| 367 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 368 |
-
if chunk is None:
|
| 369 |
-
return ""
|
| 370 |
-
ratio = self._compression_ratios.get(chunk_id, 1.0)
|
| 371 |
-
text = " ".join(chunk.text.split())
|
| 372 |
-
if ratio >= 0.999:
|
| 373 |
-
return text
|
| 374 |
-
|
| 375 |
-
query_terms = self._query_terms(self.task.query)
|
| 376 |
-
keyword_terms = self._query_terms(" ".join(chunk.keywords))
|
| 377 |
-
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
|
| 378 |
-
if not sentences:
|
| 379 |
-
return self._truncate_words(text, ratio)
|
| 380 |
-
|
| 381 |
-
ranked_sentences: list[tuple[int, float, int, str]] = []
|
| 382 |
-
for index, sentence in enumerate(sentences):
|
| 383 |
-
sentence_terms = self._query_terms(sentence)
|
| 384 |
-
overlap = len(sentence_terms & query_terms)
|
| 385 |
-
keyword_overlap = len(sentence_terms & keyword_terms)
|
| 386 |
-
score = (overlap * 2.0) + keyword_overlap + (0.25 if index == 0 else 0.0)
|
| 387 |
-
ranked_sentences.append((index, score, len(sentence.split()), sentence))
|
| 388 |
-
|
| 389 |
-
target_words = max(20, int(len(text.split()) * ratio))
|
| 390 |
-
chosen: list[tuple[int, str]] = []
|
| 391 |
-
used_words = 0
|
| 392 |
-
for index, _score, word_count, sentence in sorted(
|
| 393 |
-
ranked_sentences,
|
| 394 |
-
key=lambda item: (-item[1], item[2], item[0]),
|
| 395 |
-
):
|
| 396 |
-
if used_words >= target_words:
|
| 397 |
-
break
|
| 398 |
-
chosen.append((index, sentence))
|
| 399 |
-
used_words += word_count
|
| 400 |
-
|
| 401 |
-
if not chosen:
|
| 402 |
-
return self._truncate_words(text, ratio)
|
| 403 |
-
|
| 404 |
-
chosen.sort(key=lambda item: item[0])
|
| 405 |
-
compressed = " ".join(sentence for _index, sentence in chosen)
|
| 406 |
-
return self._truncate_words(compressed, ratio)
|
| 407 |
-
|
| 408 |
-
@staticmethod
|
| 409 |
-
def _truncate_words(text: str, ratio: float) -> str:
|
| 410 |
-
words = text.split()
|
| 411 |
-
if not words:
|
| 412 |
-
return ""
|
| 413 |
-
keep = max(12, int(len(words) * ratio))
|
| 414 |
-
truncated = " ".join(words[:keep])
|
| 415 |
-
if keep < len(words):
|
| 416 |
-
return truncated + " ..."
|
| 417 |
-
return truncated
|
| 418 |
-
|
| 419 |
-
@staticmethod
|
| 420 |
-
def _query_terms(text: str) -> set[str]:
|
| 421 |
-
return {token for token in re.findall(r"[a-z0-9]+", text.lower()) if len(token) > 2}
|
| 422 |
-
|
| 423 |
-
def _build_optimized_prompt(self) -> str:
|
| 424 |
-
if not self._selected_chunks:
|
| 425 |
-
return ""
|
| 426 |
-
sections = [f"Question: {self.task.query}", "", "Optimized Context:"]
|
| 427 |
-
for chunk_id in self._selected_chunks:
|
| 428 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 429 |
-
if chunk is None:
|
| 430 |
-
continue
|
| 431 |
-
sections.append(
|
| 432 |
-
f"[{chunk.chunk_id} | {self._effective_chunk_tokens(chunk_id)} tokens] {self._effective_chunk_text(chunk_id)}"
|
| 433 |
-
)
|
| 434 |
-
return "\n".join(sections).strip()
|
| 435 |
-
|
| 436 |
-
def _is_relevant(self, chunk_id: str) -> tuple[bool, float]:
|
| 437 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 438 |
-
if chunk is None:
|
| 439 |
-
return False, 0.0
|
| 440 |
-
score = self.retriever.hybrid_score(self.task.query, chunk)
|
| 441 |
-
return score >= 0.3, score
|
| 442 |
-
|
| 443 |
-
def _handle_select(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
|
| 444 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 445 |
-
if chunk is None:
|
| 446 |
-
self._last_action_feedback = "chunk_not_found"
|
| 447 |
-
return -0.1, {"event": "chunk_not_found"}
|
| 448 |
-
if chunk_id in self._selected_chunks:
|
| 449 |
-
self._last_action_feedback = "chunk_already_selected"
|
| 450 |
-
return 0.0, {"event": "chunk_already_selected"}
|
| 451 |
-
|
| 452 |
-
projected_tokens = self._total_tokens_used() + self._effective_chunk_tokens(chunk_id)
|
| 453 |
-
if projected_tokens > self.task.token_budget:
|
| 454 |
-
self._last_action_feedback = "exceeded_budget"
|
| 455 |
-
return -0.1, {"event": "exceeded_budget", "chunk_id": chunk_id}
|
| 456 |
-
|
| 457 |
-
self._selected_chunks.append(chunk_id)
|
| 458 |
-
_, score = self._is_relevant(chunk_id)
|
| 459 |
-
self._last_action_feedback = "chunk_selected"
|
| 460 |
-
return score * 0.2, {"event": "chunk_selected", "chunk_id": chunk_id, "hybrid_score": score}
|
| 461 |
-
|
| 462 |
-
def _handle_deselect(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
|
| 463 |
-
if chunk_id not in self._selected_chunks:
|
| 464 |
-
self._last_action_feedback = "chunk_not_selected"
|
| 465 |
-
return 0.0, {"event": "chunk_not_selected", "chunk_id": chunk_id}
|
| 466 |
-
|
| 467 |
-
self._selected_chunks.remove(chunk_id)
|
| 468 |
-
is_relevant, score = self._is_relevant(chunk_id)
|
| 469 |
-
self._last_action_feedback = "chunk_deselected"
|
| 470 |
-
reward = 0.0 if is_relevant else 0.05
|
| 471 |
-
return reward, {"event": "chunk_deselected", "chunk_id": chunk_id, "hybrid_score": score}
|
| 472 |
-
|
| 473 |
-
def _handle_compress(self, chunk_id: str, compression_ratio: float) -> tuple[float, dict[str, Any]]:
|
| 474 |
-
chunk = self._chunk_map().get(chunk_id)
|
| 475 |
-
if chunk is None:
|
| 476 |
-
self._last_action_feedback = "chunk_not_found"
|
| 477 |
-
return -0.1, {"event": "chunk_not_found", "chunk_id": chunk_id}
|
| 478 |
-
|
| 479 |
-
self._compression_ratios[chunk_id] = compression_ratio
|
| 480 |
-
is_relevant, score = self._is_relevant(chunk_id)
|
| 481 |
-
reward = 0.03 if is_relevant else 0.0
|
| 482 |
-
if score >= 0.6 and compression_ratio < 0.4:
|
| 483 |
-
reward -= 0.05
|
| 484 |
-
self._last_action_feedback = "overcompressed_relevant_chunk"
|
| 485 |
-
return reward, {
|
| 486 |
-
"event": "overcompressed_relevant_chunk",
|
| 487 |
-
"chunk_id": chunk_id,
|
| 488 |
-
"hybrid_score": score,
|
| 489 |
-
"compression_ratio": compression_ratio,
|
| 490 |
-
}
|
| 491 |
-
|
| 492 |
-
self._last_action_feedback = "chunk_compressed"
|
| 493 |
-
return reward, {
|
| 494 |
-
"event": "chunk_compressed",
|
| 495 |
-
"chunk_id": chunk_id,
|
| 496 |
-
"hybrid_score": score,
|
| 497 |
-
"compression_ratio": compression_ratio,
|
| 498 |
-
}
|
| 499 |
-
|
| 500 |
-
def _finalize_submission(self, reason: str) -> StepResult:
|
| 501 |
-
self._done = True
|
| 502 |
-
|
| 503 |
-
if not self._selected_chunks:
|
| 504 |
-
self._last_action_feedback = "no_chunks_selected"
|
| 505 |
-
observation = self._build_observation()
|
| 506 |
-
return StepResult(
|
| 507 |
-
observation=observation,
|
| 508 |
-
reward=0.0,
|
| 509 |
-
done=True,
|
| 510 |
-
info={"event": reason, "grader": None, "passed": False},
|
| 511 |
-
)
|
| 512 |
-
|
| 513 |
-
grader_result = self.grader.grade(
|
| 514 |
-
selected_chunk_ids=list(self._selected_chunks),
|
| 515 |
-
answer=self._last_answer,
|
| 516 |
-
token_budget=self.task.token_budget,
|
| 517 |
-
total_tokens_used=self._total_tokens_used(),
|
| 518 |
-
retriever=self.retriever,
|
| 519 |
-
task=self.task,
|
| 520 |
-
)
|
| 521 |
-
self._last_action_feedback = reason
|
| 522 |
-
observation = self._build_observation()
|
| 523 |
-
return StepResult(
|
| 524 |
-
observation=observation,
|
| 525 |
-
reward=grader_result.score,
|
| 526 |
-
done=True,
|
| 527 |
-
info={
|
| 528 |
-
"event": reason,
|
| 529 |
-
"grader": grader_result.breakdown,
|
| 530 |
-
"passed": grader_result.passed,
|
| 531 |
-
},
|
| 532 |
-
)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main OpenEnv-style environment for rag-context-optimizer.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from dataclasses import asdict, dataclass, is_dataclass, replace
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import re
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from env.corpus import Chunk, load_corpus, resolve_corpus_path
|
| 14 |
+
from env.context_tuner import ContextTunedPlanner
|
| 15 |
+
from env.graders import TaskGrader
|
| 16 |
+
from env.models import ChunkSummary, RagAction, RagObservation
|
| 17 |
+
from env.retriever import HybridRetriever
|
| 18 |
+
from env.tasks import ALL_TASKS, TASKS_BY_NAME, Task
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(slots=True)
|
| 22 |
+
class StepResult:
|
| 23 |
+
observation: RagObservation
|
| 24 |
+
reward: float
|
| 25 |
+
done: bool
|
| 26 |
+
info: dict[str, Any]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RagContextOptimizerEnv:
|
| 30 |
+
_PROJECT_STOPWORDS = {
|
| 31 |
+
"the", "and", "for", "with", "that", "this", "from", "into", "your", "have", "will",
|
| 32 |
+
"using", "used", "use", "into", "they", "them", "their", "about", "while", "where",
|
| 33 |
+
"when", "what", "which", "should", "would", "could", "there", "here", "then", "than",
|
| 34 |
+
"each", "such", "only", "also", "been", "being", "does", "did", "done", "just", "more",
|
| 35 |
+
"most", "very", "over", "under", "like", "same", "across", "because", "through", "make",
|
| 36 |
+
"made", "many", "much", "some", "into", "onto", "must", "need", "needs", "task", "tasks",
|
| 37 |
+
"chunk", "chunks", "query", "prompt", "environment", "agent", "agents", "model", "models",
|
| 38 |
+
}
|
| 39 |
+
_PROJECT_QUERY_HINTS = {
|
| 40 |
+
"openenv", "benchmark", "rag-context-optimizer", "readme", "docker", "fastapi", "api",
|
| 41 |
+
"endpoint", "inference.py", "app.py", "tasks.py", "graders.py", "environment.py", "repo",
|
| 42 |
+
"repository", "codebase", "ui", "frontend", "backend", "space", "validator",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
task_name: str = "single_domain_qa",
|
| 48 |
+
query_override: str | None = None,
|
| 49 |
+
token_budget_override: int | None = None,
|
| 50 |
+
max_steps_override: int | None = None,
|
| 51 |
+
corpus_family_override: str | None = None,
|
| 52 |
+
):
|
| 53 |
+
if task_name not in TASKS_BY_NAME:
|
| 54 |
+
raise ValueError(f"Unknown task_name: {task_name}")
|
| 55 |
+
|
| 56 |
self._corpus_family = corpus_family_override or os.getenv("RAG_CORPUS_FAMILY") or "enterprise_v1"
|
| 57 |
explicit_path = os.getenv("RAG_CORPUS_PATH")
|
| 58 |
self._corpus_path = resolve_corpus_path(explicit_path, family=None if explicit_path else self._corpus_family)
|
| 59 |
self._all_chunks = load_corpus(self._corpus_path)
|
| 60 |
self._query_overridden = bool(query_override and query_override.strip())
|
| 61 |
+
self._include_project_chunks = os.getenv("ENABLE_PROJECT_CORPUS", "").strip().lower() in {"1", "true", "yes"}
|
| 62 |
+
self._project_chunks = self._load_project_chunks() if self._include_project_chunks else []
|
| 63 |
self.retriever = HybridRetriever(self._all_chunks + self._project_chunks)
|
| 64 |
+
self.context_tuner = ContextTunedPlanner(
|
| 65 |
+
self.retriever,
|
| 66 |
+
self._all_chunks + self._project_chunks,
|
| 67 |
+
list(ALL_TASKS),
|
| 68 |
+
)
|
| 69 |
+
self.grader = TaskGrader()
|
| 70 |
+
self.task: Task = self._build_task(
|
| 71 |
+
TASKS_BY_NAME[task_name],
|
| 72 |
+
query_override=query_override,
|
| 73 |
+
token_budget_override=token_budget_override,
|
| 74 |
+
max_steps_override=max_steps_override,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self._available_chunks: list[Chunk] = []
|
| 78 |
+
self._selected_chunks: list[str] = []
|
| 79 |
+
self._compression_ratios: dict[str, float] = {}
|
| 80 |
+
self._step_number = 0
|
| 81 |
+
self._done = False
|
| 82 |
+
self._last_action_feedback: str | None = None
|
| 83 |
+
self._last_answer = ""
|
| 84 |
+
self._last_tuning = None
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def _build_task(
|
| 88 |
+
base_task: Task,
|
| 89 |
+
query_override: str | None = None,
|
| 90 |
+
token_budget_override: int | None = None,
|
| 91 |
+
max_steps_override: int | None = None,
|
| 92 |
+
) -> Task:
|
| 93 |
+
updated_task = base_task
|
| 94 |
+
if query_override and query_override.strip():
|
| 95 |
+
updated_task = replace(updated_task, query=query_override.strip(), domain_filter=None)
|
| 96 |
+
if token_budget_override is not None and token_budget_override > 0:
|
| 97 |
+
updated_task = replace(updated_task, token_budget=token_budget_override)
|
| 98 |
+
if max_steps_override is not None and max_steps_override > 0:
|
| 99 |
+
updated_task = replace(updated_task, max_steps=max_steps_override)
|
| 100 |
+
return updated_task
|
| 101 |
+
|
| 102 |
+
async def reset(self) -> StepResult:
|
| 103 |
+
candidate_chunks = self._filter_chunks_for_task(self.task)
|
| 104 |
+
self._available_chunks = self._rank_chunks_for_query(self.task.query, candidate_chunks)
|
| 105 |
+
if not self._query_overridden:
|
| 106 |
+
chunk_by_id = {chunk.chunk_id: chunk for chunk in candidate_chunks}
|
| 107 |
+
for chunk_id in self.task.required_chunk_ids:
|
| 108 |
+
chunk = chunk_by_id.get(chunk_id)
|
| 109 |
+
if chunk and all(existing.chunk_id != chunk_id for existing in self._available_chunks):
|
| 110 |
+
self._available_chunks.append(chunk)
|
| 111 |
+
self._selected_chunks = []
|
| 112 |
+
self._compression_ratios = {}
|
| 113 |
+
self._step_number = 0
|
| 114 |
+
self._done = False
|
| 115 |
+
self._last_action_feedback = None
|
| 116 |
+
self._last_answer = ""
|
| 117 |
+
|
| 118 |
+
observation = self._build_observation()
|
| 119 |
+
return StepResult(
|
| 120 |
+
observation=observation,
|
| 121 |
+
reward=0.0,
|
| 122 |
+
done=False,
|
| 123 |
+
info={"task": self.task.name, "event": "reset"},
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
async def step(self, action: RagAction) -> StepResult:
|
| 127 |
+
if self._done:
|
| 128 |
+
return StepResult(
|
| 129 |
+
observation=self._build_observation(),
|
| 130 |
+
reward=0.0,
|
| 131 |
+
done=True,
|
| 132 |
+
info={"task": self.task.name, "event": "episode_already_done"},
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
reward = 0.0
|
| 136 |
+
info: dict[str, Any] = {"task": self.task.name, "action_type": action.action_type}
|
| 137 |
+
|
| 138 |
+
if action.action_type == "select_chunk":
|
| 139 |
+
reward, info = self._handle_select(action.chunk_id or "")
|
| 140 |
+
elif action.action_type == "deselect_chunk":
|
| 141 |
+
reward, info = self._handle_deselect(action.chunk_id or "")
|
| 142 |
+
elif action.action_type == "compress_chunk":
|
| 143 |
+
reward, info = self._handle_compress(action.chunk_id or "", float(action.compression_ratio or 0.0))
|
| 144 |
+
elif action.action_type == "submit_answer":
|
| 145 |
+
self._last_answer = action.answer or ""
|
| 146 |
+
result = self._finalize_submission(reason="submit_answer")
|
| 147 |
+
self._step_number += 1
|
| 148 |
+
result.observation.step_number = self._step_number
|
| 149 |
+
return result
|
| 150 |
+
|
| 151 |
+
self._step_number += 1
|
| 152 |
+
|
| 153 |
+
if self._step_number >= self.task.max_steps:
|
| 154 |
+
return self._finalize_submission(reason="max_steps_reached")
|
| 155 |
+
|
| 156 |
+
observation = self._build_observation()
|
| 157 |
+
return StepResult(
|
| 158 |
+
observation=observation,
|
| 159 |
+
reward=reward,
|
| 160 |
+
done=False,
|
| 161 |
+
info=info,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
async def state(self) -> dict:
|
| 165 |
+
selected_chunk_details = []
|
| 166 |
+
for chunk_id in self._selected_chunks:
|
| 167 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 168 |
+
if chunk is None:
|
| 169 |
+
continue
|
| 170 |
+
selected_chunk_details.append(
|
| 171 |
+
{
|
| 172 |
+
"chunk_id": chunk.chunk_id,
|
| 173 |
+
"domain": chunk.domain,
|
| 174 |
+
"original_tokens": chunk.tokens,
|
| 175 |
+
"effective_tokens": self._effective_chunk_tokens(chunk_id),
|
| 176 |
+
"compression_ratio": round(self._compression_ratios.get(chunk_id, 1.0), 3),
|
| 177 |
+
"text": self._effective_chunk_text(chunk_id),
|
| 178 |
+
"keywords": chunk.keywords,
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
optimized_prompt = self._build_optimized_prompt()
|
| 182 |
+
return {
|
| 183 |
+
"task": asdict(self.task) if is_dataclass(self.task) else self.task,
|
| 184 |
+
"step_number": self._step_number,
|
| 185 |
+
"done": self._done,
|
| 186 |
+
"selected_chunks": list(self._selected_chunks),
|
| 187 |
+
"compression_ratios": dict(self._compression_ratios),
|
| 188 |
+
"total_tokens_used": self._total_tokens_used(),
|
| 189 |
+
"token_budget": self.task.token_budget,
|
| 190 |
+
"last_action_feedback": self._last_action_feedback,
|
| 191 |
+
"last_answer": self._last_answer,
|
| 192 |
+
"corpus_family": self._corpus_family,
|
| 193 |
+
"corpus_path": str(self._corpus_path),
|
| 194 |
+
"available_chunk_ids": [chunk.chunk_id for chunk in self._available_chunks],
|
| 195 |
+
"selected_chunk_details": selected_chunk_details,
|
| 196 |
+
"optimized_prompt_preview": optimized_prompt,
|
| 197 |
+
"optimized_prompt_tokens": max(1, len(optimized_prompt) // 4) if optimized_prompt else 0,
|
| 198 |
+
"context_tuning": (
|
| 199 |
+
{
|
| 200 |
+
"mode": self._last_tuning.mode,
|
| 201 |
+
"top_demo_cases": self._last_tuning.top_demo_cases,
|
| 202 |
+
"suggested_citations": self._last_tuning.suggested_citations,
|
| 203 |
+
"token_dropout": self._last_tuning.token_dropout,
|
| 204 |
+
"leave_one_out": self._last_tuning.leave_one_out,
|
| 205 |
+
}
|
| 206 |
+
if self._last_tuning is not None
|
| 207 |
+
else None
|
| 208 |
+
),
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
async def close(self):
|
| 212 |
+
self._done = True
|
| 213 |
+
|
| 214 |
+
def _filter_chunks_for_task(self, task: Task) -> list[Chunk]:
|
| 215 |
+
domain_mapping = {
|
| 216 |
+
"customer_support_operations": "Customer Support Operations",
|
| 217 |
+
"incident_response_playbooks": "Incident Response Playbooks",
|
| 218 |
+
"platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
|
| 219 |
}
|
| 220 |
if self._query_overridden:
|
| 221 |
+
if self._include_project_chunks and self._is_project_query(task.query):
|
| 222 |
return list(self._all_chunks) + list(self._project_chunks)
|
| 223 |
return list(self._all_chunks)
|
| 224 |
+
if task.domain_filter is None:
|
| 225 |
+
return list(self._all_chunks)
|
| 226 |
+
normalized = domain_mapping.get(task.domain_filter, task.domain_filter)
|
| 227 |
+
return [chunk for chunk in self._all_chunks if chunk.domain == normalized]
|
| 228 |
+
|
| 229 |
+
def _is_project_query(self, query: str) -> bool:
|
| 230 |
+
lowered = query.lower()
|
| 231 |
+
return any(hint in lowered for hint in self._PROJECT_QUERY_HINTS)
|
| 232 |
+
|
| 233 |
+
def _rank_chunks_for_query(self, query: str, chunks: list[Chunk], top_k: int = 20) -> list[Chunk]:
|
| 234 |
+
tuning = self.context_tuner.tune(query, chunks)
|
| 235 |
+
self._last_tuning = tuning
|
| 236 |
+
scored = []
|
| 237 |
for chunk in chunks:
|
| 238 |
tuned = tuning.tuned_scores.get(chunk.chunk_id)
|
| 239 |
score = tuned.final_score if tuned is not None else self.retriever.hybrid_score(query, chunk)
|
| 240 |
+
if self._include_project_chunks and self._query_overridden and chunk.domain.startswith("Project"):
|
| 241 |
score = min(1.0, score + 0.08)
|
| 242 |
scored.append((chunk, score))
|
| 243 |
+
scored.sort(key=lambda item: (-item[1], item[0].tokens, item[0].chunk_id))
|
| 244 |
+
if not scored:
|
| 245 |
+
return []
|
| 246 |
+
|
| 247 |
+
capped = scored[: max(1, min(top_k * 2, len(scored)))]
|
| 248 |
+
best_score = capped[0][1]
|
| 249 |
+
floor = max(0.12, best_score * 0.38)
|
| 250 |
+
filtered_pairs = [(chunk, score) for chunk, score in capped if score >= floor]
|
| 251 |
+
|
| 252 |
+
if self._include_project_chunks and self._query_overridden:
|
| 253 |
project_pairs = [(chunk, score) for chunk, score in filtered_pairs if chunk.domain.startswith("Project")]
|
| 254 |
if len(project_pairs) >= 4:
|
| 255 |
filtered_pairs = project_pairs + [
|
| 256 |
+
(chunk, score)
|
| 257 |
+
for chunk, score in filtered_pairs
|
| 258 |
+
if not chunk.domain.startswith("Project")
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
filtered = [chunk for chunk, _score in filtered_pairs]
|
| 262 |
+
if not filtered:
|
| 263 |
+
filtered = [chunk for chunk, _score in capped[: max(1, min(top_k, len(capped)))]]
|
| 264 |
+
|
| 265 |
+
return filtered[: max(1, min(top_k, len(filtered)))]
|
| 266 |
+
|
| 267 |
+
def _load_project_chunks(self) -> list[Chunk]:
|
| 268 |
+
root = Path(__file__).resolve().parent.parent
|
| 269 |
+
chunks: list[Chunk] = []
|
| 270 |
+
file_specs = [
|
| 271 |
+
("Project Documentation", root / "README.md", ["project_docs", "readme"]),
|
| 272 |
+
("Project Configuration", root / "openenv.yaml", ["project_docs", "config", "openenv_spec"]),
|
| 273 |
+
("Project API", root / "app.py", ["project_docs", "api", "server"]),
|
| 274 |
+
("Project Baseline", root / "inference.py", ["project_docs", "baseline", "inference"]),
|
| 275 |
+
("Project Environment", root / "env" / "environment.py", ["project_docs", "environment", "state_management"]),
|
| 276 |
+
("Project Retrieval", root / "env" / "retriever.py", ["project_docs", "retrieval", "ranking"]),
|
| 277 |
+
("Project Grading", root / "env" / "graders.py", ["project_docs", "grading", "reward_design"]),
|
| 278 |
+
("Project Tasks", root / "env" / "tasks.py", ["project_docs", "tasks", "difficulty"]),
|
| 279 |
+
("Project Validation", root / "validate.py", ["project_docs", "validation", "testing"]),
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
for domain, path, tags in file_specs:
|
| 283 |
+
if not path.exists():
|
| 284 |
+
continue
|
| 285 |
+
raw_text = path.read_text(encoding="utf-8", errors="ignore")
|
| 286 |
+
sections = self._chunk_project_text(raw_text)
|
| 287 |
+
stem = re.sub(r"[^a-z0-9]+", "_", path.stem.lower()).strip("_") or "file"
|
| 288 |
+
for index, section in enumerate(sections, start=1):
|
| 289 |
+
keywords = self._extract_project_keywords(section)
|
| 290 |
+
if not keywords:
|
| 291 |
+
keywords = [stem, domain.lower()]
|
| 292 |
+
chunks.append(
|
| 293 |
+
Chunk(
|
| 294 |
+
chunk_id=f"project_{stem}_{index:03d}",
|
| 295 |
+
domain=domain,
|
| 296 |
+
text=section,
|
| 297 |
+
tokens=max(30, len(section) // 4),
|
| 298 |
+
keywords=keywords[:5],
|
| 299 |
+
relevance_tags=tags,
|
| 300 |
+
)
|
| 301 |
+
)
|
| 302 |
+
return chunks
|
| 303 |
+
|
| 304 |
+
def _chunk_project_text(self, raw_text: str, chunk_words: int = 140, stride_words: int = 100) -> list[str]:
|
| 305 |
+
cleaned = " ".join(raw_text.split())
|
| 306 |
+
words = cleaned.split()
|
| 307 |
+
if not words:
|
| 308 |
+
return []
|
| 309 |
+
if len(words) <= chunk_words:
|
| 310 |
+
return [" ".join(words)]
|
| 311 |
+
|
| 312 |
+
chunks: list[str] = []
|
| 313 |
+
start = 0
|
| 314 |
+
while start < len(words):
|
| 315 |
+
window = words[start : start + chunk_words]
|
| 316 |
+
if not window:
|
| 317 |
+
break
|
| 318 |
+
chunks.append(" ".join(window))
|
| 319 |
+
if start + chunk_words >= len(words):
|
| 320 |
+
break
|
| 321 |
+
start += stride_words
|
| 322 |
+
return chunks
|
| 323 |
+
|
| 324 |
+
def _extract_project_keywords(self, text: str) -> list[str]:
|
| 325 |
+
terms = re.findall(r"[a-z0-9_]+", text.lower())
|
| 326 |
+
counts: dict[str, int] = {}
|
| 327 |
+
for term in terms:
|
| 328 |
+
if len(term) < 4 or term in self._PROJECT_STOPWORDS:
|
| 329 |
+
continue
|
| 330 |
+
counts[term] = counts.get(term, 0) + 1
|
| 331 |
+
ranked = sorted(counts.items(), key=lambda item: (-item[1], item[0]))
|
| 332 |
+
return [term.replace("_", " ") for term, _count in ranked[:8]]
|
| 333 |
+
|
| 334 |
+
def _build_observation(self) -> RagObservation:
|
| 335 |
+
return RagObservation(
|
| 336 |
+
query=self.task.query,
|
| 337 |
+
available_chunks=[
|
| 338 |
+
ChunkSummary(
|
| 339 |
+
chunk_id=chunk.chunk_id,
|
| 340 |
+
domain=chunk.domain,
|
| 341 |
+
tokens=self._effective_chunk_tokens(chunk.chunk_id),
|
| 342 |
+
keywords=chunk.keywords,
|
| 343 |
+
)
|
| 344 |
+
for chunk in self._available_chunks
|
| 345 |
+
],
|
| 346 |
+
selected_chunks=list(self._selected_chunks),
|
| 347 |
+
total_tokens_used=self._total_tokens_used(),
|
| 348 |
+
token_budget=self.task.token_budget,
|
| 349 |
+
step_number=self._step_number,
|
| 350 |
+
task_name=self.task.name,
|
| 351 |
+
last_action_feedback=self._last_action_feedback,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def _chunk_map(self) -> dict[str, Chunk]:
|
| 355 |
+
return {chunk.chunk_id: chunk for chunk in self._available_chunks}
|
| 356 |
+
|
| 357 |
+
def _effective_chunk_tokens(self, chunk_id: str) -> int:
|
| 358 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 359 |
+
if chunk is None:
|
| 360 |
+
return 0
|
| 361 |
+
ratio = self._compression_ratios.get(chunk_id, 1.0)
|
| 362 |
+
return max(1, int(round(chunk.tokens * ratio)))
|
| 363 |
+
|
| 364 |
+
def _total_tokens_used(self) -> int:
|
| 365 |
+
return sum(self._effective_chunk_tokens(chunk_id) for chunk_id in self._selected_chunks)
|
| 366 |
+
|
| 367 |
+
def _effective_chunk_text(self, chunk_id: str) -> str:
|
| 368 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 369 |
+
if chunk is None:
|
| 370 |
+
return ""
|
| 371 |
+
ratio = self._compression_ratios.get(chunk_id, 1.0)
|
| 372 |
+
text = " ".join(chunk.text.split())
|
| 373 |
+
if ratio >= 0.999:
|
| 374 |
+
return text
|
| 375 |
+
|
| 376 |
+
query_terms = self._query_terms(self.task.query)
|
| 377 |
+
keyword_terms = self._query_terms(" ".join(chunk.keywords))
|
| 378 |
+
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
|
| 379 |
+
if not sentences:
|
| 380 |
+
return self._truncate_words(text, ratio)
|
| 381 |
+
|
| 382 |
+
ranked_sentences: list[tuple[int, float, int, str]] = []
|
| 383 |
+
for index, sentence in enumerate(sentences):
|
| 384 |
+
sentence_terms = self._query_terms(sentence)
|
| 385 |
+
overlap = len(sentence_terms & query_terms)
|
| 386 |
+
keyword_overlap = len(sentence_terms & keyword_terms)
|
| 387 |
+
score = (overlap * 2.0) + keyword_overlap + (0.25 if index == 0 else 0.0)
|
| 388 |
+
ranked_sentences.append((index, score, len(sentence.split()), sentence))
|
| 389 |
+
|
| 390 |
+
target_words = max(20, int(len(text.split()) * ratio))
|
| 391 |
+
chosen: list[tuple[int, str]] = []
|
| 392 |
+
used_words = 0
|
| 393 |
+
for index, _score, word_count, sentence in sorted(
|
| 394 |
+
ranked_sentences,
|
| 395 |
+
key=lambda item: (-item[1], item[2], item[0]),
|
| 396 |
+
):
|
| 397 |
+
if used_words >= target_words:
|
| 398 |
+
break
|
| 399 |
+
chosen.append((index, sentence))
|
| 400 |
+
used_words += word_count
|
| 401 |
+
|
| 402 |
+
if not chosen:
|
| 403 |
+
return self._truncate_words(text, ratio)
|
| 404 |
+
|
| 405 |
+
chosen.sort(key=lambda item: item[0])
|
| 406 |
+
compressed = " ".join(sentence for _index, sentence in chosen)
|
| 407 |
+
return self._truncate_words(compressed, ratio)
|
| 408 |
+
|
| 409 |
+
@staticmethod
|
| 410 |
+
def _truncate_words(text: str, ratio: float) -> str:
|
| 411 |
+
words = text.split()
|
| 412 |
+
if not words:
|
| 413 |
+
return ""
|
| 414 |
+
keep = max(12, int(len(words) * ratio))
|
| 415 |
+
truncated = " ".join(words[:keep])
|
| 416 |
+
if keep < len(words):
|
| 417 |
+
return truncated + " ..."
|
| 418 |
+
return truncated
|
| 419 |
+
|
| 420 |
+
@staticmethod
|
| 421 |
+
def _query_terms(text: str) -> set[str]:
|
| 422 |
+
return {token for token in re.findall(r"[a-z0-9]+", text.lower()) if len(token) > 2}
|
| 423 |
+
|
| 424 |
+
def _build_optimized_prompt(self) -> str:
|
| 425 |
+
if not self._selected_chunks:
|
| 426 |
+
return ""
|
| 427 |
+
sections = [f"Question: {self.task.query}", "", "Optimized Context:"]
|
| 428 |
+
for chunk_id in self._selected_chunks:
|
| 429 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 430 |
+
if chunk is None:
|
| 431 |
+
continue
|
| 432 |
+
sections.append(
|
| 433 |
+
f"[{chunk.chunk_id} | {self._effective_chunk_tokens(chunk_id)} tokens] {self._effective_chunk_text(chunk_id)}"
|
| 434 |
+
)
|
| 435 |
+
return "\n".join(sections).strip()
|
| 436 |
+
|
| 437 |
+
def _is_relevant(self, chunk_id: str) -> tuple[bool, float]:
|
| 438 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 439 |
+
if chunk is None:
|
| 440 |
+
return False, 0.0
|
| 441 |
+
score = self.retriever.hybrid_score(self.task.query, chunk)
|
| 442 |
+
return score >= 0.3, score
|
| 443 |
+
|
| 444 |
+
def _handle_select(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
|
| 445 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 446 |
+
if chunk is None:
|
| 447 |
+
self._last_action_feedback = "chunk_not_found"
|
| 448 |
+
return -0.1, {"event": "chunk_not_found"}
|
| 449 |
+
if chunk_id in self._selected_chunks:
|
| 450 |
+
self._last_action_feedback = "chunk_already_selected"
|
| 451 |
+
return 0.0, {"event": "chunk_already_selected"}
|
| 452 |
+
|
| 453 |
+
projected_tokens = self._total_tokens_used() + self._effective_chunk_tokens(chunk_id)
|
| 454 |
+
if projected_tokens > self.task.token_budget:
|
| 455 |
+
self._last_action_feedback = "exceeded_budget"
|
| 456 |
+
return -0.1, {"event": "exceeded_budget", "chunk_id": chunk_id}
|
| 457 |
+
|
| 458 |
+
self._selected_chunks.append(chunk_id)
|
| 459 |
+
_, score = self._is_relevant(chunk_id)
|
| 460 |
+
self._last_action_feedback = "chunk_selected"
|
| 461 |
+
return score * 0.2, {"event": "chunk_selected", "chunk_id": chunk_id, "hybrid_score": score}
|
| 462 |
+
|
| 463 |
+
def _handle_deselect(self, chunk_id: str) -> tuple[float, dict[str, Any]]:
|
| 464 |
+
if chunk_id not in self._selected_chunks:
|
| 465 |
+
self._last_action_feedback = "chunk_not_selected"
|
| 466 |
+
return 0.0, {"event": "chunk_not_selected", "chunk_id": chunk_id}
|
| 467 |
+
|
| 468 |
+
self._selected_chunks.remove(chunk_id)
|
| 469 |
+
is_relevant, score = self._is_relevant(chunk_id)
|
| 470 |
+
self._last_action_feedback = "chunk_deselected"
|
| 471 |
+
reward = 0.0 if is_relevant else 0.05
|
| 472 |
+
return reward, {"event": "chunk_deselected", "chunk_id": chunk_id, "hybrid_score": score}
|
| 473 |
+
|
| 474 |
+
def _handle_compress(self, chunk_id: str, compression_ratio: float) -> tuple[float, dict[str, Any]]:
|
| 475 |
+
chunk = self._chunk_map().get(chunk_id)
|
| 476 |
+
if chunk is None:
|
| 477 |
+
self._last_action_feedback = "chunk_not_found"
|
| 478 |
+
return -0.1, {"event": "chunk_not_found", "chunk_id": chunk_id}
|
| 479 |
+
|
| 480 |
+
self._compression_ratios[chunk_id] = compression_ratio
|
| 481 |
+
is_relevant, score = self._is_relevant(chunk_id)
|
| 482 |
+
reward = 0.03 if is_relevant else 0.0
|
| 483 |
+
if score >= 0.6 and compression_ratio < 0.4:
|
| 484 |
+
reward -= 0.05
|
| 485 |
+
self._last_action_feedback = "overcompressed_relevant_chunk"
|
| 486 |
+
return reward, {
|
| 487 |
+
"event": "overcompressed_relevant_chunk",
|
| 488 |
+
"chunk_id": chunk_id,
|
| 489 |
+
"hybrid_score": score,
|
| 490 |
+
"compression_ratio": compression_ratio,
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
self._last_action_feedback = "chunk_compressed"
|
| 494 |
+
return reward, {
|
| 495 |
+
"event": "chunk_compressed",
|
| 496 |
+
"chunk_id": chunk_id,
|
| 497 |
+
"hybrid_score": score,
|
| 498 |
+
"compression_ratio": compression_ratio,
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
def _finalize_submission(self, reason: str) -> StepResult:
|
| 502 |
+
self._done = True
|
| 503 |
+
|
| 504 |
+
if not self._selected_chunks:
|
| 505 |
+
self._last_action_feedback = "no_chunks_selected"
|
| 506 |
+
observation = self._build_observation()
|
| 507 |
+
return StepResult(
|
| 508 |
+
observation=observation,
|
| 509 |
+
reward=0.0,
|
| 510 |
+
done=True,
|
| 511 |
+
info={"event": reason, "grader": None, "passed": False},
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
grader_result = self.grader.grade(
|
| 515 |
+
selected_chunk_ids=list(self._selected_chunks),
|
| 516 |
+
answer=self._last_answer,
|
| 517 |
+
token_budget=self.task.token_budget,
|
| 518 |
+
total_tokens_used=self._total_tokens_used(),
|
| 519 |
+
retriever=self.retriever,
|
| 520 |
+
task=self.task,
|
| 521 |
+
)
|
| 522 |
+
self._last_action_feedback = reason
|
| 523 |
+
observation = self._build_observation()
|
| 524 |
+
return StepResult(
|
| 525 |
+
observation=observation,
|
| 526 |
+
reward=grader_result.score,
|
| 527 |
+
done=True,
|
| 528 |
+
info={
|
| 529 |
+
"event": reason,
|
| 530 |
+
"grader": grader_result.breakdown,
|
| 531 |
+
"passed": grader_result.passed,
|
| 532 |
+
},
|
| 533 |
+
)
|
env/graders.py
CHANGED
|
@@ -1,124 +1,125 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Deterministic graders for rag-context-optimizer tasks.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
import re
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
|
| 10 |
-
from env.corpus import Chunk
|
| 11 |
-
from env.retriever import HybridRetriever
|
| 12 |
-
from env.tasks import Task
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
_STOPWORDS = {
|
| 16 |
-
"a", "an", "and", "are", "as", "at", "be", "because", "by", "for", "from", "how",
|
| 17 |
-
"if", "in", "into", "is", "it", "its", "of", "on", "or", "that", "the", "their",
|
| 18 |
-
"them", "there", "these", "this", "to", "was", "were", "what", "when", "where",
|
| 19 |
-
"which", "while", "with", "within", "without", "you", "your",
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def _tokenize(text: str) -> set[str]:
|
| 24 |
-
return set(re.findall(r"[a-z0-9]+", text.lower()))
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _content_terms(text: str) -> set[str]:
|
| 28 |
-
return {term for term in _tokenize(text) if len(term) > 2 and term not in _STOPWORDS}
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def _extract_citations(text: str) -> list[str]:
|
| 32 |
-
return re.findall(r"\[([a-z0-9_]+)\]", text.lower())
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def _normalize_chunk_id(chunk_id: str) -> str:
|
| 36 |
-
chunk_id = chunk_id.strip()
|
| 37 |
-
return chunk_id
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _normalize_domain_filter(domain_filter: str | None) -> str | None:
|
| 41 |
-
if domain_filter is None:
|
| 42 |
-
return None
|
| 43 |
-
mapping = {
|
| 44 |
-
"customer_support_operations": "Customer Support Operations",
|
| 45 |
-
"incident_response_playbooks": "Incident Response Playbooks",
|
| 46 |
-
"platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
|
| 47 |
-
}
|
| 48 |
-
return mapping.get(domain_filter, domain_filter)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def _f1_score(selected: set[str], relevant: set[str]) -> float:
|
| 52 |
-
if not selected and not relevant:
|
| 53 |
-
return 1.0
|
| 54 |
-
if not selected or not relevant:
|
| 55 |
-
return 0.0
|
| 56 |
-
overlap = len(selected & relevant)
|
| 57 |
-
if overlap == 0:
|
| 58 |
-
return 0.0
|
| 59 |
-
precision = overlap / len(selected)
|
| 60 |
-
recall = overlap / len(relevant)
|
| 61 |
-
return 2 * precision * recall / (precision + recall)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
@dataclass(frozen=True, slots=True)
|
| 65 |
-
class GraderResult:
|
| 66 |
-
score: float
|
| 67 |
-
breakdown: dict[str, float]
|
| 68 |
-
passed: bool
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
class TaskGrader:
|
| 72 |
-
def _filter_relevant_by_domain(self, relevant_ids: set[str], retriever: HybridRetriever, task: Task) -> set[str]:
|
| 73 |
-
normalized_domain = _normalize_domain_filter(task.domain_filter)
|
| 74 |
-
if normalized_domain is None:
|
| 75 |
-
return relevant_ids
|
| 76 |
-
allowed_ids = {chunk.chunk_id for chunk in retriever.corpus if chunk.domain == normalized_domain}
|
| 77 |
-
return relevant_ids & allowed_ids
|
| 78 |
-
|
| 79 |
-
def _required_chunks(self, retriever: HybridRetriever, task: Task) -> list[Chunk]:
|
| 80 |
-
normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
|
| 81 |
-
return [chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_required]
|
| 82 |
-
|
| 83 |
def _answer_quality(self, answer: str, required_chunks: list[Chunk]) -> float:
|
| 84 |
answer_terms = _content_terms(answer)
|
| 85 |
required_terms = _content_terms(" ".join(chunk.text for chunk in required_chunks))
|
|
|
|
| 86 |
if not answer_terms or not required_terms:
|
| 87 |
return 0.0
|
| 88 |
union = answer_terms | required_terms
|
| 89 |
-
if not union:
|
| 90 |
-
return 0.0
|
| 91 |
-
return len(answer_terms & required_terms) / len(union)
|
| 92 |
-
|
| 93 |
-
def _citation_accuracy(self, answer: str, selected_chunk_ids: set[str], expected_citation_ids: set[str]) -> float:
|
| 94 |
-
citations = {_normalize_chunk_id(chunk_id) for chunk_id in _extract_citations(answer)}
|
| 95 |
-
if not citations:
|
| 96 |
-
return 0.0
|
| 97 |
-
valid_citations = citations & selected_chunk_ids
|
| 98 |
-
precision = len(valid_citations) / len(citations)
|
| 99 |
-
recall = len(valid_citations & expected_citation_ids) / len(expected_citation_ids) if expected_citation_ids else 1.0
|
| 100 |
-
return (precision + recall) / 2.0
|
| 101 |
-
|
| 102 |
-
def _unsupported_claim_rate(self, answer: str, evidence_chunks: list[Chunk]) -> float:
|
| 103 |
-
answer_terms = _content_terms(re.sub(r"\[[a-z0-9_]+\]", " ", answer.lower()))
|
| 104 |
-
evidence_terms = _content_terms(" ".join(chunk.text for chunk in evidence_chunks))
|
| 105 |
-
if not answer_terms:
|
| 106 |
-
return 0.0
|
| 107 |
-
unsupported = answer_terms - evidence_terms
|
| 108 |
-
return len(unsupported) / len(answer_terms)
|
| 109 |
-
|
| 110 |
-
def grade(
|
| 111 |
-
self,
|
| 112 |
-
selected_chunk_ids: list[str],
|
| 113 |
-
answer: str,
|
| 114 |
-
token_budget: int,
|
| 115 |
-
total_tokens_used: int,
|
| 116 |
-
retriever: HybridRetriever,
|
| 117 |
-
task: Task,
|
| 118 |
) -> GraderResult:
|
| 119 |
normalized_selected = {_normalize_chunk_id(chunk_id) for chunk_id in selected_chunk_ids}
|
| 120 |
-
|
| 121 |
-
relevant = self._filter_relevant_by_domain(
|
| 122 |
|
| 123 |
retrieval_precision = _f1_score(normalized_selected, relevant)
|
| 124 |
token_efficiency = 1.0 - (total_tokens_used / token_budget) if total_tokens_used <= token_budget else 0.0
|
|
@@ -127,41 +128,41 @@ class TaskGrader:
|
|
| 127 |
required_chunks = self._required_chunks(retriever, task)
|
| 128 |
answer_quality = self._answer_quality(answer, required_chunks)
|
| 129 |
|
| 130 |
-
normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
|
| 131 |
normalized_expected_citations = {
|
| 132 |
_normalize_chunk_id(chunk_id)
|
| 133 |
for chunk_id in (task.expected_citation_ids or task.required_chunk_ids)
|
| 134 |
-
}
|
| 135 |
-
required_chunks_hit = (
|
| 136 |
-
len(normalized_selected & normalized_required) / len(normalized_required)
|
| 137 |
-
if normalized_required
|
| 138 |
-
else 1.0
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
selected_chunks = [
|
| 142 |
chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_selected
|
| 143 |
]
|
|
|
|
| 144 |
citation_accuracy = self._citation_accuracy(answer, normalized_selected, normalized_expected_citations)
|
| 145 |
-
unsupported_claim_rate = self._unsupported_claim_rate(answer,
|
| 146 |
-
hallucination_penalty = min(1.0, unsupported_claim_rate)
|
| 147 |
-
|
| 148 |
-
base_score = (
|
| 149 |
-
0.25 * retrieval_precision
|
| 150 |
-
+ 0.25 * token_efficiency
|
| 151 |
-
+ 0.35 * answer_quality
|
| 152 |
-
+ 0.15 * required_chunks_hit
|
| 153 |
-
)
|
| 154 |
-
score = base_score + (0.10 * citation_accuracy) - (0.15 * hallucination_penalty)
|
| 155 |
-
score = max(0.0, min(1.0, score))
|
| 156 |
-
|
| 157 |
-
breakdown = {
|
| 158 |
-
"retrieval_precision": retrieval_precision,
|
| 159 |
-
"token_efficiency": token_efficiency,
|
| 160 |
-
"answer_quality": answer_quality,
|
| 161 |
-
"required_chunks_hit": required_chunks_hit,
|
| 162 |
-
"citation_accuracy": citation_accuracy,
|
| 163 |
-
"unsupported_claim_rate": unsupported_claim_rate,
|
| 164 |
-
"hallucination_penalty": hallucination_penalty,
|
| 165 |
-
}
|
| 166 |
-
passed = score >= 0.7
|
| 167 |
-
return GraderResult(score=score, breakdown=breakdown, passed=passed)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deterministic graders for rag-context-optimizer tasks.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
from env.corpus import Chunk
|
| 11 |
+
from env.retriever import HybridRetriever
|
| 12 |
+
from env.tasks import Task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_STOPWORDS = {
|
| 16 |
+
"a", "an", "and", "are", "as", "at", "be", "because", "by", "for", "from", "how",
|
| 17 |
+
"if", "in", "into", "is", "it", "its", "of", "on", "or", "that", "the", "their",
|
| 18 |
+
"them", "there", "these", "this", "to", "was", "were", "what", "when", "where",
|
| 19 |
+
"which", "while", "with", "within", "without", "you", "your",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _tokenize(text: str) -> set[str]:
|
| 24 |
+
return set(re.findall(r"[a-z0-9]+", text.lower()))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _content_terms(text: str) -> set[str]:
|
| 28 |
+
return {term for term in _tokenize(text) if len(term) > 2 and term not in _STOPWORDS}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _extract_citations(text: str) -> list[str]:
|
| 32 |
+
return re.findall(r"\[([a-z0-9_]+)\]", text.lower())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _normalize_chunk_id(chunk_id: str) -> str:
|
| 36 |
+
chunk_id = chunk_id.strip()
|
| 37 |
+
return chunk_id
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _normalize_domain_filter(domain_filter: str | None) -> str | None:
|
| 41 |
+
if domain_filter is None:
|
| 42 |
+
return None
|
| 43 |
+
mapping = {
|
| 44 |
+
"customer_support_operations": "Customer Support Operations",
|
| 45 |
+
"incident_response_playbooks": "Incident Response Playbooks",
|
| 46 |
+
"platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
|
| 47 |
+
}
|
| 48 |
+
return mapping.get(domain_filter, domain_filter)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _f1_score(selected: set[str], relevant: set[str]) -> float:
|
| 52 |
+
if not selected and not relevant:
|
| 53 |
+
return 1.0
|
| 54 |
+
if not selected or not relevant:
|
| 55 |
+
return 0.0
|
| 56 |
+
overlap = len(selected & relevant)
|
| 57 |
+
if overlap == 0:
|
| 58 |
+
return 0.0
|
| 59 |
+
precision = overlap / len(selected)
|
| 60 |
+
recall = overlap / len(relevant)
|
| 61 |
+
return 2 * precision * recall / (precision + recall)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass(frozen=True, slots=True)
|
| 65 |
+
class GraderResult:
|
| 66 |
+
score: float
|
| 67 |
+
breakdown: dict[str, float]
|
| 68 |
+
passed: bool
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TaskGrader:
|
| 72 |
+
def _filter_relevant_by_domain(self, relevant_ids: set[str], retriever: HybridRetriever, task: Task) -> set[str]:
|
| 73 |
+
normalized_domain = _normalize_domain_filter(task.domain_filter)
|
| 74 |
+
if normalized_domain is None:
|
| 75 |
+
return relevant_ids
|
| 76 |
+
allowed_ids = {chunk.chunk_id for chunk in retriever.corpus if chunk.domain == normalized_domain}
|
| 77 |
+
return relevant_ids & allowed_ids
|
| 78 |
+
|
| 79 |
+
def _required_chunks(self, retriever: HybridRetriever, task: Task) -> list[Chunk]:
|
| 80 |
+
normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
|
| 81 |
+
return [chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_required]
|
| 82 |
+
|
| 83 |
def _answer_quality(self, answer: str, required_chunks: list[Chunk]) -> float:
|
| 84 |
answer_terms = _content_terms(answer)
|
| 85 |
required_terms = _content_terms(" ".join(chunk.text for chunk in required_chunks))
|
| 86 |
+
required_terms |= _content_terms(" ".join(" ".join(chunk.keywords) for chunk in required_chunks))
|
| 87 |
if not answer_terms or not required_terms:
|
| 88 |
return 0.0
|
| 89 |
union = answer_terms | required_terms
|
| 90 |
+
if not union:
|
| 91 |
+
return 0.0
|
| 92 |
+
return len(answer_terms & required_terms) / len(union)
|
| 93 |
+
|
| 94 |
+
def _citation_accuracy(self, answer: str, selected_chunk_ids: set[str], expected_citation_ids: set[str]) -> float:
|
| 95 |
+
citations = {_normalize_chunk_id(chunk_id) for chunk_id in _extract_citations(answer)}
|
| 96 |
+
if not citations:
|
| 97 |
+
return 0.0
|
| 98 |
+
valid_citations = citations & selected_chunk_ids
|
| 99 |
+
precision = len(valid_citations) / len(citations)
|
| 100 |
+
recall = len(valid_citations & expected_citation_ids) / len(expected_citation_ids) if expected_citation_ids else 1.0
|
| 101 |
+
return (precision + recall) / 2.0
|
| 102 |
+
|
| 103 |
+
def _unsupported_claim_rate(self, answer: str, evidence_chunks: list[Chunk]) -> float:
|
| 104 |
+
answer_terms = _content_terms(re.sub(r"\[[a-z0-9_]+\]", " ", answer.lower()))
|
| 105 |
+
evidence_terms = _content_terms(" ".join(chunk.text for chunk in evidence_chunks))
|
| 106 |
+
if not answer_terms:
|
| 107 |
+
return 0.0
|
| 108 |
+
unsupported = answer_terms - evidence_terms
|
| 109 |
+
return len(unsupported) / len(answer_terms)
|
| 110 |
+
|
| 111 |
+
def grade(
|
| 112 |
+
self,
|
| 113 |
+
selected_chunk_ids: list[str],
|
| 114 |
+
answer: str,
|
| 115 |
+
token_budget: int,
|
| 116 |
+
total_tokens_used: int,
|
| 117 |
+
retriever: HybridRetriever,
|
| 118 |
+
task: Task,
|
| 119 |
) -> GraderResult:
|
| 120 |
normalized_selected = {_normalize_chunk_id(chunk_id) for chunk_id in selected_chunk_ids}
|
| 121 |
+
normalized_required = {_normalize_chunk_id(chunk_id) for chunk_id in task.required_chunk_ids}
|
| 122 |
+
relevant = self._filter_relevant_by_domain(normalized_required, retriever, task)
|
| 123 |
|
| 124 |
retrieval_precision = _f1_score(normalized_selected, relevant)
|
| 125 |
token_efficiency = 1.0 - (total_tokens_used / token_budget) if total_tokens_used <= token_budget else 0.0
|
|
|
|
| 128 |
required_chunks = self._required_chunks(retriever, task)
|
| 129 |
answer_quality = self._answer_quality(answer, required_chunks)
|
| 130 |
|
|
|
|
| 131 |
normalized_expected_citations = {
|
| 132 |
_normalize_chunk_id(chunk_id)
|
| 133 |
for chunk_id in (task.expected_citation_ids or task.required_chunk_ids)
|
| 134 |
+
}
|
| 135 |
+
required_chunks_hit = (
|
| 136 |
+
len(normalized_selected & normalized_required) / len(normalized_required)
|
| 137 |
+
if normalized_required
|
| 138 |
+
else 1.0
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
selected_chunks = [
|
| 142 |
chunk for chunk in retriever.corpus if chunk.chunk_id in normalized_selected
|
| 143 |
]
|
| 144 |
+
evidence_chunks = selected_chunks or required_chunks
|
| 145 |
citation_accuracy = self._citation_accuracy(answer, normalized_selected, normalized_expected_citations)
|
| 146 |
+
unsupported_claim_rate = self._unsupported_claim_rate(answer, evidence_chunks)
|
| 147 |
+
hallucination_penalty = min(1.0, unsupported_claim_rate)
|
| 148 |
+
|
| 149 |
+
base_score = (
|
| 150 |
+
0.25 * retrieval_precision
|
| 151 |
+
+ 0.25 * token_efficiency
|
| 152 |
+
+ 0.35 * answer_quality
|
| 153 |
+
+ 0.15 * required_chunks_hit
|
| 154 |
+
)
|
| 155 |
+
score = base_score + (0.10 * citation_accuracy) - (0.15 * hallucination_penalty)
|
| 156 |
+
score = max(0.0, min(1.0, score))
|
| 157 |
+
|
| 158 |
+
breakdown = {
|
| 159 |
+
"retrieval_precision": retrieval_precision,
|
| 160 |
+
"token_efficiency": token_efficiency,
|
| 161 |
+
"answer_quality": answer_quality,
|
| 162 |
+
"required_chunks_hit": required_chunks_hit,
|
| 163 |
+
"citation_accuracy": citation_accuracy,
|
| 164 |
+
"unsupported_claim_rate": unsupported_claim_rate,
|
| 165 |
+
"hallucination_penalty": hallucination_penalty,
|
| 166 |
+
}
|
| 167 |
+
passed = score >= 0.7
|
| 168 |
+
return GraderResult(score=score, breakdown=breakdown, passed=passed)
|
inference.py
CHANGED
|
@@ -26,16 +26,16 @@ TASK_SEQUENCE = [
|
|
| 26 |
"adversarial_compression",
|
| 27 |
]
|
| 28 |
|
| 29 |
-
SYSTEM_PROMPT = """You are a baseline RAG context optimizer.
|
| 30 |
-
Read the query and available chunks using chunk_id, keywords, tokens, and domain.
|
| 31 |
Select chunks that maximize keyword overlap with the query.
|
| 32 |
Stay under the token budget.
|
| 33 |
Compress chunks that are mildly relevant but token-heavy.
|
| 34 |
Submit a concise answer once enough useful chunks are selected.
|
| 35 |
When you submit an answer, cite selected chunks inline like [support_003] or [incident_002].
|
| 36 |
Return only valid JSON matching one of these forms:
|
| 37 |
-
{"action_type":"select_chunk","chunk_id":"support_003"}
|
| 38 |
-
{"action_type":"deselect_chunk","chunk_id":"support_003"}
|
| 39 |
{"action_type":"compress_chunk","chunk_id":"support_003","compression_ratio":0.5}
|
| 40 |
{"action_type":"submit_answer","answer":"Verify outage evidence and the billing ledger before refunding [support_001] [support_003]."}"""
|
| 41 |
|
|
@@ -222,12 +222,12 @@ async def _post_json(http_client: httpx.AsyncClient, path: str, payload: dict[st
|
|
| 222 |
return response.json()
|
| 223 |
|
| 224 |
|
| 225 |
-
async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
|
| 226 |
rewards: list[float] = []
|
| 227 |
steps = 0
|
| 228 |
success = False
|
| 229 |
-
score = 0.0
|
| 230 |
-
terminal_error: str | None = None
|
| 231 |
fallback_reason: str | None = None
|
| 232 |
model_name = _model_name()
|
| 233 |
|
|
@@ -253,7 +253,7 @@ async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
|
|
| 253 |
flush=True,
|
| 254 |
)
|
| 255 |
print("[END] success=false steps=0 score=0.000 rewards=")
|
| 256 |
-
return 0.0, [], 0
|
| 257 |
|
| 258 |
try:
|
| 259 |
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
|
@@ -276,7 +276,7 @@ async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
|
|
| 276 |
print(
|
| 277 |
f"[END] success=false steps={steps} score={_clamp_score(score):.3f} rewards={_format_rewards(rewards)}",
|
| 278 |
)
|
| 279 |
-
return score, rewards, steps
|
| 280 |
print(
|
| 281 |
f"[warn] Falling back to deterministic policy for {task_name}: {fallback_reason}",
|
| 282 |
file=sys.stderr,
|
|
@@ -313,31 +313,34 @@ async def _run_task_http(task_name: str) -> tuple[float, list[float], int]:
|
|
| 313 |
success = terminal_error is None and fallback_reason is None
|
| 314 |
break
|
| 315 |
|
| 316 |
-
score = _clamp_score(score)
|
| 317 |
-
print(
|
| 318 |
-
f"[END] success={_format_bool(success)} steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
|
| 319 |
-
)
|
| 320 |
-
return score, rewards, steps
|
| 321 |
-
except Exception:
|
| 322 |
-
score = _clamp_score(score)
|
| 323 |
-
print(
|
| 324 |
-
f"[END] success=false steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
|
| 325 |
-
)
|
| 326 |
-
return score, rewards, steps
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
def run_task(task_name: str) -> tuple[float, list[float], int]:
|
| 330 |
-
return asyncio.run(_run_task_http(task_name))
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
def main() ->
|
| 334 |
-
if RAG_ENV_TASK in TASK_SEQUENCE:
|
| 335 |
-
tasks = [RAG_ENV_TASK] + [task for task in TASK_SEQUENCE if task != RAG_ENV_TASK]
|
| 336 |
-
else:
|
| 337 |
-
tasks = list(TASK_SEQUENCE)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
if
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"adversarial_compression",
|
| 27 |
]
|
| 28 |
|
| 29 |
+
SYSTEM_PROMPT = """You are a baseline RAG context optimizer.
|
| 30 |
+
Read the query and available chunks using chunk_id, keywords, tokens, and domain.
|
| 31 |
Select chunks that maximize keyword overlap with the query.
|
| 32 |
Stay under the token budget.
|
| 33 |
Compress chunks that are mildly relevant but token-heavy.
|
| 34 |
Submit a concise answer once enough useful chunks are selected.
|
| 35 |
When you submit an answer, cite selected chunks inline like [support_003] or [incident_002].
|
| 36 |
Return only valid JSON matching one of these forms:
|
| 37 |
+
{"action_type":"select_chunk","chunk_id":"support_003"}
|
| 38 |
+
{"action_type":"deselect_chunk","chunk_id":"support_003"}
|
| 39 |
{"action_type":"compress_chunk","chunk_id":"support_003","compression_ratio":0.5}
|
| 40 |
{"action_type":"submit_answer","answer":"Verify outage evidence and the billing ledger before refunding [support_001] [support_003]."}"""
|
| 41 |
|
|
|
|
| 222 |
return response.json()
|
| 223 |
|
| 224 |
|
| 225 |
+
async def _run_task_http(task_name: str) -> tuple[float, list[float], int, bool]:
|
| 226 |
rewards: list[float] = []
|
| 227 |
steps = 0
|
| 228 |
success = False
|
| 229 |
+
score = 0.0
|
| 230 |
+
terminal_error: str | None = None
|
| 231 |
fallback_reason: str | None = None
|
| 232 |
model_name = _model_name()
|
| 233 |
|
|
|
|
| 253 |
flush=True,
|
| 254 |
)
|
| 255 |
print("[END] success=false steps=0 score=0.000 rewards=")
|
| 256 |
+
return 0.0, [], 0, False
|
| 257 |
|
| 258 |
try:
|
| 259 |
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
|
|
|
| 276 |
print(
|
| 277 |
f"[END] success=false steps={steps} score={_clamp_score(score):.3f} rewards={_format_rewards(rewards)}",
|
| 278 |
)
|
| 279 |
+
return score, rewards, steps, False
|
| 280 |
print(
|
| 281 |
f"[warn] Falling back to deterministic policy for {task_name}: {fallback_reason}",
|
| 282 |
file=sys.stderr,
|
|
|
|
| 313 |
success = terminal_error is None and fallback_reason is None
|
| 314 |
break
|
| 315 |
|
| 316 |
+
score = _clamp_score(score)
|
| 317 |
+
print(
|
| 318 |
+
f"[END] success={_format_bool(success)} steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
|
| 319 |
+
)
|
| 320 |
+
return score, rewards, steps, success
|
| 321 |
+
except Exception:
|
| 322 |
+
score = _clamp_score(score)
|
| 323 |
+
print(
|
| 324 |
+
f"[END] success=false steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}"
|
| 325 |
+
)
|
| 326 |
+
return score, rewards, steps, False
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def run_task(task_name: str) -> tuple[float, list[float], int, bool]:
|
| 330 |
+
return asyncio.run(_run_task_http(task_name))
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def main() -> int:
|
| 334 |
+
if RAG_ENV_TASK in TASK_SEQUENCE:
|
| 335 |
+
tasks = [RAG_ENV_TASK] + [task for task in TASK_SEQUENCE if task != RAG_ENV_TASK]
|
| 336 |
+
else:
|
| 337 |
+
tasks = list(TASK_SEQUENCE)
|
| 338 |
+
all_success = True
|
| 339 |
+
for task_name in tasks:
|
| 340 |
+
_score, _rewards, _steps, success = run_task(task_name)
|
| 341 |
+
all_success &= success
|
| 342 |
+
return 0 if all_success else 1
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
if __name__ == "__main__":
|
| 346 |
+
raise SystemExit(main())
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from fastapi.testclient import TestClient
|
| 7 |
+
|
| 8 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
+
if str(ROOT) not in sys.path:
|
| 10 |
+
sys.path.insert(0, str(ROOT))
|
| 11 |
+
|
| 12 |
+
from app import app
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
client = TestClient(app)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_reset_accepts_empty_body():
|
| 19 |
+
response = client.post("/reset")
|
| 20 |
+
assert response.status_code == 200
|
| 21 |
+
body = response.json()
|
| 22 |
+
assert "episode_id" in body
|
| 23 |
+
assert body["done"] is False
|
| 24 |
+
assert "observation" in body
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_episode_state_is_isolated():
|
| 28 |
+
first_reset = client.post("/reset", json={"task_name": "single_domain_qa"})
|
| 29 |
+
second_reset = client.post("/reset", json={"task_name": "cross_domain_synthesis"})
|
| 30 |
+
assert first_reset.status_code == 200
|
| 31 |
+
assert second_reset.status_code == 200
|
| 32 |
+
|
| 33 |
+
first_episode = first_reset.json()["episode_id"]
|
| 34 |
+
second_episode = second_reset.json()["episode_id"]
|
| 35 |
+
assert first_episode != second_episode
|
| 36 |
+
|
| 37 |
+
first_chunk = first_reset.json()["observation"]["available_chunks"][0]["chunk_id"]
|
| 38 |
+
step = client.post(f"/step?episode_id={first_episode}", json={"action_type": "select_chunk", "chunk_id": first_chunk})
|
| 39 |
+
assert step.status_code == 200
|
| 40 |
+
assert step.json()["episode_id"] == first_episode
|
| 41 |
+
|
| 42 |
+
first_state = client.get(f"/state?episode_id={first_episode}")
|
| 43 |
+
second_state = client.get(f"/state?episode_id={second_episode}")
|
| 44 |
+
assert first_state.status_code == 200
|
| 45 |
+
assert second_state.status_code == 200
|
| 46 |
+
assert first_chunk in first_state.json()["selected_chunks"]
|
| 47 |
+
assert second_state.json()["selected_chunks"] == []
|
tests/test_inference_proxy.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import socket
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import httpx
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
+
PYTHON = ROOT / ".venv" / "Scripts" / "python.exe"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _free_port() -> int:
|
| 21 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 22 |
+
sock.bind(("127.0.0.1", 0))
|
| 23 |
+
return int(sock.getsockname()[1])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_inference_uses_proxy_api_key():
|
| 27 |
+
app_port = _free_port()
|
| 28 |
+
proxy_port = _free_port()
|
| 29 |
+
requests_seen: list[dict[str, str | None]] = []
|
| 30 |
+
|
| 31 |
+
class ProxyHandler(BaseHTTPRequestHandler):
|
| 32 |
+
def do_POST(self):
|
| 33 |
+
length = int(self.headers.get("Content-Length", "0"))
|
| 34 |
+
body = self.rfile.read(length).decode("utf-8")
|
| 35 |
+
requests_seen.append(
|
| 36 |
+
{
|
| 37 |
+
"path": self.path,
|
| 38 |
+
"authorization": self.headers.get("Authorization"),
|
| 39 |
+
"body": body,
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
payload = {
|
| 43 |
+
"id": "chatcmpl-test",
|
| 44 |
+
"object": "chat.completion",
|
| 45 |
+
"created": int(time.time()),
|
| 46 |
+
"model": "proxy-test-model",
|
| 47 |
+
"choices": [
|
| 48 |
+
{
|
| 49 |
+
"index": 0,
|
| 50 |
+
"message": {
|
| 51 |
+
"role": "assistant",
|
| 52 |
+
"content": json.dumps(
|
| 53 |
+
{
|
| 54 |
+
"action_type": "submit_answer",
|
| 55 |
+
"answer": "Proxy verified [support_003]",
|
| 56 |
+
}
|
| 57 |
+
),
|
| 58 |
+
},
|
| 59 |
+
"finish_reason": "stop",
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
}
|
| 63 |
+
encoded = json.dumps(payload).encode("utf-8")
|
| 64 |
+
self.send_response(200)
|
| 65 |
+
self.send_header("Content-Type", "application/json")
|
| 66 |
+
self.send_header("Content-Length", str(len(encoded)))
|
| 67 |
+
self.end_headers()
|
| 68 |
+
self.wfile.write(encoded)
|
| 69 |
+
|
| 70 |
+
def log_message(self, format: str, *args):
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
proxy_server = HTTPServer(("127.0.0.1", proxy_port), ProxyHandler)
|
| 74 |
+
proxy_thread = threading.Thread(target=proxy_server.serve_forever, daemon=True)
|
| 75 |
+
proxy_thread.start()
|
| 76 |
+
|
| 77 |
+
app_process = subprocess.Popen(
|
| 78 |
+
[str(PYTHON), "-m", "uvicorn", "app:app", "--host", "127.0.0.1", "--port", str(app_port)],
|
| 79 |
+
cwd=ROOT,
|
| 80 |
+
stdout=subprocess.DEVNULL,
|
| 81 |
+
stderr=subprocess.DEVNULL,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
deadline = time.time() + 20
|
| 86 |
+
while time.time() < deadline:
|
| 87 |
+
try:
|
| 88 |
+
if httpx.get(f"http://127.0.0.1:{app_port}/health", timeout=2).status_code == 200:
|
| 89 |
+
break
|
| 90 |
+
except Exception:
|
| 91 |
+
time.sleep(0.5)
|
| 92 |
+
|
| 93 |
+
env = os.environ.copy()
|
| 94 |
+
env["RAG_ENV_URL"] = f"http://127.0.0.1:{app_port}"
|
| 95 |
+
env["RAG_ENV_TASK"] = "single_domain_qa"
|
| 96 |
+
env["API_BASE_URL"] = f"http://127.0.0.1:{proxy_port}/v1"
|
| 97 |
+
env["API_KEY"] = "proxy-check-token"
|
| 98 |
+
env["HF_TOKEN"] = "legacy-should-not-win"
|
| 99 |
+
result = subprocess.run(
|
| 100 |
+
[str(PYTHON), "inference.py"],
|
| 101 |
+
cwd=ROOT,
|
| 102 |
+
env=env,
|
| 103 |
+
capture_output=True,
|
| 104 |
+
text=True,
|
| 105 |
+
timeout=60,
|
| 106 |
+
)
|
| 107 |
+
assert result.returncode == 0
|
| 108 |
+
assert requests_seen
|
| 109 |
+
assert requests_seen[0]["path"] == "/v1/chat/completions"
|
| 110 |
+
assert requests_seen[0]["authorization"] == "Bearer proxy-check-token"
|
| 111 |
+
assert any(line.startswith("[END]") and "score=" in line for line in result.stdout.splitlines())
|
| 112 |
+
finally:
|
| 113 |
+
proxy_server.shutdown()
|
| 114 |
+
proxy_server.server_close()
|
| 115 |
+
app_process.terminate()
|
| 116 |
+
try:
|
| 117 |
+
app_process.wait(timeout=5)
|
| 118 |
+
except Exception:
|
| 119 |
+
app_process.kill()
|
validate.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import signal
|
| 6 |
-
import socket
|
| 7 |
-
import subprocess
|
| 8 |
-
import sys
|
| 9 |
-
import
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
|
|
@@ -117,24 +119,80 @@ def run_task(client: httpx.Client, base_url: str, task_name: str) -> tuple[bool,
|
|
| 117 |
|
| 118 |
|
| 119 |
def run_inference_script(base_url: str) -> bool:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def main() -> int:
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import signal
|
| 6 |
+
import socket
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
import threading
|
| 10 |
+
import time
|
| 11 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 12 |
+
from pathlib import Path
|
| 13 |
|
| 14 |
import httpx
|
| 15 |
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def run_inference_script(base_url: str) -> bool:
|
| 122 |
+
proxy_port = find_free_port()
|
| 123 |
+
requests_seen: list[dict[str, str | None]] = []
|
| 124 |
+
|
| 125 |
+
class ProxyHandler(BaseHTTPRequestHandler):
|
| 126 |
+
def do_POST(self):
|
| 127 |
+
length = int(self.headers.get("Content-Length", "0"))
|
| 128 |
+
body = self.rfile.read(length).decode("utf-8")
|
| 129 |
+
requests_seen.append(
|
| 130 |
+
{
|
| 131 |
+
"path": self.path,
|
| 132 |
+
"authorization": self.headers.get("Authorization"),
|
| 133 |
+
"body": body,
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
payload = {
|
| 137 |
+
"id": "chatcmpl-validate",
|
| 138 |
+
"object": "chat.completion",
|
| 139 |
+
"created": int(time.time()),
|
| 140 |
+
"model": "validator-proxy",
|
| 141 |
+
"choices": [
|
| 142 |
+
{
|
| 143 |
+
"index": 0,
|
| 144 |
+
"message": {
|
| 145 |
+
"role": "assistant",
|
| 146 |
+
"content": json.dumps(
|
| 147 |
+
{
|
| 148 |
+
"action_type": "submit_answer",
|
| 149 |
+
"answer": "Validated via proxy [support_003]",
|
| 150 |
+
}
|
| 151 |
+
),
|
| 152 |
+
},
|
| 153 |
+
"finish_reason": "stop",
|
| 154 |
+
}
|
| 155 |
+
],
|
| 156 |
+
}
|
| 157 |
+
encoded = json.dumps(payload).encode("utf-8")
|
| 158 |
+
self.send_response(200)
|
| 159 |
+
self.send_header("Content-Type", "application/json")
|
| 160 |
+
self.send_header("Content-Length", str(len(encoded)))
|
| 161 |
+
self.end_headers()
|
| 162 |
+
self.wfile.write(encoded)
|
| 163 |
+
|
| 164 |
+
def log_message(self, format: str, *args):
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
proxy_server = HTTPServer(("127.0.0.1", proxy_port), ProxyHandler)
|
| 168 |
+
proxy_thread = threading.Thread(target=proxy_server.serve_forever, daemon=True)
|
| 169 |
+
proxy_thread.start()
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
env = os.environ.copy()
|
| 173 |
+
env["RAG_ENV_URL"] = base_url
|
| 174 |
+
env.pop("ALLOW_BASELINE_FALLBACK", None)
|
| 175 |
+
env["API_BASE_URL"] = f"http://127.0.0.1:{proxy_port}/v1"
|
| 176 |
+
env["API_KEY"] = "offline-validation-token"
|
| 177 |
+
env["HF_TOKEN"] = "legacy-should-not-win"
|
| 178 |
+
process = subprocess.run(
|
| 179 |
+
[sys.executable, "inference.py"],
|
| 180 |
+
cwd=PROJECT_ROOT,
|
| 181 |
+
capture_output=True,
|
| 182 |
+
text=True,
|
| 183 |
+
timeout=120,
|
| 184 |
+
env=env,
|
| 185 |
+
)
|
| 186 |
+
stdout = process.stdout or ""
|
| 187 |
+
has_start = "[START]" in stdout
|
| 188 |
+
has_end = "[END]" in stdout
|
| 189 |
+
end_has_score = " score=" in stdout
|
| 190 |
+
proxy_called = any(request["path"] == "/v1/chat/completions" for request in requests_seen)
|
| 191 |
+
auth_ok = any(request["authorization"] == "Bearer offline-validation-token" for request in requests_seen)
|
| 192 |
+
return process.returncode == 0 and has_start and has_end and end_has_score and proxy_called and auth_ok
|
| 193 |
+
finally:
|
| 194 |
+
proxy_server.shutdown()
|
| 195 |
+
proxy_server.server_close()
|
| 196 |
|
| 197 |
|
| 198 |
def main() -> int:
|