Spaces:
Sleeping
Sleeping
| """ | |
| Beam search inversion engine for ZSInvert. | |
| Cosine-similarity-guided beam search that reconstructs text | |
| from an embedding vector using a small LLM as the token | |
| proposal engine. | |
| Part of E04: ZSInvert. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from dataclasses import dataclass, field | |
| from typing import Callable | |
| import torch | |
| import torch.nn.functional as F | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
| from model import get_chat_format | |
| # Tokens to mask from generation (special/formatting tokens) | |
| _MASK_STRINGS = [ | |
| "<|im_end|>", "<|end_header_id|>", "<|start_header_id|>", | |
| "<|eot_id|>", "<|eom_id|>", "<|python_tag|>", | |
| "@", "\xa0", '"', "\n", "\n\n", " \n\n", | |
| ] | |
| # Number of top beams kept deterministically in randomness mode | |
| _FIXED_KEEP = 5 | |
| class Candidate: | |
| """A beam search candidate.""" | |
| token_ids: list[int] = field(default_factory=list) | |
| seq_str: str = "" | |
| score: float = 0.0 | |
| cos_sim: float = 0.0 | |
| kv_cache: DynamicCache | None = field(default=None, repr=False) | |
| class InversionResult: | |
| """Result of a full inversion run.""" | |
| original_text: str | None = None | |
| target_embedding: torch.Tensor | None = None | |
| stage1_text: str = "" | |
| stage1_cos_sim: float = 0.0 | |
| stage2_text: str = "" | |
| stage2_cos_sim: float = 0.0 | |
| def _top_k_top_p_filter(logits: torch.Tensor, top_k: int, top_p: float) -> list[int]: | |
| """Return indices that survive top-k and top-p filtering.""" | |
| # Top-k: keep only top_k highest logits | |
| topk_vals, topk_idx = torch.topk(logits, min(top_k, logits.size(-1))) | |
| # Top-p (nucleus): keep smallest set whose cumulative prob >= top_p | |
| probs = F.softmax(topk_vals, dim=-1) | |
| cumulative = torch.cumsum(probs, dim=-1) | |
| # Mask tokens beyond the nucleus | |
| mask = cumulative - probs <= top_p | |
| filtered_idx = topk_idx[mask] | |
| return filtered_idx.tolist() | |
| _cached_mask_ids: list[int] | None = None | |
| def _build_mask_token_ids(tokenizer: AutoTokenizer) -> list[int]: | |
| """Build set of token IDs to suppress during generation. Cached. | |
| Masks both exact single-token matches for _MASK_STRINGS and any | |
| vocab token whose decoded form contains a newline (catches merged | |
| tokens like '.\\n' that bypass the single-token check). | |
| """ | |
| global _cached_mask_ids | |
| if _cached_mask_ids is not None: | |
| return _cached_mask_ids | |
| mask_ids = set() | |
| for s in _MASK_STRINGS: | |
| tokens = list(tokenizer.encode(s, add_special_tokens=False)) | |
| if len(tokens) == 1: | |
| mask_ids.add(tokens[0]) | |
| if tokenizer.eos_token_id is not None: | |
| mask_ids.add(tokenizer.eos_token_id) | |
| # Also mask any vocab token containing a newline | |
| for tid in range(tokenizer.vocab_size): | |
| decoded = tokenizer.decode([tid]) | |
| if "\n" in decoded: | |
| mask_ids.add(tid) | |
| _cached_mask_ids = list(mask_ids) | |
| return _cached_mask_ids | |
| def _get_next_token_candidates( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| prefix: list[int], | |
| suffix: list[int], | |
| prompt_tokens: list[int], | |
| candidates: list[Candidate], | |
| top_k: int, | |
| top_p: float, | |
| repetition_penalty: float, | |
| mask_ids: list[int], | |
| ) -> list[list[tuple[int, float]]]: | |
| """Forward pass through LLM to get candidate next tokens. | |
| Builds input as: prefix + prompt_tokens + suffix + candidate.token_ids | |
| Uses KV-cache from candidates when available. | |
| Returns list of [(token_id, log_prob), ...] per candidate. | |
| """ | |
| device = next(model.parameters()).device | |
| # Build full token sequences | |
| base = prefix + prompt_tokens + suffix | |
| batch_tokens = [base + c.token_ids for c in candidates] | |
| # All sequences should have the same length (beam search invariant) | |
| assert len(set(len(t) for t in batch_tokens)) == 1 | |
| input_ids = torch.tensor(batch_tokens, device=device) | |
| # Check for usable KV-cache | |
| batch_kv = [c.kv_cache for c in candidates] | |
| use_cache = all(kv is not None for kv in batch_kv) | |
| if use_cache: | |
| kv_cache = DynamicCache.from_batch_splits(batch_kv) | |
| cache_len = kv_cache.get_seq_length() | |
| model_input = input_ids[:, cache_len:] | |
| attn_mask = torch.ones_like(input_ids, device=device) | |
| else: | |
| kv_cache = DynamicCache() | |
| model_input = input_ids | |
| attn_mask = None | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=model_input, | |
| attention_mask=attn_mask, | |
| past_key_values=kv_cache, | |
| use_cache=True, | |
| ) | |
| # Split KV-cache back per candidate | |
| next_kv = outputs.past_key_values | |
| try: | |
| split_kv = next_kv.batch_split(len(candidates), 1) if next_kv else [None] * len(candidates) | |
| except Exception: | |
| split_kv = [None] * len(candidates) | |
| logits = outputs.logits[:, -1, :] # (batch, vocab) | |
| # Apply repetition penalty | |
| if repetition_penalty != 1.0: | |
| for i, tokens in enumerate(batch_tokens): | |
| for tid in set(tokens): | |
| if logits[i, tid] > 0: | |
| logits[i, tid] /= repetition_penalty | |
| else: | |
| logits[i, tid] *= repetition_penalty | |
| # Mask special tokens | |
| logits[:, mask_ids] = -1e10 | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| results = [] | |
| for i in range(len(candidates)): | |
| filtered = _top_k_top_p_filter(logits[i], top_k, top_p) | |
| pairs = [(tid, log_probs[i, tid].item()) for tid in filtered] | |
| pairs.sort(key=lambda x: x[1], reverse=True) | |
| results.append(pairs) | |
| return results, split_kv | |
| def _score_candidates( | |
| encoder: SentenceTransformer, | |
| target_embedding: torch.Tensor, | |
| candidates: list[Candidate], | |
| ) -> None: | |
| """Score candidates by cosine similarity to target embedding. Mutates in place.""" | |
| if not candidates: | |
| return | |
| texts = [c.seq_str for c in candidates] | |
| embs = encoder.encode(texts, convert_to_tensor=True, normalize_embeddings=True) | |
| # target_embedding shape: (1, dim) — broadcast | |
| target_norm = F.normalize(target_embedding, dim=-1) | |
| sims = torch.matmul(embs, target_norm.squeeze(0)) # (batch,) | |
| for i, c in enumerate(candidates): | |
| c.cos_sim = sims[i].item() | |
| c.score = c.cos_sim | |
| def beam_search( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| encoder: SentenceTransformer, | |
| target_embedding: torch.Tensor, | |
| prompt: str, | |
| beam_width: int = 30, | |
| max_steps: int = 0, | |
| top_k: int = 30, | |
| top_p: float = 1.0, | |
| repetition_penalty: float = 1.5, | |
| randomness: bool = True, | |
| patience: int = 5, | |
| min_similarity: float = 0.0, | |
| on_step: Callable | None = None, | |
| ) -> Candidate: | |
| """Run cosine-similarity-guided beam search. | |
| Args: | |
| model: Generator LLM. | |
| tokenizer: LLM tokenizer. | |
| encoder: Embedding encoder for scoring. | |
| target_embedding: Target embedding to invert. Shape (1, dim). | |
| prompt: User-facing prompt (becomes chat user message). | |
| beam_width: Number of candidates to maintain per step. | |
| max_steps: Maximum tokens to generate. 0 means no limit (stop via patience only). | |
| top_k: Top-k tokens to consider per expansion. | |
| top_p: Nucleus sampling threshold. | |
| repetition_penalty: Penalty for repeated tokens in logits. | |
| randomness: If True, keep top 5 deterministically + sample rest. | |
| patience: Stop after this many steps with no improvement in best cosine sim. | |
| Set to 0 to disable early stopping. | |
| min_similarity: Stop immediately when cosine sim reaches this threshold. | |
| Set to 0.0 to disable. | |
| on_step: Callback(step, best_candidate) fired each step. | |
| Returns: | |
| Best candidate found during search. | |
| """ | |
| prefix, suffix = get_chat_format(tokenizer) | |
| prompt_tokens = list(tokenizer.encode(prompt, add_special_tokens=False)) | |
| mask_ids = _build_mask_token_ids(tokenizer) | |
| candidates = [Candidate()] | |
| best_complete: Candidate | None = None | |
| best_ever: Candidate | None = None | |
| steps_since_improvement = 0 | |
| step = 0 | |
| while max_steps <= 0 or step < max_steps: | |
| step += 1 | |
| # Expand: get next-token proposals for each candidate | |
| token_proposals, split_kv = _get_next_token_candidates( | |
| model, tokenizer, prefix, suffix, prompt_tokens, | |
| candidates, top_k, top_p, repetition_penalty, mask_ids, | |
| ) | |
| # Build expanded candidates | |
| expanded: list[Candidate] = [] | |
| for i, cand in enumerate(candidates): | |
| for tid, _logp in token_proposals[i]: | |
| new_ids = cand.token_ids + [tid] | |
| expanded.append(Candidate( | |
| token_ids=new_ids, | |
| seq_str=tokenizer.decode(new_ids), | |
| kv_cache=split_kv[i] if split_kv[i] is not None else None, | |
| )) | |
| # Score by cosine similarity | |
| _score_candidates(encoder, target_embedding, expanded) | |
| # Sort by score descending | |
| expanded.sort(key=lambda c: c.score, reverse=True) | |
| # Track best-ever candidate (highest cosine sim at any step) | |
| step_best = expanded[0] | |
| if best_ever is None or step_best.cos_sim > best_ever.cos_sim: | |
| best_ever = Candidate( | |
| token_ids=list(step_best.token_ids), | |
| seq_str=step_best.seq_str, | |
| score=step_best.score, | |
| cos_sim=step_best.cos_sim, | |
| ) | |
| steps_since_improvement = 0 | |
| else: | |
| steps_since_improvement += 1 | |
| if patience > 0 and steps_since_improvement >= patience: | |
| break | |
| if min_similarity > 0 and best_ever.cos_sim >= min_similarity: | |
| break | |
| # Track best complete sentence | |
| for c in expanded: | |
| if c.seq_str and c.seq_str.rstrip()[-1:] in ".?!": | |
| if best_complete is None or c.score > best_complete.score: | |
| best_complete = Candidate( | |
| token_ids=list(c.token_ids), | |
| seq_str=c.seq_str, | |
| score=c.score, | |
| cos_sim=c.cos_sim, | |
| ) | |
| # Select: top beam_width candidates (with optional randomness) | |
| if randomness and len(expanded) > _FIXED_KEEP: | |
| keep = min(_FIXED_KEEP, beam_width) | |
| remainder = min(beam_width - keep, len(expanded) - keep) | |
| candidates = expanded[:keep] | |
| if remainder > 0: | |
| candidates += random.sample(expanded[keep:], remainder) | |
| else: | |
| candidates = expanded[:beam_width] | |
| # Callback | |
| if on_step is not None: | |
| best_so_far = best_complete if best_complete else candidates[0] | |
| on_step(step, best_so_far) | |
| # Return the candidate with the highest cosine similarity across all tracking | |
| finalists = [c for c in [best_ever, best_complete, candidates[0]] if c is not None] | |
| return max(finalists, key=lambda c: c.cos_sim) | |
| _STAGE1_PROMPT = "tell me a story" | |
| _STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}" | |
| def invert( | |
| text: str, | |
| encoder_name: str = "gte", | |
| beam_width: int = 30, | |
| max_steps: int = 0, | |
| top_k: int = 30, | |
| two_stage: bool = True, | |
| on_progress: Callable | None = None, | |
| ) -> InversionResult: | |
| """Run the full two-stage ZSInvert inversion pipeline. | |
| Stage 1: Seed generation with a generic prompt. | |
| Stage 2: Paraphrase refinement using the Stage 1 output as context. | |
| Args: | |
| text: Input text to encode and then invert. | |
| encoder_name: Which embedding encoder to use ("gte", "gtr", "contriever"). | |
| beam_width: Beam search width. | |
| max_steps: Maximum tokens per stage. | |
| top_k: Top-k tokens per expansion step. | |
| two_stage: If True, run both stages. If False, Stage 1 only. | |
| on_progress: Callback(stage, step, best_candidate) for UI updates. | |
| stage is 1 or 2, step is the beam search step index. | |
| Returns: | |
| InversionResult with results from both stages. | |
| """ | |
| from model import load_llm, load_encoder, encode_text | |
| model, tokenizer = load_llm() | |
| encoder = load_encoder(encoder_name) | |
| target_embedding = encode_text(text, encoder) | |
| # Stage 1: seed generation | |
| def stage1_callback(step: int, cand: Candidate) -> None: | |
| if on_progress is not None: | |
| on_progress(1, step, cand) | |
| stage1 = beam_search( | |
| model, tokenizer, encoder, target_embedding, | |
| prompt=_STAGE1_PROMPT, | |
| beam_width=beam_width, | |
| max_steps=max_steps, | |
| top_k=top_k, | |
| randomness=True, | |
| on_step=stage1_callback, | |
| ) | |
| result = InversionResult( | |
| original_text=text, | |
| target_embedding=target_embedding, | |
| stage1_text=stage1.seq_str, | |
| stage1_cos_sim=stage1.cos_sim, | |
| ) | |
| if not two_stage: | |
| result.stage2_text = result.stage1_text | |
| result.stage2_cos_sim = result.stage1_cos_sim | |
| return result | |
| # Stage 2: paraphrase refinement | |
| def stage2_callback(step: int, cand: Candidate) -> None: | |
| if on_progress is not None: | |
| on_progress(2, step, cand) | |
| stage2_prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=stage1.seq_str) | |
| stage2 = beam_search( | |
| model, tokenizer, encoder, target_embedding, | |
| prompt=stage2_prompt, | |
| beam_width=beam_width, | |
| max_steps=max_steps, | |
| top_k=top_k, | |
| randomness=True, | |
| on_step=stage2_callback, | |
| ) | |
| result.stage2_text = stage2.seq_str | |
| result.stage2_cos_sim = stage2.cos_sim | |
| return result | |