File size: 14,407 Bytes
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""
PyTorch-backed context tuning for retrieval and prompt optimization.

This module adapts the paper's core idea to this benchmark:
- initialize a lightweight context representation from task-specific demonstrations
- optimize that context rather than the underlying language model
- apply leave-one-out masking and token dropout for regularized tuning
"""

from __future__ import annotations

from dataclasses import dataclass
import hashlib
import math
import re
from typing import Iterable

from rag_optimizer_env.corpus import Chunk
from rag_optimizer_env.retriever import HybridRetriever
from rag_optimizer_env.tasks import Task

try:
    import torch
    import torch.nn.functional as F
except ModuleNotFoundError:  # pragma: no cover - optional dependency at runtime
    torch = None
    F = None


_DOMAIN_TO_FLAGS = {
    "Customer Support Operations": (1.0, 0.0, 0.0),
    "Incident Response Playbooks": (0.0, 1.0, 0.0),
    "Platform Reliability & Release Engineering": (0.0, 0.0, 1.0),
}

_DOMAIN_FILTER_MAP = {
    "customer_support_operations": "Customer Support Operations",
    "incident_response_playbooks": "Incident Response Playbooks",
    "platform_reliability_release_engineering": "Platform Reliability & Release Engineering",
}

_STOPWORDS = {
    "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "how", "if", "in",
    "into", "is", "it", "its", "of", "on", "or", "that", "the", "their", "them", "there",
    "these", "this", "to", "was", "were", "what", "when", "where", "which", "while", "with",
    "without", "you", "your",
}


def _tokenize(text: str) -> set[str]:
    return {
        token
        for token in re.findall(r"[a-z0-9]+", text.lower())
        if token not in _STOPWORDS and len(token) > 2
    }


def _jaccard(left: Iterable[str], right: Iterable[str]) -> float:
    left_set = set(left)
    right_set = set(right)
    if not left_set or not right_set:
        return 0.0
    union = left_set | right_set
    if not union:
        return 0.0
    return len(left_set & right_set) / len(union)


@dataclass(frozen=True, slots=True)
class DemoCase:
    name: str
    query: str
    positive_chunk_ids: tuple[str, ...]
    expected_citations: tuple[str, ...]
    preferred_domains: tuple[str, ...]


@dataclass(frozen=True, slots=True)
class TunedChunkScore:
    chunk_id: str
    base_score: float
    tuned_score: float
    final_score: float
    citation_prior: float
    compression_ratio: float


@dataclass(frozen=True, slots=True)
class ContextTuningResult:
    mode: str
    top_demo_cases: list[str]
    suggested_citations: list[str]
    tuned_scores: dict[str, TunedChunkScore]
    token_dropout: float
    leave_one_out: bool


class ContextTunedPlanner:
    """Warm-start a retrieval policy from demonstrations and optimize context weights."""

    def __init__(self, retriever: HybridRetriever, corpus: list[Chunk], tasks: list[Task]):
        self.retriever = retriever
        self.corpus = list(corpus)
        self._chunk_map = {chunk.chunk_id: chunk for chunk in self.corpus}
        self._demo_cases = self._build_demo_cases(tasks)
        self._token_dropout = 0.14
        self._train_steps = 28
        self._feature_count = 10

    def _build_demo_cases(self, tasks: list[Task]) -> list[DemoCase]:
        demo_cases: list[DemoCase] = []
        query_variants = {
            "refund_triage_easy": [
                "Refund triage memo after a confirmed outage with billing checks and finance review steps.",
                "Business customer outage escalation: verify the billing ledger, incident evidence, and compensation path.",
            ],
            "cross_function_brief_medium": [
                "Cross-functional outage brief linking support handling, incident command discipline, and release rollback safeguards.",
                "Major outage coordination memo for support, incident response, and release engineering teams.",
            ],
            "executive_escalation_hard": [
                "Executive escalation note for compromised admin account response with customer protection and change freeze controls.",
                "Severe security incident brief balancing customer harm reduction, evidence preservation, and release safeguards.",
            ],
        }
        for task in tasks:
            normalized_domain = _DOMAIN_FILTER_MAP.get(task.domain_filter or "", task.domain_filter or "")
            preferred_domains = (
                (normalized_domain,) if normalized_domain else tuple(sorted({
                    self._chunk_map[chunk_id].domain
                    for chunk_id in task.required_artifact_ids
                    if chunk_id in self._chunk_map
                }))
            )
            base = DemoCase(
                name=f"{task.name}_gold",
                query=task.query,
                positive_chunk_ids=tuple(task.required_artifact_ids),
                expected_citations=tuple(task.expected_citation_ids or task.required_artifact_ids),
                preferred_domains=preferred_domains,
            )
            demo_cases.append(base)
            for index, variant in enumerate(query_variants.get(task.name, []), start=1):
                demo_cases.append(
                    DemoCase(
                        name=f"{task.name}_variant_{index}",
                        query=variant,
                        positive_chunk_ids=tuple(task.required_artifact_ids),
                        expected_citations=tuple(task.expected_citation_ids or task.required_artifact_ids),
                        preferred_domains=preferred_domains,
                    )
                )
        return demo_cases

    def _demo_similarity(self, query: str, demo: DemoCase) -> float:
        query_terms = _tokenize(query)
        demo_terms = _tokenize(demo.query)
        return _jaccard(query_terms, demo_terms)

    def _select_demo_cases(self, query: str, limit: int = 4) -> list[DemoCase]:
        ranked = sorted(
            self._demo_cases,
            key=lambda demo: (-self._demo_similarity(query, demo), demo.name),
        )
        chosen = ranked[: max(2, min(limit, len(ranked)))]
        if all(self._demo_similarity(query, demo) == 0.0 for demo in chosen):
            return self._demo_cases[: min(limit, len(self._demo_cases))]
        return chosen

    def _citation_prior(self, chunk_id: str, demos: list[DemoCase], weights: list[float]) -> float:
        if not demos:
            return 0.0
        matched = 0.0
        total = sum(weights) or 1.0
        for demo, weight in zip(demos, weights, strict=False):
            if chunk_id in demo.expected_citations:
                matched += weight
        return matched / total

    def _domain_prior(self, chunk: Chunk, demos: list[DemoCase], weights: list[float]) -> float:
        if not demos:
            return 0.0
        matched = 0.0
        total = sum(weights) or 1.0
        for demo, weight in zip(demos, weights, strict=False):
            if chunk.domain in demo.preferred_domains:
                matched += weight
        return matched / total

    def _query_chunk_overlap(self, query: str, chunk: Chunk) -> float:
        query_terms = _tokenize(query)
        chunk_terms = _tokenize(chunk.text) | _tokenize(" ".join(chunk.keywords))
        return _jaccard(query_terms, chunk_terms)

    def _feature_vector(self, query: str, chunk: Chunk, demos: list[DemoCase], weights: list[float]) -> list[float]:
        base = self.retriever.hybrid_score(query, chunk)
        bm25 = self.retriever.bm25_score(query, chunk)
        keyword = self.retriever.keyword_overlap_score(query, chunk)
        token_efficiency = 1.0 - min(chunk.tokens, 700) / 700.0
        domain_flags = _DOMAIN_TO_FLAGS.get(chunk.domain, (0.0, 0.0, 0.0))
        return [
            base,
            bm25,
            keyword,
            self._query_chunk_overlap(query, chunk),
            token_efficiency,
            domain_flags[0],
            domain_flags[1],
            domain_flags[2],
            self._citation_prior(chunk.chunk_id, demos, weights),
            self._domain_prior(chunk, demos, weights),
        ]

    def _context_init(self, query: str, chunks: list[Chunk], demos: list[DemoCase]) -> list[float]:
        weights = [0.25 + self._demo_similarity(query, demo) for demo in demos]
        positive_acc = [0.0] * self._feature_count
        negative_acc = [0.0] * self._feature_count
        positive_mass = 0.0
        negative_mass = 0.0

        for demo, demo_weight in zip(demos, weights, strict=False):
            for chunk in chunks:
                features = self._feature_vector(demo.query, chunk, demos, weights)
                if chunk.chunk_id in demo.positive_chunk_ids:
                    positive_acc = [value + (demo_weight * feature) for value, feature in zip(positive_acc, features, strict=False)]
                    positive_mass += demo_weight
                else:
                    negative_acc = [value + (demo_weight * feature) for value, feature in zip(negative_acc, features, strict=False)]
                    negative_mass += demo_weight

        positive_mean = [value / max(positive_mass, 1e-6) for value in positive_acc]
        negative_mean = [value / max(negative_mass, 1e-6) for value in negative_acc]
        return [positive - negative for positive, negative in zip(positive_mean, negative_mean, strict=False)]

    def _stable_seed(self, query: str) -> int:
        digest = hashlib.sha256(query.encode("utf-8")).hexdigest()[:8]
        return int(digest, 16)

    def _optimize_with_torch(self, query: str, chunks: list[Chunk], demos: list[DemoCase]) -> list[float]:
        init = self._context_init(query, chunks, demos)
        if torch is None or F is None or not chunks:
            return init

        seed = self._stable_seed(query)
        torch.manual_seed(seed)
        theta = torch.nn.Parameter(torch.tensor(init, dtype=torch.float32))
        optimizer = torch.optim.Adam([theta], lr=0.12)

        for _ in range(self._train_steps):
            optimizer.zero_grad()
            total_loss = torch.tensor(0.0, dtype=torch.float32)
            for demo_index, demo in enumerate(demos):
                masked_demos = [item for index, item in enumerate(demos) if index != demo_index]
                if not masked_demos:
                    masked_demos = demos
                masked_init = torch.tensor(
                    self._context_init(demo.query, chunks, masked_demos),
                    dtype=torch.float32,
                )
                drop_mask = (torch.rand_like(masked_init) > self._token_dropout).float()
                drop_mask = drop_mask / max(1e-6, 1.0 - self._token_dropout)
                effective_theta = (0.55 * theta + 0.45 * masked_init) * drop_mask

                weights = [0.25 + self._demo_similarity(demo.query, item) for item in masked_demos]
                matrix = torch.tensor(
                    [self._feature_vector(demo.query, chunk, masked_demos, weights) for chunk in chunks],
                    dtype=torch.float32,
                )
                labels = torch.tensor(
                    [1.0 if chunk.chunk_id in demo.positive_chunk_ids else 0.0 for chunk in chunks],
                    dtype=torch.float32,
                )
                logits = matrix @ effective_theta
                positive_weight = 1.0 + (labels.sum().item() / max(1.0, len(labels) - labels.sum().item()))
                total_loss = total_loss + F.binary_cross_entropy_with_logits(
                    logits,
                    labels,
                    pos_weight=torch.tensor(positive_weight, dtype=torch.float32),
                )

            total_loss.backward()
            optimizer.step()

        return theta.detach().tolist()

    def tune(self, query: str, candidate_chunks: list[Chunk]) -> ContextTuningResult:
        chunks = list(candidate_chunks)
        demos = self._select_demo_cases(query)
        demo_weights = [0.25 + self._demo_similarity(query, demo) for demo in demos]
        theta_values = self._optimize_with_torch(query, chunks, demos)
        mode = "context_tuned_pytorch" if torch is not None else "context_tuned_analytic"

        if torch is not None:
            theta_tensor = torch.tensor(theta_values, dtype=torch.float32)
            matrix_tensor = torch.tensor(
                [self._feature_vector(query, chunk, demos, demo_weights) for chunk in chunks],
                dtype=torch.float32,
            )
            tuned_values = torch.sigmoid(matrix_tensor @ theta_tensor).tolist()
        else:
            tuned_values = []
            for chunk in chunks:
                features = self._feature_vector(query, chunk, demos, demo_weights)
                raw = sum(weight * feature for weight, feature in zip(theta_values, features, strict=False))
                tuned_values.append(1.0 / (1.0 + math.exp(-raw)))

        tuned_scores: dict[str, TunedChunkScore] = {}
        for chunk, tuned_score in zip(chunks, tuned_values, strict=False):
            base_score = self.retriever.hybrid_score(query, chunk)
            citation_prior = self._citation_prior(chunk.chunk_id, demos, demo_weights)
            final_score = max(0.0, min(1.0, (0.40 * base_score) + (0.60 * tuned_score)))
            compression_ratio = 0.82 - (0.34 * final_score) - (0.14 * citation_prior)
            compression_ratio = max(0.38, min(0.84, compression_ratio))
            tuned_scores[chunk.chunk_id] = TunedChunkScore(
                chunk_id=chunk.chunk_id,
                base_score=round(base_score, 4),
                tuned_score=round(float(tuned_score), 4),
                final_score=round(final_score, 4),
                citation_prior=round(citation_prior, 4),
                compression_ratio=round(compression_ratio, 2),
            )

        ranked = sorted(
            tuned_scores.values(),
            key=lambda item: (-item.final_score, -item.citation_prior, item.chunk_id),
        )
        suggested_citations = [item.chunk_id for item in ranked[:3] if item.final_score >= 0.35]
        return ContextTuningResult(
            mode=mode,
            top_demo_cases=[demo.name for demo in demos],
            suggested_citations=suggested_citations,
            tuned_scores=tuned_scores,
            token_dropout=self._token_dropout,
            leave_one_out=True,
        )