context-prune / rag_optimizer_env /context_tuner.py
prithic07's picture
Upgrade RAG project with advanced Context Optimizer environment and baseline inference
0b89610
"""
PyTorch-backed context tuning for retrieval and prompt optimization.
This module adapts the paper's core idea to this benchmark:
- initialize a lightweight context representation from task-specific demonstrations
- optimize that context rather than the underlying language model
- apply leave-one-out masking and token dropout for regularized tuning
"""
from __future__ import annotations
from dataclasses import dataclass
import hashlib
import math
import re
from typing import Iterable
from rag_optimizer_env.corpus import Chunk
from rag_optimizer_env.retriever import HybridRetriever
from rag_optimizer_env.tasks import Task
try:
import torch
import torch.nn.functional as F
except ModuleNotFoundError: # pragma: no cover - optional dependency at runtime
torch = None
F = None
_DOMAIN_TO_FLAGS = {
"Customer Support Operations": (1.0, 0.0, 0.0),
"Incident Response Playbooks": (0.0, 1.0, 0.0),
"Platform Reliability & Release Engineering": (0.0, 0.0, 1.0),
}
_DOMAIN_FILTER_MAP = {
"customer_support_operations": "Customer Support Operations",
"incident_response_playbooks": "Incident Response Playbooks",
"platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
}
_STOPWORDS = {
"a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "how", "if", "in",
"into", "is", "it", "its", "of", "on", "or", "that", "the", "their", "them", "there",
"these", "this", "to", "was", "were", "what", "when", "where", "which", "while", "with",
"without", "you", "your",
}
def _tokenize(text: str) -> set[str]:
return {
token
for token in re.findall(r"[a-z0-9]+", text.lower())
if token not in _STOPWORDS and len(token) > 2
}
def _jaccard(left: Iterable[str], right: Iterable[str]) -> float:
left_set = set(left)
right_set = set(right)
if not left_set or not right_set:
return 0.0
union = left_set | right_set
if not union:
return 0.0
return len(left_set & right_set) / len(union)
@dataclass(frozen=True, slots=True)
class DemoCase:
name: str
query: str
positive_chunk_ids: tuple[str, ...]
expected_citations: tuple[str, ...]
preferred_domains: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class TunedChunkScore:
chunk_id: str
base_score: float
tuned_score: float
final_score: float
citation_prior: float
compression_ratio: float
@dataclass(frozen=True, slots=True)
class ContextTuningResult:
mode: str
top_demo_cases: list[str]
suggested_citations: list[str]
tuned_scores: dict[str, TunedChunkScore]
token_dropout: float
leave_one_out: bool
class ContextTunedPlanner:
"""Warm-start a retrieval policy from demonstrations and optimize context weights."""
def __init__(self, retriever: HybridRetriever, corpus: list[Chunk], tasks: list[Task]):
self.retriever = retriever
self.corpus = list(corpus)
self._chunk_map = {chunk.chunk_id: chunk for chunk in self.corpus}
self._demo_cases = self._build_demo_cases(tasks)
self._token_dropout = 0.14
self._train_steps = 28
self._feature_count = 10
def _build_demo_cases(self, tasks: list[Task]) -> list[DemoCase]:
demo_cases: list[DemoCase] = []
query_variants = {
"refund_triage_easy": [
"Refund triage memo after a confirmed outage with billing checks and finance review steps.",
"Business customer outage escalation: verify the billing ledger, incident evidence, and compensation path.",
],
"cross_function_brief_medium": [
"Cross-functional outage brief linking support handling, incident command discipline, and release rollback safeguards.",
"Major outage coordination memo for support, incident response, and release engineering teams.",
],
"executive_escalation_hard": [
"Executive escalation note for compromised admin account response with customer protection and change freeze controls.",
"Severe security incident brief balancing customer harm reduction, evidence preservation, and release safeguards.",
],
}
for task in tasks:
normalized_domain = _DOMAIN_FILTER_MAP.get(task.domain_filter or "", task.domain_filter or "")
preferred_domains = (
(normalized_domain,) if normalized_domain else tuple(sorted({
self._chunk_map[chunk_id].domain
for chunk_id in task.required_artifact_ids
if chunk_id in self._chunk_map
}))
)
base = DemoCase(
name=f"{task.name}_gold",
query=task.query,
positive_chunk_ids=tuple(task.required_artifact_ids),
expected_citations=tuple(task.expected_citation_ids or task.required_artifact_ids),
preferred_domains=preferred_domains,
)
demo_cases.append(base)
for index, variant in enumerate(query_variants.get(task.name, []), start=1):
demo_cases.append(
DemoCase(
name=f"{task.name}_variant_{index}",
query=variant,
positive_chunk_ids=tuple(task.required_artifact_ids),
expected_citations=tuple(task.expected_citation_ids or task.required_artifact_ids),
preferred_domains=preferred_domains,
)
)
return demo_cases
def _demo_similarity(self, query: str, demo: DemoCase) -> float:
query_terms = _tokenize(query)
demo_terms = _tokenize(demo.query)
return _jaccard(query_terms, demo_terms)
def _select_demo_cases(self, query: str, limit: int = 4) -> list[DemoCase]:
ranked = sorted(
self._demo_cases,
key=lambda demo: (-self._demo_similarity(query, demo), demo.name),
)
chosen = ranked[: max(2, min(limit, len(ranked)))]
if all(self._demo_similarity(query, demo) == 0.0 for demo in chosen):
return self._demo_cases[: min(limit, len(self._demo_cases))]
return chosen
def _citation_prior(self, chunk_id: str, demos: list[DemoCase], weights: list[float]) -> float:
if not demos:
return 0.0
matched = 0.0
total = sum(weights) or 1.0
for demo, weight in zip(demos, weights, strict=False):
if chunk_id in demo.expected_citations:
matched += weight
return matched / total
def _domain_prior(self, chunk: Chunk, demos: list[DemoCase], weights: list[float]) -> float:
if not demos:
return 0.0
matched = 0.0
total = sum(weights) or 1.0
for demo, weight in zip(demos, weights, strict=False):
if chunk.domain in demo.preferred_domains:
matched += weight
return matched / total
def _query_chunk_overlap(self, query: str, chunk: Chunk) -> float:
query_terms = _tokenize(query)
chunk_terms = _tokenize(chunk.text) | _tokenize(" ".join(chunk.keywords))
return _jaccard(query_terms, chunk_terms)
def _feature_vector(self, query: str, chunk: Chunk, demos: list[DemoCase], weights: list[float]) -> list[float]:
base = self.retriever.hybrid_score(query, chunk)
bm25 = self.retriever.bm25_score(query, chunk)
keyword = self.retriever.keyword_overlap_score(query, chunk)
token_efficiency = 1.0 - min(chunk.tokens, 700) / 700.0
domain_flags = _DOMAIN_TO_FLAGS.get(chunk.domain, (0.0, 0.0, 0.0))
return [
base,
bm25,
keyword,
self._query_chunk_overlap(query, chunk),
token_efficiency,
domain_flags[0],
domain_flags[1],
domain_flags[2],
self._citation_prior(chunk.chunk_id, demos, weights),
self._domain_prior(chunk, demos, weights),
]
def _context_init(self, query: str, chunks: list[Chunk], demos: list[DemoCase]) -> list[float]:
weights = [0.25 + self._demo_similarity(query, demo) for demo in demos]
positive_acc = [0.0] * self._feature_count
negative_acc = [0.0] * self._feature_count
positive_mass = 0.0
negative_mass = 0.0
for demo, demo_weight in zip(demos, weights, strict=False):
for chunk in chunks:
features = self._feature_vector(demo.query, chunk, demos, weights)
if chunk.chunk_id in demo.positive_chunk_ids:
positive_acc = [value + (demo_weight * feature) for value, feature in zip(positive_acc, features, strict=False)]
positive_mass += demo_weight
else:
negative_acc = [value + (demo_weight * feature) for value, feature in zip(negative_acc, features, strict=False)]
negative_mass += demo_weight
positive_mean = [value / max(positive_mass, 1e-6) for value in positive_acc]
negative_mean = [value / max(negative_mass, 1e-6) for value in negative_acc]
return [positive - negative for positive, negative in zip(positive_mean, negative_mean, strict=False)]
def _stable_seed(self, query: str) -> int:
digest = hashlib.sha256(query.encode("utf-8")).hexdigest()[:8]
return int(digest, 16)
def _optimize_with_torch(self, query: str, chunks: list[Chunk], demos: list[DemoCase]) -> list[float]:
init = self._context_init(query, chunks, demos)
if torch is None or F is None or not chunks:
return init
seed = self._stable_seed(query)
torch.manual_seed(seed)
theta = torch.nn.Parameter(torch.tensor(init, dtype=torch.float32))
optimizer = torch.optim.Adam([theta], lr=0.12)
for _ in range(self._train_steps):
optimizer.zero_grad()
total_loss = torch.tensor(0.0, dtype=torch.float32)
for demo_index, demo in enumerate(demos):
masked_demos = [item for index, item in enumerate(demos) if index != demo_index]
if not masked_demos:
masked_demos = demos
masked_init = torch.tensor(
self._context_init(demo.query, chunks, masked_demos),
dtype=torch.float32,
)
drop_mask = (torch.rand_like(masked_init) > self._token_dropout).float()
drop_mask = drop_mask / max(1e-6, 1.0 - self._token_dropout)
effective_theta = (0.55 * theta + 0.45 * masked_init) * drop_mask
weights = [0.25 + self._demo_similarity(demo.query, item) for item in masked_demos]
matrix = torch.tensor(
[self._feature_vector(demo.query, chunk, masked_demos, weights) for chunk in chunks],
dtype=torch.float32,
)
labels = torch.tensor(
[1.0 if chunk.chunk_id in demo.positive_chunk_ids else 0.0 for chunk in chunks],
dtype=torch.float32,
)
logits = matrix @ effective_theta
positive_weight = 1.0 + (labels.sum().item() / max(1.0, len(labels) - labels.sum().item()))
total_loss = total_loss + F.binary_cross_entropy_with_logits(
logits,
labels,
pos_weight=torch.tensor(positive_weight, dtype=torch.float32),
)
total_loss.backward()
optimizer.step()
return theta.detach().tolist()
def tune(self, query: str, candidate_chunks: list[Chunk]) -> ContextTuningResult:
chunks = list(candidate_chunks)
demos = self._select_demo_cases(query)
demo_weights = [0.25 + self._demo_similarity(query, demo) for demo in demos]
theta_values = self._optimize_with_torch(query, chunks, demos)
mode = "context_tuned_pytorch" if torch is not None else "context_tuned_analytic"
if torch is not None:
theta_tensor = torch.tensor(theta_values, dtype=torch.float32)
matrix_tensor = torch.tensor(
[self._feature_vector(query, chunk, demos, demo_weights) for chunk in chunks],
dtype=torch.float32,
)
tuned_values = torch.sigmoid(matrix_tensor @ theta_tensor).tolist()
else:
tuned_values = []
for chunk in chunks:
features = self._feature_vector(query, chunk, demos, demo_weights)
raw = sum(weight * feature for weight, feature in zip(theta_values, features, strict=False))
tuned_values.append(1.0 / (1.0 + math.exp(-raw)))
tuned_scores: dict[str, TunedChunkScore] = {}
for chunk, tuned_score in zip(chunks, tuned_values, strict=False):
base_score = self.retriever.hybrid_score(query, chunk)
citation_prior = self._citation_prior(chunk.chunk_id, demos, demo_weights)
final_score = max(0.0, min(1.0, (0.40 * base_score) + (0.60 * tuned_score)))
compression_ratio = 0.82 - (0.34 * final_score) - (0.14 * citation_prior)
compression_ratio = max(0.38, min(0.84, compression_ratio))
tuned_scores[chunk.chunk_id] = TunedChunkScore(
chunk_id=chunk.chunk_id,
base_score=round(base_score, 4),
tuned_score=round(float(tuned_score), 4),
final_score=round(final_score, 4),
citation_prior=round(citation_prior, 4),
compression_ratio=round(compression_ratio, 2),
)
ranked = sorted(
tuned_scores.values(),
key=lambda item: (-item.final_score, -item.citation_prior, item.chunk_id),
)
suggested_citations = [item.chunk_id for item in ranked[:3] if item.final_score >= 0.35]
return ContextTuningResult(
mode=mode,
top_demo_cases=[demo.name for demo in demos],
suggested_citations=suggested_citations,
tuned_scores=tuned_scores,
token_dropout=self._token_dropout,
leave_one_out=True,
)