File size: 13,754 Bytes
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f06d2ef
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f06d2ef
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
"""
Beam search inversion engine for ZSInvert.

Cosine-similarity-guided beam search that reconstructs text
from an embedding vector using a small LLM as the token
proposal engine.

Part of E04: ZSInvert.
"""

from __future__ import annotations

import random
from dataclasses import dataclass, field
from typing import Callable

import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache

from model import get_chat_format

# Tokens to mask from generation (special/formatting tokens)
_MASK_STRINGS = [
    "<|im_end|>", "<|end_header_id|>", "<|start_header_id|>",
    "<|eot_id|>", "<|eom_id|>", "<|python_tag|>",
    "@", "\xa0", '"', "\n", "\n\n", " \n\n",
]

# Number of top beams kept deterministically in randomness mode
_FIXED_KEEP = 5


@dataclass
class Candidate:
    """A beam search candidate."""
    token_ids: list[int] = field(default_factory=list)
    seq_str: str = ""
    score: float = 0.0
    cos_sim: float = 0.0
    kv_cache: DynamicCache | None = field(default=None, repr=False)


@dataclass
class InversionResult:
    """Result of a full inversion run."""
    original_text: str | None = None
    target_embedding: torch.Tensor | None = None
    stage1_text: str = ""
    stage1_cos_sim: float = 0.0
    stage2_text: str = ""
    stage2_cos_sim: float = 0.0


def _top_k_top_p_filter(logits: torch.Tensor, top_k: int, top_p: float) -> list[int]:
    """Return indices that survive top-k and top-p filtering."""
    # Top-k: keep only top_k highest logits
    topk_vals, topk_idx = torch.topk(logits, min(top_k, logits.size(-1)))

    # Top-p (nucleus): keep smallest set whose cumulative prob >= top_p
    probs = F.softmax(topk_vals, dim=-1)
    cumulative = torch.cumsum(probs, dim=-1)
    # Mask tokens beyond the nucleus
    mask = cumulative - probs <= top_p
    filtered_idx = topk_idx[mask]

    return filtered_idx.tolist()


_cached_mask_ids: list[int] | None = None


def _build_mask_token_ids(tokenizer: AutoTokenizer) -> list[int]:
    """Build set of token IDs to suppress during generation. Cached.

    Masks both exact single-token matches for _MASK_STRINGS and any
    vocab token whose decoded form contains a newline (catches merged
    tokens like '.\\n' that bypass the single-token check).
    """
    global _cached_mask_ids
    if _cached_mask_ids is not None:
        return _cached_mask_ids

    mask_ids = set()
    for s in _MASK_STRINGS:
        tokens = list(tokenizer.encode(s, add_special_tokens=False))
        if len(tokens) == 1:
            mask_ids.add(tokens[0])
    if tokenizer.eos_token_id is not None:
        mask_ids.add(tokenizer.eos_token_id)
    # Also mask any vocab token containing a newline
    for tid in range(tokenizer.vocab_size):
        decoded = tokenizer.decode([tid])
        if "\n" in decoded:
            mask_ids.add(tid)
    _cached_mask_ids = list(mask_ids)
    return _cached_mask_ids


