Spaces:
Sleeping
Sleeping
File size: 9,302 Bytes
f866820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
"""
Context shaping module for optimizing retrieved chunks.
Performs:
- Deduplication: Remove semantically similar chunks
- Token budgeting: Allocate tokens based on relevance
- Pruning: Remove irrelevant sentences within chunks
- Compression: Summarize if over budget
"""
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
import re
# Lazy imports
_sentence_model = None
@dataclass
class ContextShapeResult:
"""Result of context shaping."""
chunks: List[Dict[str, Any]]
original_tokens: int
final_tokens: int
chunks_removed: int
compression_applied: bool
def _estimate_tokens(text: str) -> int:
"""Estimate token count (rough: 1 token ≈ 4 chars)."""
return len(text) // 4
def _get_sentence_model():
"""Lazy load sentence transformer for similarity."""
global _sentence_model
if _sentence_model is None:
try:
from sentence_transformers import SentenceTransformer
_sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
except ImportError:
return None
return _sentence_model
def _compute_similarity(text1: str, text2: str) -> float:
"""Compute cosine similarity between two texts."""
model = _get_sentence_model()
if model is None:
return 0.0
try:
import numpy as np
embeddings = model.encode([text1, text2])
cos_sim = np.dot(embeddings[0], embeddings[1]) / (
np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1])
)
return float(cos_sim)
except Exception:
return 0.0
def _split_sentences(text: str) -> List[str]:
"""Split text into sentences."""
# Simple sentence splitter
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
def deduplicate_chunks(
chunks: List[Dict[str, Any]],
threshold: float = 0.85
) -> Tuple[List[Dict[str, Any]], int]:
"""
Remove chunks with high semantic similarity.
Args:
chunks: List of chunks
threshold: Similarity threshold for deduplication
Returns:
Tuple of (deduplicated chunks, count removed)
"""
if len(chunks) <= 1:
return chunks, 0
# Keep track of which chunks to keep
keep_indices = []
removed = 0
for i, chunk in enumerate(chunks):
text_i = chunk.get("text", "")
is_duplicate = False
# Compare with already kept chunks
for j in keep_indices:
text_j = chunks[j].get("text", "")
similarity = _compute_similarity(text_i, text_j)
if similarity >= threshold:
is_duplicate = True
removed += 1
break
if not is_duplicate:
keep_indices.append(i)
return [chunks[i] for i in keep_indices], removed
def budget_chunks(
chunks: List[Dict[str, Any]],
token_budget: int,
min_tokens_per_chunk: int = 50
) -> List[Dict[str, Any]]:
"""
Allocate token budget across chunks based on relevance scores.
Args:
chunks: List of chunks with scores
token_budget: Total token budget
min_tokens_per_chunk: Minimum tokens to keep per chunk
Returns:
Chunks with text trimmed to fit budget
"""
if not chunks:
return []
# Calculate total relevance for weighting
total_score = sum(c.get("score", 0.5) for c in chunks)
if total_score == 0:
total_score = len(chunks) # Equal weight
budgeted = []
remaining_budget = token_budget
for chunk in chunks:
text = chunk.get("text", "")
score = chunk.get("score", 0.5)
# Allocate budget proportionally to score
chunk_budget = int((score / total_score) * token_budget)
chunk_budget = max(chunk_budget, min_tokens_per_chunk)
chunk_budget = min(chunk_budget, remaining_budget)
if chunk_budget <= 0:
continue
# Trim text if needed
current_tokens = _estimate_tokens(text)
if current_tokens > chunk_budget:
# Truncate to fit budget (keep first N chars)
char_limit = chunk_budget * 4
text = text[:char_limit].rsplit(" ", 1)[0] + "..."
new_chunk = chunk.copy()
new_chunk["text"] = text
new_chunk["budget_allocated"] = chunk_budget
budgeted.append(new_chunk)
remaining_budget -= _estimate_tokens(text)
if remaining_budget <= 0:
break
return budgeted
def prune_irrelevant_sentences(
chunk: Dict[str, Any],
query: str,
relevance_threshold: float = 0.3
) -> Dict[str, Any]:
"""
Remove sentences within a chunk that are not relevant to the query.
Args:
chunk: Chunk to prune
query: Query for relevance comparison
relevance_threshold: Minimum similarity to keep sentence
Returns:
Chunk with irrelevant sentences removed
"""
text = chunk.get("text", "")
if not text:
return chunk
sentences = _split_sentences(text)
if len(sentences) <= 1:
return chunk
# Score each sentence
relevant_sentences = []
for sentence in sentences:
if len(sentence) < 10: # Keep short fragments
relevant_sentences.append(sentence)
continue
similarity = _compute_similarity(query, sentence)
if similarity >= relevance_threshold:
relevant_sentences.append(sentence)
if not relevant_sentences:
# Keep at least the first sentence
relevant_sentences = sentences[:1]
new_chunk = chunk.copy()
new_chunk["text"] = " ".join(relevant_sentences)
new_chunk["sentences_pruned"] = len(sentences) - len(relevant_sentences)
return new_chunk
def compress_with_llm(
chunks: List[Dict[str, Any]],
query: str,
target_tokens: int
) -> List[Dict[str, Any]]:
"""
Compress chunks using LLM summarization.
Args:
chunks: Chunks to compress
query: Query for context-aware compression
target_tokens: Target token count
Returns:
Compressed chunks
"""
try:
from src.llm_providers import call_llm
except ImportError:
return chunks
# Combine all chunk texts
combined = "\n\n".join(c.get("text", "") for c in chunks)
current_tokens = _estimate_tokens(combined)
if current_tokens <= target_tokens:
return chunks
prompt = f"""Summarize the following context to approximately {target_tokens * 4} characters.
Preserve all key facts relevant to this query: {query}
Keep specific names, numbers, and dates.
Context:
{combined}
Summary:"""
try:
response = call_llm(prompt=prompt, temperature=0.0, max_tokens=target_tokens)
summary = response.get("text", "").strip()
# Return as single compressed chunk
return [{
"id": "compressed_context",
"text": summary,
"score": max(c.get("score", 0) for c in chunks),
"compressed_from": len(chunks)
}]
except Exception:
return chunks
def shape_context(
chunks: List[Dict[str, Any]],
query: str,
token_budget: int = 3000,
dedup_threshold: float = 0.85,
enable_pruning: bool = True,
enable_compression: bool = True,
relevance_threshold: float = 0.3
) -> ContextShapeResult:
"""
Shape context by deduplicating, pruning, and compressing chunks.
Args:
chunks: Retrieved chunks
query: User query for relevance
token_budget: Maximum tokens for final context
dedup_threshold: Similarity threshold for deduplication
enable_pruning: Whether to prune irrelevant sentences
enable_compression: Whether to compress if over budget
relevance_threshold: Minimum relevance for sentence pruning
Returns:
ContextShapeResult with shaped chunks and metadata
"""
if not chunks:
return ContextShapeResult(
chunks=[],
original_tokens=0,
final_tokens=0,
chunks_removed=0,
compression_applied=False
)
# Calculate original token count
original_tokens = sum(_estimate_tokens(c.get("text", "")) for c in chunks)
# Step 1: Deduplicate
deduped, removed = deduplicate_chunks(chunks, threshold=dedup_threshold)
# Step 2: Prune irrelevant sentences (optional)
if enable_pruning:
deduped = [
prune_irrelevant_sentences(c, query, relevance_threshold)
for c in deduped
]
# Step 3: Budget allocation
budgeted = budget_chunks(deduped, token_budget)
# Step 4: Check if compression needed
current_tokens = sum(_estimate_tokens(c.get("text", "")) for c in budgeted)
compression_applied = False
if enable_compression and current_tokens > token_budget * 1.2:
budgeted = compress_with_llm(budgeted, query, token_budget)
compression_applied = True
final_tokens = sum(_estimate_tokens(c.get("text", "")) for c in budgeted)
return ContextShapeResult(
chunks=budgeted,
original_tokens=original_tokens,
final_tokens=final_tokens,
chunks_removed=removed,
compression_applied=compression_applied
)
|