Spaces:
Sleeping
Sleeping
File size: 18,582 Bytes
0b89610 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 | """
Prompt optimization utilities for the HTTP app.
This module keeps prompt-rewriting and evidence-packaging logic out of `app.py`
so the FastAPI layer stays thinner and easier to review.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Literal
from rag_optimizer_env.environment import RagContextOptimizerEnv
from rag_optimizer_env.llm_runtime import estimate_tokens, llm_configured
from rag_optimizer_env.llm_services import rewrite_prompt as rewrite_prompt_with_llm
CompressionMode = Literal["balanced", "aggressive", "grounded"]
_PROMPT_STOPWORDS = {
"a", "an", "and", "are", "as", "at", "be", "but", "by", "can", "could", "do", "does", "did",
"for", "from", "had", "has", "have", "how", "i", "if", "in", "into", "is", "it", "its", "me",
"my", "of", "on", "or", "our", "should", "so", "than", "that", "the", "their", "them", "then",
"there", "these", "they", "this", "to", "too", "use", "using", "was", "we", "were", "what",
"when", "where", "which", "while", "with", "without", "would", "you", "your",
}
_INSTRUCTION_PRIORITY_TERMS = {
"must", "should", "only", "not", "never", "always", "include", "exclude", "cite", "answer",
"return", "draft", "write", "summarize", "compare", "explain", "verify", "preserve", "focus",
"keep", "avoid", "report", "escalate", "rollback", "refund", "incident", "customer", "security",
}
@dataclass(frozen=True, slots=True)
class PromptOptimizationResult:
optimized_prompt: str
stats: dict[str, int]
grounding: dict[str, Any]
context_tuning: dict[str, Any]
corpus_family: str
selected_keywords: list[str]
optimization_mode: CompressionMode
def _tokenize(text: str) -> set[str]:
return set(re.findall(r"[a-z0-9]+", text.lower()))
def _content_terms(text: str) -> set[str]:
return {term for term in _tokenize(text) if len(term) > 2 and term not in _PROMPT_STOPWORDS}
def _clean_output_text(text: str) -> str:
cleaned = text.replace("```", " ").replace("---", " ")
cleaned = re.sub(r"\s+", " ", cleaned).strip()
cleaned = re.sub(r"[#*_`]+", "", cleaned)
cleaned = re.sub(r'\b(title|emoji|colorfrom|colorto|sdk|app_file|pinned)\s*:\s*', "", cleaned, flags=re.IGNORECASE)
return cleaned.strip(" -:")
def _compact_text(text: str, max_words: int = 28) -> str:
words = text.split()
if len(words) <= max_words:
return text
return " ".join(words[:max_words]).rstrip(" ,;:") + " ..."
def _approx_tokens(text: str) -> int:
return max(1, len(text.strip()) // 4) if text.strip() else 0
def _truncate_to_word_boundary(text: str, max_chars: int, add_ellipsis: bool = True) -> str:
raw = text.strip()
if not raw or len(raw) <= max_chars:
return raw
candidate = raw[:max_chars].rstrip(" ,;:\n\t")
if max_chars < len(raw) and max_chars > 0 and not raw[max_chars - 1].isspace():
last_space = candidate.rfind(" ")
if last_space >= max(4, max_chars // 3):
candidate = candidate[:last_space].rstrip(" ,;:\n\t")
if not candidate:
candidate = raw[:max_chars].rstrip(" ,;:\n\t")
if add_ellipsis and candidate and not candidate.endswith("..."):
candidate = candidate + " ..."
return candidate
def _trim_sentence(sentence: str, max_terms: int) -> str:
words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_\/]*|[,:;()]", sentence)
if not words:
return ""
kept: list[str] = []
for index, token in enumerate(words):
normalized = re.sub(r"[^A-Za-z0-9]+", "", token).lower()
if token in {",", ":", ";", "(", ")"}:
if kept and kept[-1] not in {",", ":", ";", "("}:
kept.append(token)
continue
is_priority = normalized in _INSTRUCTION_PRIORITY_TERMS
is_meaningful = (
normalized.isdigit()
or any(ch in token for ch in ("_", "-", "/"))
or len(normalized) >= 4
or is_priority
or index < 3
)
if not is_meaningful:
continue
if normalized in _PROMPT_STOPWORDS and not is_priority and index >= 3:
continue
kept.append(token)
if len([word for word in kept if word not in {",", ":", ";", "(", ")"}]) >= max_terms:
break
text = " ".join(kept)
text = re.sub(r"\s+([,:;)])", r"\1", text)
text = re.sub(r"(\()\s+", r"\1", text)
return text.strip(" ,;:")
def _rewrite_prompt_text(prompt: str, target_tokens: int) -> str:
raw = " ".join(prompt.strip().split())
if not raw:
return ""
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", raw) if segment.strip()]
if not sentences:
sentences = [raw]
rewritten: list[str] = []
used_terms = 0
max_terms = max(8, target_tokens)
for index, sentence in enumerate(sentences):
remaining = max_terms - used_terms
if remaining <= 0:
break
compact = _trim_sentence(sentence, max(4, remaining if index == 0 else min(remaining, 12)))
if not compact:
continue
rewritten.append(compact)
used_terms += len(compact.split())
if used_terms >= max_terms:
break
if not rewritten:
fallback = _trim_sentence(raw, max_terms)
return fallback or raw[: max(16, target_tokens * 4)].strip()
output = ". ".join(rewritten).strip()
if len(rewritten) == 1 and not output.endswith("."):
output += "."
return output
def _lightweight_short_prompt_rewrite(prompt: str) -> str:
raw = " ".join(prompt.strip().split())
if not raw:
return ""
cleaned = raw
cleaned = re.sub(r"\b[Pp]lease\s+", "", cleaned)
cleaned = re.sub(r"\bhelp me to\b", "help me", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bhelp me\b", "Help me", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bi want to\b", "I want to", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bcan you help me\b", "Help me", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
if cleaned:
cleaned = cleaned[0].upper() + cleaned[1:]
return cleaned
def _sentence_rank(query: str, text: str) -> list[str]:
query_terms = _tokenize(query)
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
if not sentences:
return []
ranked: list[tuple[float, str]] = []
for index, sentence in enumerate(sentences):
sentence_terms = _tokenize(sentence)
overlap = len(query_terms & sentence_terms)
score = (overlap * 2.0) + (0.25 if index == 0 else 0.0)
ranked.append((score, sentence))
ranked.sort(key=lambda item: (-item[0], len(item[1])))
return [sentence for _score, sentence in ranked]
def _summarize_chunk_for_output(chunk: Any, effective_text: str) -> str:
if getattr(chunk, "domain", "").startswith("Project"):
keywords = ", ".join(chunk.keywords[:5])
domain = chunk.domain.replace("Project ", "").lower()
return _compact_text(f"This benchmark's {domain} covers {keywords}.", 24)
ranked_sentences = _sentence_rank(" ".join(chunk.keywords), _clean_output_text(effective_text))
if ranked_sentences:
return _compact_text(_clean_output_text(ranked_sentences[0]))
return _compact_text(_clean_output_text(effective_text))
def _target_ratio(input_tokens: int, mode: CompressionMode) -> float:
if mode == "aggressive":
if input_tokens <= 24:
return 0.78
if input_tokens <= 60:
return 0.66
if input_tokens <= 120:
return 0.58
return 0.52
if mode == "grounded":
if input_tokens <= 24:
return 0.98
if input_tokens <= 60:
return 0.90
if input_tokens <= 120:
return 0.84
return 0.78
if input_tokens <= 24:
return 0.85
if input_tokens <= 60:
return 0.75
if input_tokens <= 120:
return 0.68
return 0.62
def _fit_citations_into_prompt(
base_prompt: str,
citation_ids: list[str],
input_tokens: int,
target_tokens: int,
source_prompt: str,
mode: CompressionMode,
) -> tuple[str, bool, str | None]:
if not citation_ids:
return base_prompt, False, "No high-confidence evidence anchors were selected."
prioritized = citation_ids[: (3 if mode == "grounded" else 2)]
suffix = " Evidence: " + " ".join(f"[{chunk_id}]" for chunk_id in prioritized)
with_all = (base_prompt.rstrip(".") + "." + suffix).strip()
if mode == "grounded" and _approx_tokens(with_all) <= max(input_tokens, target_tokens + 4):
return with_all, True, None
if _approx_tokens(with_all) < input_tokens:
return with_all, True, None
one_suffix = " Evidence: " + f"[{citation_ids[0]}]"
with_one = (base_prompt.rstrip(".") + "." + one_suffix).strip()
if mode == "grounded" and _approx_tokens(with_one) <= max(input_tokens, target_tokens + 2):
return with_one, True, None
if _approx_tokens(with_one) < input_tokens:
return with_one, True, None
tighter_target = max(8, target_tokens - (2 if mode == "grounded" else 3))
tighter_prompt = _rewrite_prompt_text(source_prompt, tighter_target)
tighter_with_one = (tighter_prompt.rstrip(".") + "." + one_suffix).strip()
if mode == "grounded" and _approx_tokens(tighter_with_one) <= max(input_tokens, target_tokens + 2):
return tighter_with_one, True, None
if _approx_tokens(tighter_with_one) < input_tokens:
return tighter_with_one, True, None
if mode == "grounded":
forced = (tighter_prompt.rstrip(".") + "." + one_suffix).strip()
return forced, True, "Grounded mode preserved at least one inline citation, even at the cost of a slightly longer prompt."
return base_prompt, False, "Citations were omitted to keep the optimized prompt shorter than the original. Use grounded mode or the evidence notes below if explicit anchors are required."
async def optimize_prompt(
prompt: str,
corpus_family: str | None = None,
mode: CompressionMode = "balanced",
) -> PromptOptimizationResult:
clean_prompt = prompt.strip()
env = RagContextOptimizerEnv(
task_name="refund_triage_easy",
query_override=clean_prompt,
token_budget_override=800,
max_steps_override=6,
corpus_family_override=corpus_family,
)
await env.reset()
tuning = env._last_tuning or env.context_tuner.tune(clean_prompt, env._available_chunks)
ranked_candidates = []
for chunk in env._available_chunks:
tuned = tuning.tuned_scores.get(chunk.chunk_id)
score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
if score < 0.16:
continue
ranked_candidates.append((chunk, score, tuned))
ranked_candidates.sort(
key=lambda item: (
-(item[2].citation_prior if item[2] is not None else 0.0) if mode == "grounded" else 0.0,
-(item[1] / max(item[0].tokens, 1)),
-item[1],
item[0].chunk_id,
)
)
selected_ids: list[str] = []
token_cap = 420 if mode == "grounded" else 360
running_tokens = 0
for chunk, score, _tuned in ranked_candidates:
if len(selected_ids) >= (5 if mode == "grounded" else 4):
break
if score < (0.18 if mode == "grounded" else 0.22) and selected_ids:
break
projected = running_tokens + chunk.tokens
if projected > token_cap and selected_ids:
continue
selected_ids.append(chunk.chunk_id)
env._selected_chunks.append(chunk.chunk_id)
running_tokens += chunk.tokens
if not selected_ids and ranked_candidates:
best_chunk = ranked_candidates[0][0]
selected_ids.append(best_chunk.chunk_id)
env._selected_chunks.append(best_chunk.chunk_id)
for chunk_id in list(selected_ids):
chunk = env._chunk_map().get(chunk_id)
if chunk is None:
continue
tuned = tuning.tuned_scores.get(chunk_id)
score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
ratio = tuned.compression_ratio if tuned is not None else 0.5
if mode == "grounded":
ratio = max(ratio, 0.68 if score >= 0.55 else 0.58)
elif score >= 0.75:
ratio = max(ratio, 0.6)
env._compression_ratios[chunk_id] = ratio
input_tokens = await estimate_tokens(clean_prompt)
target_tokens = max(12, int(input_tokens * _target_ratio(input_tokens, mode)))
target_tokens = min(target_tokens, 120 if mode == "grounded" else 80)
preserve_short_prompt = mode != "aggressive" and input_tokens <= 12 and len(clean_prompt.split()) <= 8
distilled_points: list[tuple[str, str]] = []
if not preserve_short_prompt:
for chunk_id in env._selected_chunks:
chunk = env._chunk_map().get(chunk_id)
if chunk is None:
continue
best = _summarize_chunk_for_output(chunk, env._effective_chunk_text(chunk_id))
if best and all(existing != best for _cid, existing in distilled_points):
distilled_points.append((chunk_id, best))
if len(distilled_points) >= (3 if mode == "grounded" else (2 if input_tokens < 80 else 3)):
break
citation_ids = tuning.suggested_citations or list(env._selected_chunks)
if llm_configured():
llm_result = await rewrite_prompt_with_llm(
prompt=clean_prompt,
mode=mode,
target_tokens=target_tokens,
evidence_notes=[
{"chunk_id": chunk_id, "note": note}
for chunk_id, note in distilled_points
],
citation_ids=citation_ids,
)
optimized_prompt = llm_result["optimized_prompt"] or clean_prompt
citation_ready = llm_result["citation_ready"]
citation_guidance = llm_result["citation_guidance"]
optimized_prompt_tokens = llm_result["estimated_tokens"]
else:
rewritten = _rewrite_prompt_text(clean_prompt, target_tokens=target_tokens)
short_prompt_rewrite = _lightweight_short_prompt_rewrite(clean_prompt) if preserve_short_prompt else ""
lines: list[str] = [
short_prompt_rewrite if preserve_short_prompt and short_prompt_rewrite else (
clean_prompt if preserve_short_prompt else (rewritten if rewritten else clean_prompt)
)
]
if distilled_points and (mode == "grounded" or input_tokens >= 80):
lines.append("")
lines.append("Context:")
lines.extend([f"- [{chunk_id}] {point}" for chunk_id, point in distilled_points])
optimized_prompt = "\n".join(lines).strip()
if preserve_short_prompt and not distilled_points:
optimized_prompt = short_prompt_rewrite if short_prompt_rewrite and short_prompt_rewrite != clean_prompt else clean_prompt
elif mode != "grounded" and input_tokens > 0 and _approx_tokens(optimized_prompt) >= input_tokens:
max_chars = max(12, (input_tokens - 1) * 4)
optimized_prompt = _truncate_to_word_boundary(optimized_prompt, max_chars)
while input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens and len(optimized_prompt) > 12:
optimized_prompt = _truncate_to_word_boundary(optimized_prompt, max(8, len(optimized_prompt) - 6))
if input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens:
optimized_prompt = _rewrite_prompt_text(clean_prompt, target_tokens=max(5, input_tokens - 1))
if optimized_prompt and not optimized_prompt.endswith("...") and _approx_tokens(optimized_prompt) >= input_tokens:
optimized_prompt = _truncate_to_word_boundary(optimized_prompt, max(8, (input_tokens - 1) * 4))
optimized_prompt, citation_ready, citation_guidance = _fit_citations_into_prompt(
optimized_prompt,
citation_ids,
input_tokens,
target_tokens,
clean_prompt,
mode,
)
optimized_prompt_tokens = await estimate_tokens(optimized_prompt)
original_prompt_tokens = input_tokens
source_tokens = sum(env._chunk_map()[chunk_id].tokens for chunk_id in env._selected_chunks if chunk_id in env._chunk_map())
compressed_tokens = sum(env._effective_chunk_tokens(chunk_id) for chunk_id in env._selected_chunks)
evidence_terms = _content_terms(" ".join(env._effective_chunk_text(chunk_id) for chunk_id in env._selected_chunks))
prompt_terms = _content_terms(optimized_prompt)
inline_citations = set(re.findall(r"\[([a-z0-9_]+)\]", optimized_prompt.lower()))
grounded_overlap = (len(prompt_terms & evidence_terms) / len(prompt_terms)) if prompt_terms else 0.0
return PromptOptimizationResult(
optimized_prompt=optimized_prompt,
stats={
"selected_chunks": len(env._selected_chunks),
"source_tokens": source_tokens,
"compressed_context_tokens": compressed_tokens,
"original_prompt_tokens": original_prompt_tokens,
"optimized_prompt_tokens": optimized_prompt_tokens,
"compression_gain": max(0, source_tokens - compressed_tokens),
},
grounding={
"citations": tuning.suggested_citations or list(env._selected_chunks),
"citation_ready": citation_ready and bool(inline_citations),
"citation_guidance": citation_guidance,
"grounded_overlap": round(grounded_overlap, 3),
"evidence_notes": [
{"chunk_id": chunk_id, "note": note}
for chunk_id, note in distilled_points
],
},
context_tuning={
"mode": tuning.mode,
"top_demo_cases": tuning.top_demo_cases,
"suggested_citations": tuning.suggested_citations,
"token_dropout": tuning.token_dropout,
"leave_one_out": tuning.leave_one_out,
},
corpus_family=env._corpus_family,
selected_keywords=[
keyword
for chunk_id in env._selected_chunks
for keyword in (env._chunk_map().get(chunk_id).keywords if env._chunk_map().get(chunk_id) else [])
][:10],
optimization_mode=mode,
)
|