File size: 9,302 Bytes
f866820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
Context shaping module for optimizing retrieved chunks.

Performs:
- Deduplication: Remove semantically similar chunks
- Token budgeting: Allocate tokens based on relevance
- Pruning: Remove irrelevant sentences within chunks
- Compression: Summarize if over budget
"""

from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
import re

# Lazy imports
_sentence_model = None


@dataclass
class ContextShapeResult:
    """Result of context shaping."""
    chunks: List[Dict[str, Any]]
    original_tokens: int
    final_tokens: int
    chunks_removed: int
    compression_applied: bool


def _estimate_tokens(text: str) -> int:
    """Estimate token count (rough: 1 token ≈ 4 chars)."""
    return len(text) // 4


def _get_sentence_model():
    """Lazy load sentence transformer for similarity."""
    global _sentence_model
    if _sentence_model is None:
        try:
            from sentence_transformers import SentenceTransformer
            _sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
        except ImportError:
            return None
    return _sentence_model


def _compute_similarity(text1: str, text2: str) -> float:
    """Compute cosine similarity between two texts."""
    model = _get_sentence_model()
    if model is None:
        return 0.0

    try:
        import numpy as np
        embeddings = model.encode([text1, text2])
        cos_sim = np.dot(embeddings[0], embeddings[1]) / (
            np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1])
        )
        return float(cos_sim)
    except Exception:
        return 0.0


def _split_sentences(text: str) -> List[str]:
    """Split text into sentences."""
    # Simple sentence splitter
    sentences = re.split(r'(?<=[.!?])\s+', text)
    return [s.strip() for s in sentences if s.strip()]


def deduplicate_chunks(
    chunks: List[Dict[str, Any]],
    threshold: float = 0.85
) -> Tuple[List[Dict[str, Any]], int]:
    """
    Remove chunks with high semantic similarity.

    Args:
        chunks: List of chunks
        threshold: Similarity threshold for deduplication

    Returns:
        Tuple of (deduplicated chunks, count removed)
    """
    if len(chunks) <= 1:
        return chunks, 0

    # Keep track of which chunks to keep
    keep_indices = []
    removed = 0

    for i, chunk in enumerate(chunks):
        text_i = chunk.get("text", "")
        is_duplicate = False

        # Compare with already kept chunks
        for j in keep_indices:
            text_j = chunks[j].get("text", "")
            similarity = _compute_similarity(text_i, text_j)

            if similarity >= threshold:
                is_duplicate = True
                removed += 1
                break

        if not is_duplicate:
            keep_indices.append(i)

    return [chunks[i] for i in keep_indices], removed


def budget_chunks(
    chunks: List[Dict[str, Any]],
    token_budget: int,
    min_tokens_per_chunk: int = 50
) -> List[Dict[str, Any]]:
    """
    Allocate token budget across chunks based on relevance scores.

    Args:
        chunks: List of chunks with scores
        token_budget: Total token budget
        min_tokens_per_chunk: Minimum tokens to keep per chunk

    Returns:
        Chunks with text trimmed to fit budget
    """
    if not chunks:
        return []

    # Calculate total relevance for weighting
    total_score = sum(c.get("score", 0.5) for c in chunks)
    if total_score == 0:
        total_score = len(chunks)  # Equal weight

    budgeted = []
    remaining_budget = token_budget

    for chunk in chunks:
        text = chunk.get("text", "")
        score = chunk.get("score", 0.5)

        # Allocate budget proportionally to score
        chunk_budget = int((score / total_score) * token_budget)
        chunk_budget = max(chunk_budget, min_tokens_per_chunk)
        chunk_budget = min(chunk_budget, remaining_budget)

        if chunk_budget <= 0:
            continue

        # Trim text if needed
        current_tokens = _estimate_tokens(text)
        if current_tokens > chunk_budget:
            # Truncate to fit budget (keep first N chars)
            char_limit = chunk_budget * 4
            text = text[:char_limit].rsplit(" ", 1)[0] + "..."

        new_chunk = chunk.copy()
        new_chunk["text"] = text
        new_chunk["budget_allocated"] = chunk_budget
        budgeted.append(new_chunk)

        remaining_budget -= _estimate_tokens(text)
        if remaining_budget <= 0:
            break

    return budgeted


def prune_irrelevant_sentences(
    chunk: Dict[str, Any],
    query: str,
    relevance_threshold: float = 0.3
) -> Dict[str, Any]:
    """
    Remove sentences within a chunk that are not relevant to the query.

    Args:
        chunk: Chunk to prune
        query: Query for relevance comparison
        relevance_threshold: Minimum similarity to keep sentence

    Returns:
        Chunk with irrelevant sentences removed
    """
    text = chunk.get("text", "")
    if not text:
        return chunk

    sentences = _split_sentences(text)
    if len(sentences) <= 1:
        return chunk

    # Score each sentence
    relevant_sentences = []
    for sentence in sentences:
        if len(sentence) < 10:  # Keep short fragments
            relevant_sentences.append(sentence)
            continue

        similarity = _compute_similarity(query, sentence)
        if similarity >= relevance_threshold:
            relevant_sentences.append(sentence)

    if not relevant_sentences:
        # Keep at least the first sentence
        relevant_sentences = sentences[:1]

    new_chunk = chunk.copy()
    new_chunk["text"] = " ".join(relevant_sentences)
    new_chunk["sentences_pruned"] = len(sentences) - len(relevant_sentences)
    return new_chunk


def compress_with_llm(
    chunks: List[Dict[str, Any]],
    query: str,
    target_tokens: int
) -> List[Dict[str, Any]]:
    """
    Compress chunks using LLM summarization.

    Args:
        chunks: Chunks to compress
        query: Query for context-aware compression
        target_tokens: Target token count

    Returns:
        Compressed chunks
    """
    try:
        from src.llm_providers import call_llm
    except ImportError:
        return chunks

    # Combine all chunk texts
    combined = "\n\n".join(c.get("text", "") for c in chunks)
    current_tokens = _estimate_tokens(combined)

    if current_tokens <= target_tokens:
        return chunks

    prompt = f"""Summarize the following context to approximately {target_tokens * 4} characters.
