doc-ingestion / src /core /context_optimizer.py
Vamshi Pokala
chore(ci): fix Ruff lint and Mypy for GitHub Actions
40db081
"""Token-budget context packing for RAG prompts."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Sequence, Union
from src.core.reranker import RankedResult
from src.core.retrieval_result import RetrievalResult
from transformers import AutoTokenizer
@dataclass
class OptimizedContext:
"""Chunks selected to fit within a token budget."""
documents: List[Dict[str, Any]] = field(default_factory=list)
total_tokens: int = 0
was_truncated: bool = False
dropped_count: int = 0
def _unwrap_chunk(
item: Union[RetrievalResult, RankedResult, Dict[str, Any]],
) -> RetrievalResult:
if isinstance(item, RankedResult):
return item.result
if isinstance(item, RetrievalResult):
return item
# legacy dict from to_legacy_dict
return RetrievalResult(
id=str(item["id"]),
text=str(item["text"]),
metadata=dict(item.get("metadata") or {}),
fusion_score=float(item.get("score") or 0.0),
sources=list(item.get("sources") or []),
confidence=float(item.get("confidence") or 0.0),
)
class ContextOptimizer:
"""Pack retrieved chunks into a prompt-sized context using a HF tokenizer."""
def __init__(self, max_context_tokens: int = 4000, tokenizer_name: str = "gpt2") -> None:
self.max_context_tokens = max_context_tokens
self.tokenizer_name = tokenizer_name
self._tokenizer = None
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
return self._tokenizer
def _count(self, text: str) -> int:
return len(self.tokenizer.encode(text, add_special_tokens=False))
def compress_document(self, text: str, max_tokens: int) -> str:
"""Word-split truncation to max_tokens with a short trailer."""
if max_tokens <= 0:
return ""
words = text.split()
low, high = 0, len(words)
best = ""
while low <= high:
mid = (low + high) // 2
candidate = " ".join(words[:mid])
if self._count(candidate) <= max_tokens:
best = candidate
low = mid + 1
else:
high = mid - 1
if not best.strip():
return self.tokenizer.decode(self.tokenizer.encode(text, add_special_tokens=False)[:max_tokens]) # type: ignore[no-any-return]
trailer = "\n\n[... truncated for context length ...]"
if self._count(best + trailer) > max_tokens:
return best
return best + trailer
def optimize_context(
self,
query: str,
documents: Sequence[Union[RetrievalResult, RankedResult, Dict[str, Any]]],
) -> OptimizedContext:
"""Greedily add highest-priority chunks until the token budget is exhausted."""
if not documents:
return OptimizedContext(documents=[], total_tokens=self._count(query), was_truncated=False, dropped_count=0)
# Preserve incoming order as priority (caller should pass reranked order)
wrapped: List[RetrievalResult] = [_unwrap_chunk(d) for d in documents]
total_dropped = 0
selected: List[Dict[str, Any]] = []
used = self._count(query)
for doc in wrapped:
block = f"[{doc.id}]\n{doc.text}"
block_tokens = self._count(block)
remaining = self.max_context_tokens - used
if block_tokens <= remaining:
entry: Dict[str, Any] = {
"id": doc.id,
"text": doc.text,
"metadata": doc.metadata,
"fusion_score": doc.fusion_score,
}
selected.append(entry)
used += block_tokens
continue
if remaining > 64:
bracket_prefix = f"[{doc.id}]\n"
max_for_compress = max(remaining - self._count(bracket_prefix), 32)
compressed = self.compress_document(doc.text, max_tokens=max_for_compress)
block2 = f"[{doc.id}]\n{compressed}"
if self._count(block2) <= self.max_context_tokens - used:
selected.append(
{
"id": doc.id,
"text": compressed,
"metadata": doc.metadata,
"fusion_score": doc.fusion_score,
"was_compressed": True,
}
)
used += self._count(block2)
continue
total_dropped += 1
was_truncated = total_dropped > 0 or any(d.get("was_compressed") for d in selected)
return OptimizedContext(
documents=selected,
total_tokens=used,
was_truncated=was_truncated,
dropped_count=total_dropped,
)