context-prune / rag_optimizer_env /prompt_optimizer.py
prithic07's picture
Upgrade RAG project with advanced Context Optimizer environment and baseline inference
0b89610
"""
Prompt optimization utilities for the HTTP app.
This module keeps prompt-rewriting and evidence-packaging logic out of `app.py`
so the FastAPI layer stays thinner and easier to review.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Literal
from rag_optimizer_env.environment import RagContextOptimizerEnv
from rag_optimizer_env.llm_runtime import estimate_tokens, llm_configured
from rag_optimizer_env.llm_services import rewrite_prompt as rewrite_prompt_with_llm
CompressionMode = Literal["balanced", "aggressive", "grounded"]
_PROMPT_STOPWORDS = {
"a", "an", "and", "are", "as", "at", "be", "but", "by", "can", "could", "do", "does", "did",
"for", "from", "had", "has", "have", "how", "i", "if", "in", "into", "is", "it", "its", "me",
"my", "of", "on", "or", "our", "should", "so", "than", "that", "the", "their", "them", "then",
"there", "these", "they", "this", "to", "too", "use", "using", "was", "we", "were", "what",
"when", "where", "which", "while", "with", "without", "would", "you", "your",
}
_INSTRUCTION_PRIORITY_TERMS = {
"must", "should", "only", "not", "never", "always", "include", "exclude", "cite", "answer",
"return", "draft", "write", "summarize", "compare", "explain", "verify", "preserve", "focus",
"keep", "avoid", "report", "escalate", "rollback", "refund", "incident", "customer", "security",
}
@dataclass(frozen=True, slots=True)
class PromptOptimizationResult:
optimized_prompt: str
stats: dict[str, int]
grounding: dict[str, Any]
context_tuning: dict[str, Any]
corpus_family: str
selected_keywords: list[str]
optimization_mode: CompressionMode
def _tokenize(text: str) -> set[str]:
return set(re.findall(r"[a-z0-9]+", text.lower()))
def _content_terms(text: str) -> set[str]:
return {term for term in _tokenize(text) if len(term) > 2 and term not in _PROMPT_STOPWORDS}
def _clean_output_text(text: str) -> str:
cleaned = text.replace("```", " ").replace("---", " ")
cleaned = re.sub(r"\s+", " ", cleaned).strip()
cleaned = re.sub(r"[#*_`]+", "", cleaned)
cleaned = re.sub(r'\b(title|emoji|colorfrom|colorto|sdk|app_file|pinned)\s*:\s*', "", cleaned, flags=re.IGNORECASE)
return cleaned.strip(" -:")
def _compact_text(text: str, max_words: int = 28) -> str:
words = text.split()
if len(words) <= max_words:
return text
return " ".join(words[:max_words]).rstrip(" ,;:") + " ..."
def _approx_tokens(text: str) -> int:
return max(1, len(text.strip()) // 4) if text.strip() else 0
def _truncate_to_word_boundary(text: str, max_chars: int, add_ellipsis: bool = True) -> str:
raw = text.strip()
if not raw or len(raw) <= max_chars:
return raw
candidate = raw[:max_chars].rstrip(" ,;:\n\t")
if max_chars < len(raw) and max_chars > 0 and not raw[max_chars - 1].isspace():
last_space = candidate.rfind(" ")
if last_space >= max(4, max_chars // 3):
candidate = candidate[:last_space].rstrip(" ,;:\n\t")
if not candidate:
candidate = raw[:max_chars].rstrip(" ,;:\n\t")
if add_ellipsis and candidate and not candidate.endswith("..."):
candidate = candidate + " ..."
return candidate
def _trim_sentence(sentence: str, max_terms: int) -> str:
words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_\/]*|[,:;()]", sentence)
if not words:
return ""
kept: list[str] = []
for index, token in enumerate(words):
normalized = re.sub(r"[^A-Za-z0-9]+", "", token).lower()
if token in {",", ":", ";", "(", ")"}:
if kept and kept[-1] not in {",", ":", ";", "("}:
kept.append(token)
continue
is_priority = normalized in _INSTRUCTION_PRIORITY_TERMS
is_meaningful = (
normalized.isdigit()
or any(ch in token for ch in ("_", "-", "/"))
or len(normalized) >= 4
or is_priority
or index < 3
)
if not is_meaningful:
continue
if normalized in _PROMPT_STOPWORDS and not is_priority and index >= 3:
continue
kept.append(token)
if len([word for word in kept if word not in {",", ":", ";", "(", ")"}]) >= max_terms:
break
text = " ".join(kept)
text = re.sub(r"\s+([,:;)])", r"\1", text)
text = re.sub(r"(\()\s+", r"\1", text)
return text.strip(" ,;:")
def _rewrite_prompt_text(prompt: str, target_tokens: int) -> str:
raw = " ".join(prompt.strip().split())
if not raw:
return ""
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", raw) if segment.strip()]
if not sentences:
sentences = [raw]
rewritten: list[str] = []
used_terms = 0
max_terms = max(8, target_tokens)
for index, sentence in enumerate(sentences):
remaining = max_terms - used_terms
if remaining <= 0:
break
compact = _trim_sentence(sentence, max(4, remaining if index == 0 else min(remaining, 12)))
if not compact:
continue
rewritten.append(compact)
used_terms += len(compact.split())
if used_terms >= max_terms:
break
if not rewritten:
fallback = _trim_sentence(raw, max_terms)
return fallback or raw[: max(16, target_tokens * 4)].strip()
output = ". ".join(rewritten).strip()
if len(rewritten) == 1 and not output.endswith("."):
output += "."
return output
def _lightweight_short_prompt_rewrite(prompt: str) -> str:
raw = " ".join(prompt.strip().split())
if not raw:
return ""
cleaned = raw
cleaned = re.sub(r"\b[Pp]lease\s+", "", cleaned)
cleaned = re.sub(r"\bhelp me to\b", "help me", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bhelp me\b", "Help me", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bi want to\b", "I want to", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bcan you help me\b", "Help me", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
if cleaned:
cleaned = cleaned[0].upper() + cleaned[1:]
return cleaned
def _sentence_rank(query: str, text: str) -> list[str]:
query_terms = _tokenize(query)
sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()]
if not sentences:
return []
ranked: list[tuple[float, str]] = []
for index, sentence in enumerate(sentences):
sentence_terms = _tokenize(sentence)
overlap = len(query_terms & sentence_terms)
score = (overlap * 2.0) + (0.25 if index == 0 else 0.0)
ranked.append((score, sentence))
ranked.sort(key=lambda item: (-item[0], len(item[1])))
return [sentence for _score, sentence in ranked]
def _summarize_chunk_for_output(chunk: Any, effective_text: str) -> str:
if getattr(chunk, "domain", "").startswith("Project"):
keywords = ", ".join(chunk.keywords[:5])
domain = chunk.domain.replace("Project ", "").lower()
return _compact_text(f"This benchmark's {domain} covers {keywords}.", 24)
ranked_sentences = _sentence_rank(" ".join(chunk.keywords), _clean_output_text(effective_text))
if ranked_sentences:
return _compact_text(_clean_output_text(ranked_sentences[0]))
return _compact_text(_clean_output_text(effective_text))
def _target_ratio(input_tokens: int, mode: CompressionMode) -> float:
if mode == "aggressive":
if input_tokens <= 24:
return 0.78
if input_tokens <= 60:
return 0.66
if input_tokens <= 120:
return 0.58
return 0.52
if mode == "grounded":
if input_tokens <= 24:
return 0.98
if input_tokens <= 60:
return 0.90
if input_tokens <= 120:
return 0.84
return 0.78
if input_tokens <= 24:
return 0.85
if input_tokens <= 60:
return 0.75
if input_tokens <= 120:
return 0.68
return 0.62
def _fit_citations_into_prompt(
base_prompt: str,
citation_ids: list[str],
input_tokens: int,
target_tokens: int,
source_prompt: str,
mode: CompressionMode,
) -> tuple[str, bool, str | None]:
if not citation_ids:
return base_prompt, False, "No high-confidence evidence anchors were selected."
prioritized = citation_ids[: (3 if mode == "grounded" else 2)]
suffix = " Evidence: " + " ".join(f"[{chunk_id}]" for chunk_id in prioritized)
with_all = (base_prompt.rstrip(".") + "." + suffix).strip()
if mode == "grounded" and _approx_tokens(with_all) <= max(input_tokens, target_tokens + 4):
return with_all, True, None
if _approx_tokens(with_all) < input_tokens:
return with_all, True, None
one_suffix = " Evidence: " + f"[{citation_ids[0]}]"
with_one = (base_prompt.rstrip(".") + "." + one_suffix).strip()
if mode == "grounded" and _approx_tokens(with_one) <= max(input_tokens, target_tokens + 2):
return with_one, True, None
if _approx_tokens(with_one) < input_tokens:
return with_one, True, None
tighter_target = max(8, target_tokens - (2 if mode == "grounded" else 3))
tighter_prompt = _rewrite_prompt_text(source_prompt, tighter_target)
tighter_with_one = (tighter_prompt.rstrip(".") + "." + one_suffix).strip()
if mode == "grounded" and _approx_tokens(tighter_with_one) <= max(input_tokens, target_tokens + 2):
return tighter_with_one, True, None
if _approx_tokens(tighter_with_one) < input_tokens:
return tighter_with_one, True, None
if mode == "grounded":
forced = (tighter_prompt.rstrip(".") + "." + one_suffix).strip()
return forced, True, "Grounded mode preserved at least one inline citation, even at the cost of a slightly longer prompt."
return base_prompt, False, "Citations were omitted to keep the optimized prompt shorter than the original. Use grounded mode or the evidence notes below if explicit anchors are required."
async def optimize_prompt(
prompt: str,
corpus_family: str | None = None,
mode: CompressionMode = "balanced",
) -> PromptOptimizationResult:
clean_prompt = prompt.strip()
env = RagContextOptimizerEnv(
task_name="refund_triage_easy",
query_override=clean_prompt,
token_budget_override=800,
max_steps_override=6,
corpus_family_override=corpus_family,
)
await env.reset()
tuning = env._last_tuning or env.context_tuner.tune(clean_prompt, env._available_chunks)
ranked_candidates = []
for chunk in env._available_chunks:
tuned = tuning.tuned_scores.get(chunk.chunk_id)
score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
if score < 0.16:
continue
ranked_candidates.append((chunk, score, tuned))
ranked_candidates.sort(
key=lambda item: (
-(item[2].citation_prior if item[2] is not None else 0.0) if mode == "grounded" else 0.0,
-(item[1] / max(item[0].tokens, 1)),
-item[1],
item[0].chunk_id,
)
)
selected_ids: list[str] = []
token_cap = 420 if mode == "grounded" else 360
running_tokens = 0
for chunk, score, _tuned in ranked_candidates:
if len(selected_ids) >= (5 if mode == "grounded" else 4):
break
if score < (0.18 if mode == "grounded" else 0.22) and selected_ids:
break
projected = running_tokens + chunk.tokens
if projected > token_cap and selected_ids:
continue
selected_ids.append(chunk.chunk_id)
env._selected_chunks.append(chunk.chunk_id)
running_tokens += chunk.tokens
if not selected_ids and ranked_candidates:
best_chunk = ranked_candidates[0][0]
selected_ids.append(best_chunk.chunk_id)
env._selected_chunks.append(best_chunk.chunk_id)
for chunk_id in list(selected_ids):
chunk = env._chunk_map().get(chunk_id)
if chunk is None:
continue
tuned = tuning.tuned_scores.get(chunk_id)
score = tuned.final_score if tuned is not None else env.retriever.hybrid_score(clean_prompt, chunk)
ratio = tuned.compression_ratio if tuned is not None else 0.5
if mode == "grounded":
ratio = max(ratio, 0.68 if score >= 0.55 else 0.58)
elif score >= 0.75:
ratio = max(ratio, 0.6)
env._compression_ratios[chunk_id] = ratio
input_tokens = await estimate_tokens(clean_prompt)
target_tokens = max(12, int(input_tokens * _target_ratio(input_tokens, mode)))
target_tokens = min(target_tokens, 120 if mode == "grounded" else 80)
preserve_short_prompt = mode != "aggressive" and input_tokens <= 12 and len(clean_prompt.split()) <= 8
distilled_points: list[tuple[str, str]] = []
if not preserve_short_prompt:
for chunk_id in env._selected_chunks:
chunk = env._chunk_map().get(chunk_id)
if chunk is None:
continue
best = _summarize_chunk_for_output(chunk, env._effective_chunk_text(chunk_id))
if best and all(existing != best for _cid, existing in distilled_points):
distilled_points.append((chunk_id, best))
if len(distilled_points) >= (3 if mode == "grounded" else (2 if input_tokens < 80 else 3)):
break
citation_ids = tuning.suggested_citations or list(env._selected_chunks)
if llm_configured():
llm_result = await rewrite_prompt_with_llm(
prompt=clean_prompt,
mode=mode,
target_tokens=target_tokens,
evidence_notes=[
{"chunk_id": chunk_id, "note": note}
for chunk_id, note in distilled_points
],
citation_ids=citation_ids,
)
optimized_prompt = llm_result["optimized_prompt"] or clean_prompt
citation_ready = llm_result["citation_ready"]
citation_guidance = llm_result["citation_guidance"]
optimized_prompt_tokens = llm_result["estimated_tokens"]
else:
rewritten = _rewrite_prompt_text(clean_prompt, target_tokens=target_tokens)
short_prompt_rewrite = _lightweight_short_prompt_rewrite(clean_prompt) if preserve_short_prompt else ""
lines: list[str] = [
short_prompt_rewrite if preserve_short_prompt and short_prompt_rewrite else (
clean_prompt if preserve_short_prompt else (rewritten if rewritten else clean_prompt)
)
]
if distilled_points and (mode == "grounded" or input_tokens >= 80):
lines.append("")
lines.append("Context:")
lines.extend([f"- [{chunk_id}] {point}" for chunk_id, point in distilled_points])
optimized_prompt = "\n".join(lines).strip()
if preserve_short_prompt and not distilled_points:
optimized_prompt = short_prompt_rewrite if short_prompt_rewrite and short_prompt_rewrite != clean_prompt else clean_prompt
elif mode != "grounded" and input_tokens > 0 and _approx_tokens(optimized_prompt) >= input_tokens:
max_chars = max(12, (input_tokens - 1) * 4)
optimized_prompt = _truncate_to_word_boundary(optimized_prompt, max_chars)
while input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens and len(optimized_prompt) > 12:
optimized_prompt = _truncate_to_word_boundary(optimized_prompt, max(8, len(optimized_prompt) - 6))
if input_tokens > 1 and _approx_tokens(optimized_prompt) >= input_tokens:
optimized_prompt = _rewrite_prompt_text(clean_prompt, target_tokens=max(5, input_tokens - 1))
if optimized_prompt and not optimized_prompt.endswith("...") and _approx_tokens(optimized_prompt) >= input_tokens:
optimized_prompt = _truncate_to_word_boundary(optimized_prompt, max(8, (input_tokens - 1) * 4))
optimized_prompt, citation_ready, citation_guidance = _fit_citations_into_prompt(
optimized_prompt,
citation_ids,
input_tokens,
target_tokens,
clean_prompt,
mode,
)
optimized_prompt_tokens = await estimate_tokens(optimized_prompt)
original_prompt_tokens = input_tokens
source_tokens = sum(env._chunk_map()[chunk_id].tokens for chunk_id in env._selected_chunks if chunk_id in env._chunk_map())
compressed_tokens = sum(env._effective_chunk_tokens(chunk_id) for chunk_id in env._selected_chunks)
evidence_terms = _content_terms(" ".join(env._effective_chunk_text(chunk_id) for chunk_id in env._selected_chunks))
prompt_terms = _content_terms(optimized_prompt)
inline_citations = set(re.findall(r"\[([a-z0-9_]+)\]", optimized_prompt.lower()))
grounded_overlap = (len(prompt_terms & evidence_terms) / len(prompt_terms)) if prompt_terms else 0.0
return PromptOptimizationResult(
optimized_prompt=optimized_prompt,
stats={
"selected_chunks": len(env._selected_chunks),
"source_tokens": source_tokens,
"compressed_context_tokens": compressed_tokens,
"original_prompt_tokens": original_prompt_tokens,
"optimized_prompt_tokens": optimized_prompt_tokens,
"compression_gain": max(0, source_tokens - compressed_tokens),
},
grounding={
"citations": tuning.suggested_citations or list(env._selected_chunks),
"citation_ready": citation_ready and bool(inline_citations),
"citation_guidance": citation_guidance,
"grounded_overlap": round(grounded_overlap, 3),
"evidence_notes": [
{"chunk_id": chunk_id, "note": note}
for chunk_id, note in distilled_points
],
},
context_tuning={
"mode": tuning.mode,
"top_demo_cases": tuning.top_demo_cases,
"suggested_citations": tuning.suggested_citations,
"token_dropout": tuning.token_dropout,
"leave_one_out": tuning.leave_one_out,
},
corpus_family=env._corpus_family,
selected_keywords=[
keyword
for chunk_id in env._selected_chunks
for keyword in (env._chunk_map().get(chunk_id).keywords if env._chunk_map().get(chunk_id) else [])
][:10],
optimization_mode=mode,
)