"""
ZSInvert — Zero-Shot Embedding Inversion Explorer.
Interactive tool demonstrating embedding inversion via
adversarial decoding beam search. Reconstructs text from
embedding vectors without training embedding-specific models.
Part of E04: ZSInvert.
"""
import time
import gradio as gr
import torch
try:
import spaces
gpu_decorator = spaces.GPU(duration=120)
except ImportError:
gpu_decorator = lambda fn: fn
from model import load_llm, load_encoder, encode_text, ENCODERS
from invert import beam_search
_STAGE1_PROMPT = "tell me a story"
_STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}"
# Encoder choices (drop contriever — broken)
_ENCODER_CHOICES = [k for k in ENCODERS if k != "contriever"]
def _sim_color(cos_sim: float) -> str:
"""Return hex color for a cosine similarity value."""
if cos_sim > 0.99:
return "#3b82f6" # blue
if cos_sim > 0.95:
return "#16a34a" # dark green
if cos_sim > 0.85:
return "#65a30d" # green
if cos_sim > 0.70:
return "#ca8a04" # amber
if cos_sim > 0.50:
return "#ef4444" # red
return "#a855f7" # purple
def _format_results(stage_results: list[dict]) -> str:
"""Render accumulated stage results as styled HTML."""
if not stage_results:
return ""
rows = []
for r in stage_results:
color = _sim_color(r["cos_sim"])
rows.append(
f'
'
f'S{r["stage"]} '
f'"{r["text"]}"
'
f'cos={r["cos_sim"]:.4f}'
f' len={r["length"]}'
f' {r["time"]:.1f}s'
f' steps={r["steps"]}'
f'
'
)
return "".join(rows)
@gpu_decorator
def _run_stage_gpu(
target_emb, encoder_name, prompt,
beam_width, top_k, patience, max_steps, min_similarity, randomness,
encode_text_input=None,
):
"""Run a single beam search stage on GPU.
All CUDA operations happen inside this decorated function.
If encode_text_input is provided and target_emb is None,
encodes the text first (Stage 1).
"""
llm, tokenizer = load_llm()
encoder = load_encoder(encoder_name)
if target_emb is None and encode_text_input is not None:
target_emb = encode_text(encode_text_input, encoder)
elif target_emb is not None:
# Move CPU tensor back to GPU for beam search
device = next(llm.parameters()).device
target_emb = target_emb.to(device)
step_count = 0
def count_steps(step, cand):
nonlocal step_count
step_count = step
t0 = time.time()
result = beam_search(
llm, tokenizer, encoder, target_emb,
prompt=prompt,
beam_width=int(beam_width),
max_steps=int(max_steps),
top_k=int(top_k),
patience=int(patience),
min_similarity=float(min_similarity),
randomness=bool(randomness),
on_step=count_steps,
)
elapsed = time.time() - t0
# Return only CPU/plain data to avoid CUDA init in main process on ZeroGPU
return {
"seq_str": result.seq_str,
"cos_sim": result.cos_sim,
"token_ids": result.token_ids,
}, elapsed, step_count, target_emb.cpu()
def run_stage(
text, encoder_name,
beam_width, top_k, patience, max_steps, min_similarity, randomness,
target_emb_state, stage_results_state,
):
"""Run the next stage of inversion."""
if not text or not text.strip():
gr.Warning("Please enter some text.")
return (
target_emb_state,
stage_results_state,
_format_results(stage_results_state),
gr.update(),
)
stage_num = len(stage_results_state) + 1
# Build prompt
if stage_num == 1:
prompt = _STAGE1_PROMPT
else:
prev_text = stage_results_state[-1]["text"]
prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=prev_text)
# On Stage 1, pass raw text so encoding happens inside GPU context
encode_input = text.strip() if stage_num == 1 else None
result_dict, elapsed, steps, returned_emb_cpu = _run_stage_gpu(
target_emb_state, encoder_name, prompt,
beam_width, top_k, patience, max_steps, min_similarity, randomness,
encode_text_input=encode_input,
)
# Store embedding on CPU — it gets moved back to GPU inside _run_stage_gpu
target_emb_state = returned_emb_cpu
stage_results_state = stage_results_state + [{
"stage": stage_num,
"text": result_dict["seq_str"],
"cos_sim": result_dict["cos_sim"],
"length": len(result_dict["token_ids"]),
"time": elapsed,
"steps": steps,
}]
html = _format_results(stage_results_state)
btn_label = f"Run Stage {stage_num + 1}"
return (
target_emb_state,
stage_results_state,
html,
gr.update(value=btn_label, visible=True),
)
def reset_state():
"""Reset all state for a fresh run."""
return None, [], "", gr.update(value="Run Stage 1", visible=True)
with gr.Blocks(title="ZSInvert") as demo:
gr.Markdown("# Inverting Embeddings")
gr.Markdown(
"Reconstruct text from its embedding vector using "
"cosine-similarity-guided beam search. "
"Based on [Text Embeddings Reveal (Almost) As Much As Text](https://arxiv.org/abs/2504.00147) "
"(Zhang, Morris, Shmatikov 2023)."
)
# --- State ---
target_emb_state = gr.State(value=None)
stage_results_state = gr.State(value=[])
# --- Input row ---
with gr.Row():
text_input = gr.Textbox(
label="Input text",
placeholder="Enter text to encode and invert...",
scale=4,
)
encoder_dd = gr.Dropdown(
choices=_ENCODER_CHOICES,
value="gte",
label="Encoder",
scale=1,
)
# --- Advanced settings ---
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
beam_width_sl = gr.Slider(5, 50, value=10, step=1, label="beam_width")
top_k_sl = gr.Slider(5, 50, value=10, step=1, label="top_k")
patience_sl = gr.Slider(0, 20, value=5, step=1, label="patience (0=off)")
with gr.Row():
max_steps_sl = gr.Slider(0, 64, value=0, step=1, label="max_steps (0=unlimited)")
min_sim_sl = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="min_similarity (0=off)")
randomness_cb = gr.Checkbox(value=True, label="randomness")
# --- Run button ---
run_btn = gr.Button("Run Stage 1", variant="primary")
# --- Results ---
results_html = gr.HTML(value="", label="Results")
# --- Wiring ---
all_inputs = [
text_input, encoder_dd,
beam_width_sl, top_k_sl, patience_sl, max_steps_sl, min_sim_sl, randomness_cb,
target_emb_state, stage_results_state,
]
all_outputs = [
target_emb_state, stage_results_state,
results_html, run_btn,
]
run_btn.click(fn=run_stage, inputs=all_inputs, outputs=all_outputs)
# Reset when input text or encoder changes
text_input.change(fn=reset_state, inputs=[], outputs=all_outputs)
encoder_dd.change(fn=reset_state, inputs=[], outputs=all_outputs)
if __name__ == "__main__":
demo.launch(server_port=7860, theme=gr.themes.Base())