ryandt's picture
Removed streaming
f06d2ef
"""
Beam search inversion engine for ZSInvert.
Cosine-similarity-guided beam search that reconstructs text
from an embedding vector using a small LLM as the token
proposal engine.
Part of E04: ZSInvert.
"""
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Callable
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
from model import get_chat_format
# Tokens to mask from generation (special/formatting tokens)
_MASK_STRINGS = [
"<|im_end|>", "<|end_header_id|>", "<|start_header_id|>",
"<|eot_id|>", "<|eom_id|>", "<|python_tag|>",
"@", "\xa0", '"', "\n", "\n\n", " \n\n",
]
# Number of top beams kept deterministically in randomness mode
_FIXED_KEEP = 5
@dataclass
class Candidate:
"""A beam search candidate."""
token_ids: list[int] = field(default_factory=list)
seq_str: str = ""
score: float = 0.0
cos_sim: float = 0.0
kv_cache: DynamicCache | None = field(default=None, repr=False)
@dataclass
class InversionResult:
"""Result of a full inversion run."""
original_text: str | None = None
target_embedding: torch.Tensor | None = None
stage1_text: str = ""
stage1_cos_sim: float = 0.0
stage2_text: str = ""
stage2_cos_sim: float = 0.0
def _top_k_top_p_filter(logits: torch.Tensor, top_k: int, top_p: float) -> list[int]:
"""Return indices that survive top-k and top-p filtering."""
# Top-k: keep only top_k highest logits
topk_vals, topk_idx = torch.topk(logits, min(top_k, logits.size(-1)))
# Top-p (nucleus): keep smallest set whose cumulative prob >= top_p
probs = F.softmax(topk_vals, dim=-1)
cumulative = torch.cumsum(probs, dim=-1)
# Mask tokens beyond the nucleus
mask = cumulative - probs <= top_p
filtered_idx = topk_idx[mask]
return filtered_idx.tolist()
_cached_mask_ids: list[int] | None = None
def _build_mask_token_ids(tokenizer: AutoTokenizer) -> list[int]:
"""Build set of token IDs to suppress during generation. Cached.
Masks both exact single-token matches for _MASK_STRINGS and any
vocab token whose decoded form contains a newline (catches merged
tokens like '.\\n' that bypass the single-token check).
"""
global _cached_mask_ids
if _cached_mask_ids is not None:
return _cached_mask_ids
mask_ids = set()
for s in _MASK_STRINGS:
tokens = list(tokenizer.encode(s, add_special_tokens=False))
if len(tokens) == 1:
mask_ids.add(tokens[0])
if tokenizer.eos_token_id is not None:
mask_ids.add(tokenizer.eos_token_id)
# Also mask any vocab token containing a newline
for tid in range(tokenizer.vocab_size):
decoded = tokenizer.decode([tid])
if "\n" in decoded:
mask_ids.add(tid)
_cached_mask_ids = list(mask_ids)
return _cached_mask_ids
def _get_next_token_candidates(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prefix: list[int],
suffix: list[int],
prompt_tokens: list[int],
candidates: list[Candidate],
top_k: int,
top_p: float,
repetition_penalty: float,
mask_ids: list[int],
) -> list[list[tuple[int, float]]]:
"""Forward pass through LLM to get candidate next tokens.
Builds input as: prefix + prompt_tokens + suffix + candidate.token_ids
Uses KV-cache from candidates when available.
Returns list of [(token_id, log_prob), ...] per candidate.
"""
device = next(model.parameters()).device
# Build full token sequences
base = prefix + prompt_tokens + suffix
batch_tokens = [base + c.token_ids for c in candidates]
# All sequences should have the same length (beam search invariant)
assert len(set(len(t) for t in batch_tokens)) == 1
input_ids = torch.tensor(batch_tokens, device=device)
# Check for usable KV-cache
batch_kv = [c.kv_cache for c in candidates]
use_cache = all(kv is not None for kv in batch_kv)
if use_cache:
kv_cache = DynamicCache.from_batch_splits(batch_kv)
cache_len = kv_cache.get_seq_length()
model_input = input_ids[:, cache_len:]
attn_mask = torch.ones_like(input_ids, device=device)
else:
kv_cache = DynamicCache()
model_input = input_ids
attn_mask = None
with torch.no_grad():
outputs = model(
input_ids=model_input,
attention_mask=attn_mask,
past_key_values=kv_cache,
use_cache=True,
)
# Split KV-cache back per candidate
next_kv = outputs.past_key_values
try:
split_kv = next_kv.batch_split(len(candidates), 1) if next_kv else [None] * len(candidates)
except Exception:
split_kv = [None] * len(candidates)
logits = outputs.logits[:, -1, :] # (batch, vocab)
# Apply repetition penalty
if repetition_penalty != 1.0:
for i, tokens in enumerate(batch_tokens):
for tid in set(tokens):
if logits[i, tid] > 0:
logits[i, tid] /= repetition_penalty
else:
logits[i, tid] *= repetition_penalty
# Mask special tokens
logits[:, mask_ids] = -1e10
log_probs = F.log_softmax(logits, dim=-1)
results = []
for i in range(len(candidates)):
filtered = _top_k_top_p_filter(logits[i], top_k, top_p)
pairs = [(tid, log_probs[i, tid].item()) for tid in filtered]
pairs.sort(key=lambda x: x[1], reverse=True)
results.append(pairs)
return results, split_kv
def _score_candidates(
encoder: SentenceTransformer,
target_embedding: torch.Tensor,
candidates: list[Candidate],
) -> None:
"""Score candidates by cosine similarity to target embedding. Mutates in place."""
if not candidates:
return
texts = [c.seq_str for c in candidates]
embs = encoder.encode(texts, convert_to_tensor=True, normalize_embeddings=True)
# target_embedding shape: (1, dim) — broadcast
target_norm = F.normalize(target_embedding, dim=-1)
sims = torch.matmul(embs, target_norm.squeeze(0)) # (batch,)
for i, c in enumerate(candidates):
c.cos_sim = sims[i].item()
c.score = c.cos_sim
def beam_search(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
encoder: SentenceTransformer,
target_embedding: torch.Tensor,
prompt: str,
beam_width: int = 30,
max_steps: int = 0,
top_k: int = 30,
top_p: float = 1.0,
repetition_penalty: float = 1.5,
randomness: bool = True,
patience: int = 5,
min_similarity: float = 0.0,
on_step: Callable | None = None,
) -> Candidate:
"""Run cosine-similarity-guided beam search.
Args:
model: Generator LLM.
tokenizer: LLM tokenizer.
encoder: Embedding encoder for scoring.
target_embedding: Target embedding to invert. Shape (1, dim).
prompt: User-facing prompt (becomes chat user message).
beam_width: Number of candidates to maintain per step.
max_steps: Maximum tokens to generate. 0 means no limit (stop via patience only).
top_k: Top-k tokens to consider per expansion.
top_p: Nucleus sampling threshold.
repetition_penalty: Penalty for repeated tokens in logits.
randomness: If True, keep top 5 deterministically + sample rest.
patience: Stop after this many steps with no improvement in best cosine sim.
Set to 0 to disable early stopping.
min_similarity: Stop immediately when cosine sim reaches this threshold.
Set to 0.0 to disable.
on_step: Callback(step, best_candidate) fired each step.
Returns:
Best candidate found during search.
"""
prefix, suffix = get_chat_format(tokenizer)
prompt_tokens = list(tokenizer.encode(prompt, add_special_tokens=False))
mask_ids = _build_mask_token_ids(tokenizer)
candidates = [Candidate()]
best_complete: Candidate | None = None
best_ever: Candidate | None = None
steps_since_improvement = 0
step = 0
while max_steps <= 0 or step < max_steps:
step += 1
# Expand: get next-token proposals for each candidate
token_proposals, split_kv = _get_next_token_candidates(
model, tokenizer, prefix, suffix, prompt_tokens,
candidates, top_k, top_p, repetition_penalty, mask_ids,
)
# Build expanded candidates
expanded: list[Candidate] = []
for i, cand in enumerate(candidates):
for tid, _logp in token_proposals[i]:
new_ids = cand.token_ids + [tid]
expanded.append(Candidate(
token_ids=new_ids,
seq_str=tokenizer.decode(new_ids),
kv_cache=split_kv[i] if split_kv[i] is not None else None,
))
# Score by cosine similarity
_score_candidates(encoder, target_embedding, expanded)
# Sort by score descending
expanded.sort(key=lambda c: c.score, reverse=True)
# Track best-ever candidate (highest cosine sim at any step)
step_best = expanded[0]
if best_ever is None or step_best.cos_sim > best_ever.cos_sim:
best_ever = Candidate(
token_ids=list(step_best.token_ids),
seq_str=step_best.seq_str,
score=step_best.score,
cos_sim=step_best.cos_sim,
)
steps_since_improvement = 0
else:
steps_since_improvement += 1
if patience > 0 and steps_since_improvement >= patience:
break
if min_similarity > 0 and best_ever.cos_sim >= min_similarity:
break
# Track best complete sentence
for c in expanded:
if c.seq_str and c.seq_str.rstrip()[-1:] in ".?!":
if best_complete is None or c.score > best_complete.score:
best_complete = Candidate(
token_ids=list(c.token_ids),
seq_str=c.seq_str,
score=c.score,
cos_sim=c.cos_sim,
)
# Select: top beam_width candidates (with optional randomness)
if randomness and len(expanded) > _FIXED_KEEP:
keep = min(_FIXED_KEEP, beam_width)
remainder = min(beam_width - keep, len(expanded) - keep)
candidates = expanded[:keep]
if remainder > 0:
candidates += random.sample(expanded[keep:], remainder)
else:
candidates = expanded[:beam_width]
# Callback
if on_step is not None:
best_so_far = best_complete if best_complete else candidates[0]
on_step(step, best_so_far)
# Return the candidate with the highest cosine similarity across all tracking
finalists = [c for c in [best_ever, best_complete, candidates[0]] if c is not None]
return max(finalists, key=lambda c: c.cos_sim)
_STAGE1_PROMPT = "tell me a story"
_STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}"
def invert(
text: str,
encoder_name: str = "gte",
beam_width: int = 30,
max_steps: int = 0,
top_k: int = 30,
two_stage: bool = True,
on_progress: Callable | None = None,
) -> InversionResult:
"""Run the full two-stage ZSInvert inversion pipeline.
Stage 1: Seed generation with a generic prompt.
Stage 2: Paraphrase refinement using the Stage 1 output as context.
Args:
text: Input text to encode and then invert.
encoder_name: Which embedding encoder to use ("gte", "gtr", "contriever").
beam_width: Beam search width.
max_steps: Maximum tokens per stage.
top_k: Top-k tokens per expansion step.
two_stage: If True, run both stages. If False, Stage 1 only.
on_progress: Callback(stage, step, best_candidate) for UI updates.
stage is 1 or 2, step is the beam search step index.
Returns:
InversionResult with results from both stages.
"""
from model import load_llm, load_encoder, encode_text
model, tokenizer = load_llm()
encoder = load_encoder(encoder_name)
target_embedding = encode_text(text, encoder)
# Stage 1: seed generation
def stage1_callback(step: int, cand: Candidate) -> None:
if on_progress is not None:
on_progress(1, step, cand)
stage1 = beam_search(
model, tokenizer, encoder, target_embedding,
prompt=_STAGE1_PROMPT,
beam_width=beam_width,
max_steps=max_steps,
top_k=top_k,
randomness=True,
on_step=stage1_callback,
)
result = InversionResult(
original_text=text,
target_embedding=target_embedding,
stage1_text=stage1.seq_str,
stage1_cos_sim=stage1.cos_sim,
)
if not two_stage:
result.stage2_text = result.stage1_text
result.stage2_cos_sim = result.stage1_cos_sim
return result
# Stage 2: paraphrase refinement
def stage2_callback(step: int, cand: Candidate) -> None:
if on_progress is not None:
on_progress(2, step, cand)
stage2_prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=stage1.seq_str)
stage2 = beam_search(
model, tokenizer, encoder, target_embedding,
prompt=stage2_prompt,
beam_width=beam_width,
max_steps=max_steps,
top_k=top_k,
randomness=True,
on_step=stage2_callback,
)
result.stage2_text = stage2.seq_str
result.stage2_cos_sim = stage2.cos_sim
return result