Spaces:
Sleeping
Sleeping
| """ | |
| ZSInvert — Zero-Shot Embedding Inversion Explorer. | |
| Interactive tool demonstrating embedding inversion via | |
| adversarial decoding beam search. Reconstructs text from | |
| embedding vectors without training embedding-specific models. | |
| Part of E04: ZSInvert. | |
| """ | |
| import time | |
| import gradio as gr | |
| import torch | |
| try: | |
| import spaces | |
| gpu_decorator = spaces.GPU(duration=120) | |
| except ImportError: | |
| gpu_decorator = lambda fn: fn | |
| from model import load_llm, load_encoder, encode_text, ENCODERS | |
| from invert import beam_search | |
| _STAGE1_PROMPT = "tell me a story" | |
| _STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}" | |
| # Encoder choices (drop contriever — broken) | |
| _ENCODER_CHOICES = [k for k in ENCODERS if k != "contriever"] | |
| def _sim_color(cos_sim: float) -> str: | |
| """Return hex color for a cosine similarity value.""" | |
| if cos_sim > 0.99: | |
| return "#3b82f6" # blue | |
| if cos_sim > 0.95: | |
| return "#16a34a" # dark green | |
| if cos_sim > 0.85: | |
| return "#65a30d" # green | |
| if cos_sim > 0.70: | |
| return "#ca8a04" # amber | |
| if cos_sim > 0.50: | |
| return "#ef4444" # red | |
| return "#a855f7" # purple | |
| def _format_results(stage_results: list[dict]) -> str: | |
| """Render accumulated stage results as styled HTML.""" | |
| if not stage_results: | |
| return "" | |
| rows = [] | |
| for r in stage_results: | |
| color = _sim_color(r["cos_sim"]) | |
| rows.append( | |
| f'<div style="margin-bottom:12px;padding:10px;border:1px solid #333;border-radius:6px;' | |
| f'background:#1a1a2e;">' | |
| f'<span style="font-weight:bold;color:#ccc;">S{r["stage"]}</span> ' | |
| f'<span style="color:#eee;font-style:italic;">"{r["text"]}"</span><br>' | |
| f'<span style="color:{color};font-weight:bold;">cos={r["cos_sim"]:.4f}</span>' | |
| f' len={r["length"]}' | |
| f' {r["time"]:.1f}s' | |
| f' steps={r["steps"]}' | |
| f'</div>' | |
| ) | |
| return "".join(rows) | |
| def _run_stage_gpu( | |
| target_emb, encoder_name, prompt, | |
| beam_width, top_k, patience, max_steps, min_similarity, randomness, | |
| encode_text_input=None, | |
| ): | |
| """Run a single beam search stage on GPU. | |
| All CUDA operations happen inside this decorated function. | |
| If encode_text_input is provided and target_emb is None, | |
| encodes the text first (Stage 1). | |
| """ | |
| llm, tokenizer = load_llm() | |
| encoder = load_encoder(encoder_name) | |
| if target_emb is None and encode_text_input is not None: | |
| target_emb = encode_text(encode_text_input, encoder) | |
| elif target_emb is not None: | |
| # Move CPU tensor back to GPU for beam search | |
| device = next(llm.parameters()).device | |
| target_emb = target_emb.to(device) | |
| step_count = 0 | |
| def count_steps(step, cand): | |
| nonlocal step_count | |
| step_count = step | |
| t0 = time.time() | |
| result = beam_search( | |
| llm, tokenizer, encoder, target_emb, | |
| prompt=prompt, | |
| beam_width=int(beam_width), | |
| max_steps=int(max_steps), | |
| top_k=int(top_k), | |
| patience=int(patience), | |
| min_similarity=float(min_similarity), | |
| randomness=bool(randomness), | |
| on_step=count_steps, | |
| ) | |
| elapsed = time.time() - t0 | |
| # Return only CPU/plain data to avoid CUDA init in main process on ZeroGPU | |
| return { | |
| "seq_str": result.seq_str, | |
| "cos_sim": result.cos_sim, | |
| "token_ids": result.token_ids, | |
| }, elapsed, step_count, target_emb.cpu() | |
| def run_stage( | |
| text, encoder_name, | |
| beam_width, top_k, patience, max_steps, min_similarity, randomness, | |
| target_emb_state, stage_results_state, | |
| ): | |
| """Run the next stage of inversion.""" | |
| if not text or not text.strip(): | |
| gr.Warning("Please enter some text.") | |
| return ( | |
| target_emb_state, | |
| stage_results_state, | |
| _format_results(stage_results_state), | |
| gr.update(), | |
| ) | |
| stage_num = len(stage_results_state) + 1 | |
| # Build prompt | |
| if stage_num == 1: | |
| prompt = _STAGE1_PROMPT | |
| else: | |
| prev_text = stage_results_state[-1]["text"] | |
| prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=prev_text) | |
| # On Stage 1, pass raw text so encoding happens inside GPU context | |
| encode_input = text.strip() if stage_num == 1 else None | |
| result_dict, elapsed, steps, returned_emb_cpu = _run_stage_gpu( | |
| target_emb_state, encoder_name, prompt, | |
| beam_width, top_k, patience, max_steps, min_similarity, randomness, | |
| encode_text_input=encode_input, | |
| ) | |
| # Store embedding on CPU — it gets moved back to GPU inside _run_stage_gpu | |
| target_emb_state = returned_emb_cpu | |
| stage_results_state = stage_results_state + [{ | |
| "stage": stage_num, | |
| "text": result_dict["seq_str"], | |
| "cos_sim": result_dict["cos_sim"], | |
| "length": len(result_dict["token_ids"]), | |
| "time": elapsed, | |
| "steps": steps, | |
| }] | |
| html = _format_results(stage_results_state) | |
| btn_label = f"Run Stage {stage_num + 1}" | |
| return ( | |
| target_emb_state, | |
| stage_results_state, | |
| html, | |
| gr.update(value=btn_label, visible=True), | |
| ) | |
| def reset_state(): | |
| """Reset all state for a fresh run.""" | |
| return None, [], "", gr.update(value="Run Stage 1", visible=True) | |
| with gr.Blocks(title="ZSInvert") as demo: | |
| gr.Markdown("# Inverting Embeddings") | |
| gr.Markdown( | |
| "Reconstruct text from its embedding vector using " | |
| "cosine-similarity-guided beam search. " | |
| "Based on [Text Embeddings Reveal (Almost) As Much As Text](https://arxiv.org/abs/2504.00147) " | |
| "(Zhang, Morris, Shmatikov 2023)." | |
| ) | |
| # --- State --- | |
| target_emb_state = gr.State(value=None) | |
| stage_results_state = gr.State(value=[]) | |
| # --- Input row --- | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Input text", | |
| placeholder="Enter text to encode and invert...", | |
| scale=4, | |
| ) | |
| encoder_dd = gr.Dropdown( | |
| choices=_ENCODER_CHOICES, | |
| value="gte", | |
| label="Encoder", | |
| scale=1, | |
| ) | |
| # --- Advanced settings --- | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| beam_width_sl = gr.Slider(5, 50, value=10, step=1, label="beam_width") | |
| top_k_sl = gr.Slider(5, 50, value=10, step=1, label="top_k") | |
| patience_sl = gr.Slider(0, 20, value=5, step=1, label="patience (0=off)") | |
| with gr.Row(): | |
| max_steps_sl = gr.Slider(0, 64, value=0, step=1, label="max_steps (0=unlimited)") | |
| min_sim_sl = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="min_similarity (0=off)") | |
| randomness_cb = gr.Checkbox(value=True, label="randomness") | |
| # --- Run button --- | |
| run_btn = gr.Button("Run Stage 1", variant="primary") | |
| # --- Results --- | |
| results_html = gr.HTML(value="", label="Results") | |
| # --- Wiring --- | |
| all_inputs = [ | |
| text_input, encoder_dd, | |
| beam_width_sl, top_k_sl, patience_sl, max_steps_sl, min_sim_sl, randomness_cb, | |
| target_emb_state, stage_results_state, | |
| ] | |
| all_outputs = [ | |
| target_emb_state, stage_results_state, | |
| results_html, run_btn, | |
| ] | |
| run_btn.click(fn=run_stage, inputs=all_inputs, outputs=all_outputs) | |
| # Reset when input text or encoder changes | |
| text_input.change(fn=reset_state, inputs=[], outputs=all_outputs) | |
| encoder_dd.change(fn=reset_state, inputs=[], outputs=all_outputs) | |
| if __name__ == "__main__": | |
| demo.launch(server_port=7860, theme=gr.themes.Base()) | |