""" Beam search inversion engine for ZSInvert. Cosine-similarity-guided beam search that reconstructs text from an embedding vector using a small LLM as the token proposal engine. Part of E04: ZSInvert. """ from __future__ import annotations import random from dataclasses import dataclass, field from typing import Callable import torch import torch.nn.functional as F from sentence_transformers import SentenceTransformer from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache from model import get_chat_format # Tokens to mask from generation (special/formatting tokens) _MASK_STRINGS = [ "<|im_end|>", "<|end_header_id|>", "<|start_header_id|>", "<|eot_id|>", "<|eom_id|>", "<|python_tag|>", "@", "\xa0", '"', "\n", "\n\n", " \n\n", ] # Number of top beams kept deterministically in randomness mode _FIXED_KEEP = 5 @dataclass class Candidate: """A beam search candidate.""" token_ids: list[int] = field(default_factory=list) seq_str: str = "" score: float = 0.0 cos_sim: float = 0.0 kv_cache: DynamicCache | None = field(default=None, repr=False) @dataclass class InversionResult: """Result of a full inversion run.""" original_text: str | None = None target_embedding: torch.Tensor | None = None stage1_text: str = "" stage1_cos_sim: float = 0.0 stage2_text: str = "" stage2_cos_sim: float = 0.0 def _top_k_top_p_filter(logits: torch.Tensor, top_k: int, top_p: float) -> list[int]: """Return indices that survive top-k and top-p filtering.""" # Top-k: keep only top_k highest logits topk_vals, topk_idx = torch.topk(logits, min(top_k, logits.size(-1))) # Top-p (nucleus): keep smallest set whose cumulative prob >= top_p probs = F.softmax(topk_vals, dim=-1) cumulative = torch.cumsum(probs, dim=-1) # Mask tokens beyond the nucleus mask = cumulative - probs <= top_p filtered_idx = topk_idx[mask] return filtered_idx.tolist() _cached_mask_ids: list[int] | None = None def _build_mask_token_ids(tokenizer: AutoTokenizer) -> list[int]: """Build set of token IDs to suppress during generation. Cached. Masks both exact single-token matches for _MASK_STRINGS and any vocab token whose decoded form contains a newline (catches merged tokens like '.\\n' that bypass the single-token check). """ global _cached_mask_ids if _cached_mask_ids is not None: return _cached_mask_ids mask_ids = set() for s in _MASK_STRINGS: tokens = list(tokenizer.encode(s, add_special_tokens=False)) if len(tokens) == 1: mask_ids.add(tokens[0]) if tokenizer.eos_token_id is not None: mask_ids.add(tokenizer.eos_token_id) # Also mask any vocab token containing a newline for tid in range(tokenizer.vocab_size): decoded = tokenizer.decode([tid]) if "\n" in decoded: mask_ids.add(tid) _cached_mask_ids = list(mask_ids) return _cached_mask_ids def _get_next_token_candidates( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prefix: list[int], suffix: list[int], prompt_tokens: list[int], candidates: list[Candidate], top_k: int, top_p: float, repetition_penalty: float, mask_ids: list[int], ) -> list[list[tuple[int, float]]]: """Forward pass through LLM to get candidate next tokens. Builds input as: prefix + prompt_tokens + suffix + candidate.token_ids Uses KV-cache from candidates when available. Returns list of [(token_id, log_prob), ...] per candidate. """ device = next(model.parameters()).device # Build full token sequences base = prefix + prompt_tokens + suffix batch_tokens = [base + c.token_ids for c in candidates] # All sequences should have the same length (beam search invariant) assert len(set(len(t) for t in batch_tokens)) == 1 input_ids = torch.tensor(batch_tokens, device=device) # Check for usable KV-cache batch_kv = [c.kv_cache for c in candidates] use_cache = all(kv is not None for kv in batch_kv) if use_cache: kv_cache = DynamicCache.from_batch_splits(batch_kv) cache_len = kv_cache.get_seq_length() model_input = input_ids[:, cache_len:] attn_mask = torch.ones_like(input_ids, device=device) else: kv_cache = DynamicCache() model_input = input_ids attn_mask = None with torch.no_grad(): outputs = model( input_ids=model_input, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True, ) # Split KV-cache back per candidate next_kv = outputs.past_key_values try: split_kv = next_kv.batch_split(len(candidates), 1) if next_kv else [None] * len(candidates) except Exception: split_kv = [None] * len(candidates) logits = outputs.logits[:, -1, :] # (batch, vocab) # Apply repetition penalty if repetition_penalty != 1.0: for i, tokens in enumerate(batch_tokens): for tid in set(tokens): if logits[i, tid] > 0: logits[i, tid] /= repetition_penalty else: logits[i, tid] *= repetition_penalty # Mask special tokens logits[:, mask_ids] = -1e10 log_probs = F.log_softmax(logits, dim=-1) results = [] for i in range(len(candidates)): filtered = _top_k_top_p_filter(logits[i], top_k, top_p) pairs = [(tid, log_probs[i, tid].item()) for tid in filtered] pairs.sort(key=lambda x: x[1], reverse=True) results.append(pairs) return results, split_kv def _score_candidates( encoder: SentenceTransformer, target_embedding: torch.Tensor, candidates: list[Candidate], ) -> None: """Score candidates by cosine similarity to target embedding. Mutates in place.""" if not candidates: return texts = [c.seq_str for c in candidates] embs = encoder.encode(texts, convert_to_tensor=True, normalize_embeddings=True) # target_embedding shape: (1, dim) — broadcast target_norm = F.normalize(target_embedding, dim=-1) sims = torch.matmul(embs, target_norm.squeeze(0)) # (batch,) for i, c in enumerate(candidates): c.cos_sim = sims[i].item() c.score = c.cos_sim def beam_search( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, encoder: SentenceTransformer, target_embedding: torch.Tensor, prompt: str, beam_width: int = 30, max_steps: int = 0, top_k: int = 30, top_p: float = 1.0, repetition_penalty: float = 1.5, randomness: bool = True, patience: int = 5, min_similarity: float = 0.0, on_step: Callable | None = None, ) -> Candidate: """Run cosine-similarity-guided beam search. Args: model: Generator LLM. tokenizer: LLM tokenizer. encoder: Embedding encoder for scoring. target_embedding: Target embedding to invert. Shape (1, dim). prompt: User-facing prompt (becomes chat user message). beam_width: Number of candidates to maintain per step. max_steps: Maximum tokens to generate. 0 means no limit (stop via patience only). top_k: Top-k tokens to consider per expansion. top_p: Nucleus sampling threshold. repetition_penalty: Penalty for repeated tokens in logits. randomness: If True, keep top 5 deterministically + sample rest. patience: Stop after this many steps with no improvement in best cosine sim. Set to 0 to disable early stopping. min_similarity: Stop immediately when cosine sim reaches this threshold. Set to 0.0 to disable. on_step: Callback(step, best_candidate) fired each step. Returns: Best candidate found during search. """ prefix, suffix = get_chat_format(tokenizer) prompt_tokens = list(tokenizer.encode(prompt, add_special_tokens=False)) mask_ids = _build_mask_token_ids(tokenizer) candidates = [Candidate()] best_complete: Candidate | None = None best_ever: Candidate | None = None steps_since_improvement = 0 step = 0 while max_steps <= 0 or step < max_steps: step += 1 # Expand: get next-token proposals for each candidate token_proposals, split_kv = _get_next_token_candidates( model, tokenizer, prefix, suffix, prompt_tokens, candidates, top_k, top_p, repetition_penalty, mask_ids, ) # Build expanded candidates expanded: list[Candidate] = [] for i, cand in enumerate(candidates): for tid, _logp in token_proposals[i]: new_ids = cand.token_ids + [tid] expanded.append(Candidate( token_ids=new_ids, seq_str=tokenizer.decode(new_ids), kv_cache=split_kv[i] if split_kv[i] is not None else None, )) # Score by cosine similarity _score_candidates(encoder, target_embedding, expanded) # Sort by score descending expanded.sort(key=lambda c: c.score, reverse=True) # Track best-ever candidate (highest cosine sim at any step) step_best = expanded[0] if best_ever is None or step_best.cos_sim > best_ever.cos_sim: best_ever = Candidate( token_ids=list(step_best.token_ids), seq_str=step_best.seq_str, score=step_best.score, cos_sim=step_best.cos_sim, ) steps_since_improvement = 0 else: steps_since_improvement += 1 if patience > 0 and steps_since_improvement >= patience: break if min_similarity > 0 and best_ever.cos_sim >= min_similarity: break # Track best complete sentence for c in expanded: if c.seq_str and c.seq_str.rstrip()[-1:] in ".?!": if best_complete is None or c.score > best_complete.score: best_complete = Candidate( token_ids=list(c.token_ids), seq_str=c.seq_str, score=c.score, cos_sim=c.cos_sim, ) # Select: top beam_width candidates (with optional randomness) if randomness and len(expanded) > _FIXED_KEEP: keep = min(_FIXED_KEEP, beam_width) remainder = min(beam_width - keep, len(expanded) - keep) candidates = expanded[:keep] if remainder > 0: candidates += random.sample(expanded[keep:], remainder) else: candidates = expanded[:beam_width] # Callback if on_step is not None: best_so_far = best_complete if best_complete else candidates[0] on_step(step, best_so_far) # Return the candidate with the highest cosine similarity across all tracking finalists = [c for c in [best_ever, best_complete, candidates[0]] if c is not None] return max(finalists, key=lambda c: c.cos_sim) _STAGE1_PROMPT = "tell me a story" _STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}" def invert( text: str, encoder_name: str = "gte", beam_width: int = 30, max_steps: int = 0, top_k: int = 30, two_stage: bool = True, on_progress: Callable | None = None, ) -> InversionResult: """Run the full two-stage ZSInvert inversion pipeline. Stage 1: Seed generation with a generic prompt. Stage 2: Paraphrase refinement using the Stage 1 output as context. Args: text: Input text to encode and then invert. encoder_name: Which embedding encoder to use ("gte", "gtr", "contriever"). beam_width: Beam search width. max_steps: Maximum tokens per stage. top_k: Top-k tokens per expansion step. two_stage: If True, run both stages. If False, Stage 1 only. on_progress: Callback(stage, step, best_candidate) for UI updates. stage is 1 or 2, step is the beam search step index. Returns: InversionResult with results from both stages. """ from model import load_llm, load_encoder, encode_text model, tokenizer = load_llm() encoder = load_encoder(encoder_name) target_embedding = encode_text(text, encoder) # Stage 1: seed generation def stage1_callback(step: int, cand: Candidate) -> None: if on_progress is not None: on_progress(1, step, cand) stage1 = beam_search( model, tokenizer, encoder, target_embedding, prompt=_STAGE1_PROMPT, beam_width=beam_width, max_steps=max_steps, top_k=top_k, randomness=True, on_step=stage1_callback, ) result = InversionResult( original_text=text, target_embedding=target_embedding, stage1_text=stage1.seq_str, stage1_cos_sim=stage1.cos_sim, ) if not two_stage: result.stage2_text = result.stage1_text result.stage2_cos_sim = result.stage1_cos_sim return result # Stage 2: paraphrase refinement def stage2_callback(step: int, cand: Candidate) -> None: if on_progress is not None: on_progress(2, step, cand) stage2_prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=stage1.seq_str) stage2 = beam_search( model, tokenizer, encoder, target_embedding, prompt=stage2_prompt, beam_width=beam_width, max_steps=max_steps, top_k=top_k, randomness=True, on_step=stage2_callback, ) result.stage2_text = stage2.seq_str result.stage2_cos_sim = stage2.cos_sim return result