def _get_next_token_candidates(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prefix: list[int],
    suffix: list[int],
    prompt_tokens: list[int],
    candidates: list[Candidate],
    top_k: int,
    top_p: float,
    repetition_penalty: float,
    mask_ids: list[int],
) -> list[list[tuple[int, float]]]:
    """Forward pass through LLM to get candidate next tokens.

    Builds input as: prefix + prompt_tokens + suffix + candidate.token_ids
    Uses KV-cache from candidates when available.

    Returns list of [(token_id, log_prob), ...] per candidate.
    """
    device = next(model.parameters()).device

    # Build full token sequences
    base = prefix + prompt_tokens + suffix
    batch_tokens = [base + c.token_ids for c in candidates]

    # All sequences should have the same length (beam search invariant)
    assert len(set(len(t) for t in batch_tokens)) == 1

    input_ids = torch.tensor(batch_tokens, device=device)

    # Check for usable KV-cache
    batch_kv = [c.kv_cache for c in candidates]
    use_cache = all(kv is not None for kv in batch_kv)

    if use_cache:
        kv_cache = DynamicCache.from_batch_splits(batch_kv)
        cache_len = kv_cache.get_seq_length()
        model_input = input_ids[:, cache_len:]
        attn_mask = torch.ones_like(input_ids, device=device)
    else:
        kv_cache = DynamicCache()
        model_input = input_ids
        attn_mask = None

    with torch.no_grad():
        outputs = model(
            input_ids=model_input,
            attention_mask=attn_mask,
            past_key_values=kv_cache,
            use_cache=True,
        )

    # Split KV-cache back per candidate
    next_kv = outputs.past_key_values
    try:
        split_kv = next_kv.batch_split(len(candidates), 1) if next_kv else [None] * len(candidates)
    except Exception:
        split_kv = [None] * len(candidates)

    logits = outputs.logits[:, -1, :]  # (batch, vocab)

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        for i, tokens in enumerate(batch_tokens):
            for tid in set(tokens):
                if logits[i, tid] > 0:
                    logits[i, tid] /= repetition_penalty
                else:
                    logits[i, tid] *= repetition_penalty

    # Mask special tokens
    logits[:, mask_ids] = -1e10

    log_probs = F.log_softmax(logits, dim=-1)

    results = []
    for i in range(len(candidates)):
        filtered = _top_k_top_p_filter(logits[i], top_k, top_p)
        pairs = [(tid, log_probs[i, tid].item()) for tid in filtered]
        pairs.sort(key=lambda x: x[1], reverse=True)
        results.append(pairs)

    return results, split_kv


def _score_candidates(
    encoder: SentenceTransformer,
    target_embedding: torch.Tensor,
    candidates: list[Candidate],
) -> None:
    """Score candidates by cosine similarity to target embedding. Mutates in place."""
    if not candidates:
        return

    texts = [c.seq_str for c in candidates]
    embs = encoder.encode(texts, convert_to_tensor=True, normalize_embeddings=True)

    # target_embedding shape: (1, dim) — broadcast
    target_norm = F.normalize(target_embedding, dim=-1)
    sims = torch.matmul(embs, target_norm.squeeze(0))  # (batch,)

    for i, c in enumerate(candidates):
        c.cos_sim = sims[i].item()
        c.score = c.cos_sim


def beam_search(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    encoder: SentenceTransformer,
    target_embedding: torch.Tensor,
    prompt: str,
    beam_width: int = 30,
    max_steps: int = 0,
    top_k: int = 30,
    top_p: float = 1.0,
    repetition_penalty: float = 1.5,
    randomness: bool = True,
    patience: int = 5,
    min_similarity: float = 0.0,
    on_step: Callable | None = None,
) -> Candidate:
    """Run cosine-similarity-guided beam search.

    Args:
        model: Generator LLM.
        tokenizer: LLM tokenizer.
        encoder: Embedding encoder for scoring.
        target_embedding: Target embedding to invert. Shape (1, dim).
        prompt: User-facing prompt (becomes chat user message).
        beam_width: Number of candidates to maintain per step.
        max_steps: Maximum tokens to generate. 0 means no limit (stop via patience only).
        top_k: Top-k tokens to consider per expansion.
        top_p: Nucleus sampling threshold.
        repetition_penalty: Penalty for repeated tokens in logits.
        randomness: If True, keep top 5 deterministically + sample rest.
        patience: Stop after this many steps with no improvement in best cosine sim.
            Set to 0 to disable early stopping.
        min_similarity: Stop immediately when cosine sim reaches this threshold.
            Set to 0.0 to disable.
        on_step: Callback(step, best_candidate) fired each step.

    Returns:
        Best candidate found during search.
    """
    prefix, suffix = get_chat_format(tokenizer)
    prompt_tokens = list(tokenizer.encode(prompt, add_special_tokens=False))
    mask_ids = _build_mask_token_ids(tokenizer)

    candidates = [Candidate()]
    best_complete: Candidate | None = None
    best_ever: Candidate | None = None
    steps_since_improvement = 0

    step = 0
    while max_steps <= 0 or step < max_steps:
        step += 1
        # Expand: get next-token proposals for each candidate
        token_proposals, split_kv = _get_next_token_candidates(
            model, tokenizer, prefix, suffix, prompt_tokens,
            candidates, top_k, top_p, repetition_penalty, mask_ids,
        )

        # Build expanded candidates
        expanded: list[Candidate] = []
        for i, cand in enumerate(candidates):
            for tid, _logp in token_proposals[i]:
                new_ids = cand.token_ids + [tid]
                expanded.append(Candidate(
                    token_ids=new_ids,
                    seq_str=tokenizer.decode(new_ids),
                    kv_cache=split_kv[i] if split_kv[i] is not None else None,
                ))

        # Score by cosine similarity
        _score_candidates(encoder, target_embedding, expanded)

        # Sort by score descending
        expanded.sort(key=lambda c: c.score, reverse=True)

        # Track best-ever candidate (highest cosine sim at any step)
        step_best = expanded[0]
        if best_ever is None or step_best.cos_sim > best_ever.cos_sim:
            best_ever = Candidate(
                token_ids=list(step_best.token_ids),
                seq_str=step_best.seq_str,
                score=step_best.score,
                cos_sim=step_best.cos_sim,
            )
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
            if patience > 0 and steps_since_improvement >= patience:
                break

        if min_similarity > 0 and best_ever.cos_sim >= min_similarity:
            break

        # Track best complete sentence
        for c in expanded:
            if c.seq_str and c.seq_str.rstrip()[-1:] in ".?!":
                if best_complete is None or c.score > best_complete.score:
                    best_complete = Candidate(
                        token_ids=list(c.token_ids),
                        seq_str=c.seq_str,
                        score=c.score,
                        cos_sim=c.cos_sim,
                    )

        # Select: top beam_width candidates (with optional randomness)
        if randomness and len(expanded) > _FIXED_KEEP:
            keep = min(_FIXED_KEEP, beam_width)
            remainder = min(beam_width - keep, len(expanded) - keep)
            candidates = expanded[:keep]
            if remainder > 0:
                candidates += random.sample(expanded[keep:], remainder)
        else:
            candidates = expanded[:beam_width]

        # Callback
        if on_step is not None:
            best_so_far = best_complete if best_complete else candidates[0]
            on_step(step, best_so_far)

    # Return the candidate with the highest cosine similarity across all tracking
    finalists = [c for c in [best_ever, best_complete, candidates[0]] if c is not None]
    return max(finalists, key=lambda c: c.cos_sim)


