""" ZSInvert — Zero-Shot Embedding Inversion Explorer. Interactive tool demonstrating embedding inversion via adversarial decoding beam search. Reconstructs text from embedding vectors without training embedding-specific models. Part of E04: ZSInvert. """ import time import gradio as gr import torch try: import spaces gpu_decorator = spaces.GPU(duration=120) except ImportError: gpu_decorator = lambda fn: fn from model import load_llm, load_encoder, encode_text, ENCODERS from invert import beam_search _STAGE1_PROMPT = "tell me a story" _STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}" # Encoder choices (drop contriever — broken) _ENCODER_CHOICES = [k for k in ENCODERS if k != "contriever"] def _sim_color(cos_sim: float) -> str: """Return hex color for a cosine similarity value.""" if cos_sim > 0.99: return "#3b82f6" # blue if cos_sim > 0.95: return "#16a34a" # dark green if cos_sim > 0.85: return "#65a30d" # green if cos_sim > 0.70: return "#ca8a04" # amber if cos_sim > 0.50: return "#ef4444" # red return "#a855f7" # purple def _format_results(stage_results: list[dict]) -> str: """Render accumulated stage results as styled HTML.""" if not stage_results: return "" rows = [] for r in stage_results: color = _sim_color(r["cos_sim"]) rows.append( f'
' f'S{r["stage"]} ' f'"{r["text"]}"
' f'cos={r["cos_sim"]:.4f}' f'  len={r["length"]}' f'  {r["time"]:.1f}s' f'  steps={r["steps"]}' f'
' ) return "".join(rows) @gpu_decorator def _run_stage_gpu( target_emb, encoder_name, prompt, beam_width, top_k, patience, max_steps, min_similarity, randomness, encode_text_input=None, ): """Run a single beam search stage on GPU. All CUDA operations happen inside this decorated function. If encode_text_input is provided and target_emb is None, encodes the text first (Stage 1). """ llm, tokenizer = load_llm() encoder = load_encoder(encoder_name) if target_emb is None and encode_text_input is not None: target_emb = encode_text(encode_text_input, encoder) elif target_emb is not None: # Move CPU tensor back to GPU for beam search device = next(llm.parameters()).device target_emb = target_emb.to(device) step_count = 0 def count_steps(step, cand): nonlocal step_count step_count = step t0 = time.time() result = beam_search( llm, tokenizer, encoder, target_emb, prompt=prompt, beam_width=int(beam_width), max_steps=int(max_steps), top_k=int(top_k), patience=int(patience), min_similarity=float(min_similarity), randomness=bool(randomness), on_step=count_steps, ) elapsed = time.time() - t0 # Return only CPU/plain data to avoid CUDA init in main process on ZeroGPU return { "seq_str": result.seq_str, "cos_sim": result.cos_sim, "token_ids": result.token_ids, }, elapsed, step_count, target_emb.cpu() def run_stage( text, encoder_name, beam_width, top_k, patience, max_steps, min_similarity, randomness, target_emb_state, stage_results_state, ): """Run the next stage of inversion.""" if not text or not text.strip(): gr.Warning("Please enter some text.") return ( target_emb_state, stage_results_state, _format_results(stage_results_state), gr.update(), ) stage_num = len(stage_results_state) + 1 # Build prompt if stage_num == 1: prompt = _STAGE1_PROMPT else: prev_text = stage_results_state[-1]["text"] prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=prev_text) # On Stage 1, pass raw text so encoding happens inside GPU context encode_input = text.strip() if stage_num == 1 else None result_dict, elapsed, steps, returned_emb_cpu = _run_stage_gpu( target_emb_state, encoder_name, prompt, beam_width, top_k, patience, max_steps, min_similarity, randomness, encode_text_input=encode_input, ) # Store embedding on CPU — it gets moved back to GPU inside _run_stage_gpu target_emb_state = returned_emb_cpu stage_results_state = stage_results_state + [{ "stage": stage_num, "text": result_dict["seq_str"], "cos_sim": result_dict["cos_sim"], "length": len(result_dict["token_ids"]), "time": elapsed, "steps": steps, }] html = _format_results(stage_results_state) btn_label = f"Run Stage {stage_num + 1}" return ( target_emb_state, stage_results_state, html, gr.update(value=btn_label, visible=True), ) def reset_state(): """Reset all state for a fresh run.""" return None, [], "", gr.update(value="Run Stage 1", visible=True) with gr.Blocks(title="ZSInvert") as demo: gr.Markdown("# Inverting Embeddings") gr.Markdown( "Reconstruct text from its embedding vector using " "cosine-similarity-guided beam search. " "Based on [Text Embeddings Reveal (Almost) As Much As Text](https://arxiv.org/abs/2504.00147) " "(Zhang, Morris, Shmatikov 2023)." ) # --- State --- target_emb_state = gr.State(value=None) stage_results_state = gr.State(value=[]) # --- Input row --- with gr.Row(): text_input = gr.Textbox( label="Input text", placeholder="Enter text to encode and invert...", scale=4, ) encoder_dd = gr.Dropdown( choices=_ENCODER_CHOICES, value="gte", label="Encoder", scale=1, ) # --- Advanced settings --- with gr.Accordion("Advanced Settings", open=False): with gr.Row(): beam_width_sl = gr.Slider(5, 50, value=10, step=1, label="beam_width") top_k_sl = gr.Slider(5, 50, value=10, step=1, label="top_k") patience_sl = gr.Slider(0, 20, value=5, step=1, label="patience (0=off)") with gr.Row(): max_steps_sl = gr.Slider(0, 64, value=0, step=1, label="max_steps (0=unlimited)") min_sim_sl = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="min_similarity (0=off)") randomness_cb = gr.Checkbox(value=True, label="randomness") # --- Run button --- run_btn = gr.Button("Run Stage 1", variant="primary") # --- Results --- results_html = gr.HTML(value="", label="Results") # --- Wiring --- all_inputs = [ text_input, encoder_dd, beam_width_sl, top_k_sl, patience_sl, max_steps_sl, min_sim_sl, randomness_cb, target_emb_state, stage_results_state, ] all_outputs = [ target_emb_state, stage_results_state, results_html, run_btn, ] run_btn.click(fn=run_stage, inputs=all_inputs, outputs=all_outputs) # Reset when input text or encoder changes text_input.change(fn=reset_state, inputs=[], outputs=all_outputs) encoder_dd.change(fn=reset_state, inputs=[], outputs=all_outputs) if __name__ == "__main__": demo.launch(server_port=7860, theme=gr.themes.Base())