Preserve all key facts relevant to this query: {query}
Keep specific names, numbers, and dates.

Context:
{combined}

Summary:"""

    try:
        response = call_llm(prompt=prompt, temperature=0.0, max_tokens=target_tokens)
        summary = response.get("text", "").strip()

        # Return as single compressed chunk
        return [{
            "id": "compressed_context",
            "text": summary,
            "score": max(c.get("score", 0) for c in chunks),
            "compressed_from": len(chunks)
        }]
    except Exception:
        return chunks


def shape_context(
    chunks: List[Dict[str, Any]],
    query: str,
    token_budget: int = 3000,
    dedup_threshold: float = 0.85,
    enable_pruning: bool = True,
    enable_compression: bool = True,
    relevance_threshold: float = 0.3
) -> ContextShapeResult:
    """
    Shape context by deduplicating, pruning, and compressing chunks.

    Args:
        chunks: Retrieved chunks
        query: User query for relevance
        token_budget: Maximum tokens for final context
        dedup_threshold: Similarity threshold for deduplication
        enable_pruning: Whether to prune irrelevant sentences
        enable_compression: Whether to compress if over budget
        relevance_threshold: Minimum relevance for sentence pruning

    Returns:
        ContextShapeResult with shaped chunks and metadata
    """
    if not chunks:
        return ContextShapeResult(
            chunks=[],
            original_tokens=0,
            final_tokens=0,
            chunks_removed=0,
            compression_applied=False
        )

    # Calculate original token count
    original_tokens = sum(_estimate_tokens(c.get("text", "")) for c in chunks)

    # Step 1: Deduplicate
    deduped, removed = deduplicate_chunks(chunks, threshold=dedup_threshold)

    # Step 2: Prune irrelevant sentences (optional)
    if enable_pruning:
        deduped = [
            prune_irrelevant_sentences(c, query, relevance_threshold)
            for c in deduped
        ]

    # Step 3: Budget allocation
    budgeted = budget_chunks(deduped, token_budget)

    # Step 4: Check if compression needed
    current_tokens = sum(_estimate_tokens(c.get("text", "")) for c in budgeted)
    compression_applied = False

    if enable_compression and current_tokens > token_budget * 1.2:
        budgeted = compress_with_llm(budgeted, query, token_budget)
        compression_applied = True

    final_tokens = sum(_estimate_tokens(c.get("text", "")) for c in budgeted)

    return ContextShapeResult(
        chunks=budgeted,
        original_tokens=original_tokens,
        final_tokens=final_tokens,
        chunks_removed=removed,
        compression_applied=compression_applied
    )