_STAGE1_PROMPT = "tell me a story"
_STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}"


def invert(
    text: str,
    encoder_name: str = "gte",
    beam_width: int = 30,
    max_steps: int = 0,
    top_k: int = 30,
    two_stage: bool = True,
    on_progress: Callable | None = None,
) -> InversionResult:
    """Run the full two-stage ZSInvert inversion pipeline.

    Stage 1: Seed generation with a generic prompt.
    Stage 2: Paraphrase refinement using the Stage 1 output as context.

    Args:
        text: Input text to encode and then invert.
        encoder_name: Which embedding encoder to use ("gte", "gtr", "contriever").
        beam_width: Beam search width.
        max_steps: Maximum tokens per stage.
        top_k: Top-k tokens per expansion step.
        two_stage: If True, run both stages. If False, Stage 1 only.
        on_progress: Callback(stage, step, best_candidate) for UI updates.
            stage is 1 or 2, step is the beam search step index.

    Returns:
        InversionResult with results from both stages.
    """
    from model import load_llm, load_encoder, encode_text

    model, tokenizer = load_llm()
    encoder = load_encoder(encoder_name)
    target_embedding = encode_text(text, encoder)

    # Stage 1: seed generation
    def stage1_callback(step: int, cand: Candidate) -> None:
        if on_progress is not None:
            on_progress(1, step, cand)

    stage1 = beam_search(
        model, tokenizer, encoder, target_embedding,
        prompt=_STAGE1_PROMPT,
        beam_width=beam_width,
        max_steps=max_steps,
        top_k=top_k,
        randomness=True,
        on_step=stage1_callback,
    )

    result = InversionResult(
        original_text=text,
        target_embedding=target_embedding,
        stage1_text=stage1.seq_str,
        stage1_cos_sim=stage1.cos_sim,
    )

    if not two_stage:
        result.stage2_text = result.stage1_text
        result.stage2_cos_sim = result.stage1_cos_sim
        return result

    # Stage 2: paraphrase refinement
    def stage2_callback(step: int, cand: Candidate) -> None:
        if on_progress is not None:
            on_progress(2, step, cand)

    stage2_prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=stage1.seq_str)
    stage2 = beam_search(
        model, tokenizer, encoder, target_embedding,
        prompt=stage2_prompt,
        beam_width=beam_width,
        max_steps=max_steps,
        top_k=top_k,
        randomness=True,
        on_step=stage2_callback,
    )

    result.stage2_text = stage2.seq_str
    result.stage2_cos_sim = stage2.cos_sim
    return result