File size: 7,651 Bytes
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261552
5689bad
 
248ed0c
5689bad
a261552
248ed0c
a261552
 
 
248ed0c
5689bad
 
 
248ed0c
 
5c47ed5
 
 
 
248ed0c
5689bad
a261552
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
a261552
5689bad
 
5c47ed5
 
 
 
 
 
5689bad
 
 
 
 
 
 
a261552
5689bad
 
a261552
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248ed0c
 
 
5c47ed5
a261552
 
 
 
5689bad
5c47ed5
 
a261552
5689bad
 
5c47ed5
 
 
5689bad
 
 
 
 
 
 
a261552
5689bad
 
 
a261552
5689bad
 
 
 
 
a261552
5689bad
 
 
80dd98d
5689bad
 
80dd98d
3b4c39f
80dd98d
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
ZSInvert — Zero-Shot Embedding Inversion Explorer.

Interactive tool demonstrating embedding inversion via
adversarial decoding beam search. Reconstructs text from
embedding vectors without training embedding-specific models.

Part of E04: ZSInvert.
"""

import time

import gradio as gr
import torch

try:
    import spaces
    gpu_decorator = spaces.GPU(duration=120)
except ImportError:
    gpu_decorator = lambda fn: fn

from model import load_llm, load_encoder, encode_text, ENCODERS
from invert import beam_search

_STAGE1_PROMPT = "tell me a story"
_STAGE2_PROMPT_TEMPLATE = "write a sentence similar to this: {seed}"

# Encoder choices (drop contriever — broken)
_ENCODER_CHOICES = [k for k in ENCODERS if k != "contriever"]


def _sim_color(cos_sim: float) -> str:
    """Return hex color for a cosine similarity value."""
    if cos_sim > 0.99:
        return "#3b82f6"  # blue
    if cos_sim > 0.95:
        return "#16a34a"  # dark green
    if cos_sim > 0.85:
        return "#65a30d"  # green
    if cos_sim > 0.70:
        return "#ca8a04"  # amber
    if cos_sim > 0.50:
        return "#ef4444"  # red
    return "#a855f7"      # purple


def _format_results(stage_results: list[dict]) -> str:
    """Render accumulated stage results as styled HTML."""
    if not stage_results:
        return ""
    rows = []
    for r in stage_results:
        color = _sim_color(r["cos_sim"])
        rows.append(
            f'<div style="margin-bottom:12px;padding:10px;border:1px solid #333;border-radius:6px;'
            f'background:#1a1a2e;">'
            f'<span style="font-weight:bold;color:#ccc;">S{r["stage"]}</span> '
            f'<span style="color:#eee;font-style:italic;">"{r["text"]}"</span><br>'
            f'<span style="color:{color};font-weight:bold;">cos={r["cos_sim"]:.4f}</span>'
            f'&nbsp;&nbsp;len={r["length"]}'
            f'&nbsp;&nbsp;{r["time"]:.1f}s'
            f'&nbsp;&nbsp;steps={r["steps"]}'
            f'</div>'
        )
    return "".join(rows)


@gpu_decorator
def _run_stage_gpu(
    target_emb, encoder_name, prompt,
    beam_width, top_k, patience, max_steps, min_similarity, randomness,
    encode_text_input=None,
):
    """Run a single beam search stage on GPU.

    All CUDA operations happen inside this decorated function.
    If encode_text_input is provided and target_emb is None,
    encodes the text first (Stage 1).
    """
    llm, tokenizer = load_llm()
    encoder = load_encoder(encoder_name)

    if target_emb is None and encode_text_input is not None:
        target_emb = encode_text(encode_text_input, encoder)
    elif target_emb is not None:
        # Move CPU tensor back to GPU for beam search
        device = next(llm.parameters()).device
        target_emb = target_emb.to(device)

    step_count = 0
    def count_steps(step, cand):
        nonlocal step_count
        step_count = step

    t0 = time.time()
    result = beam_search(
        llm, tokenizer, encoder, target_emb,
        prompt=prompt,
        beam_width=int(beam_width),
        max_steps=int(max_steps),
        top_k=int(top_k),
        patience=int(patience),
        min_similarity=float(min_similarity),
        randomness=bool(randomness),
        on_step=count_steps,
    )
    elapsed = time.time() - t0
    # Return only CPU/plain data to avoid CUDA init in main process on ZeroGPU
    return {
        "seq_str": result.seq_str,
        "cos_sim": result.cos_sim,
        "token_ids": result.token_ids,
    }, elapsed, step_count, target_emb.cpu()


def run_stage(
    text, encoder_name,
    beam_width, top_k, patience, max_steps, min_similarity, randomness,
    target_emb_state, stage_results_state,
):
    """Run the next stage of inversion."""
    if not text or not text.strip():
        gr.Warning("Please enter some text.")
        return (
            target_emb_state,
            stage_results_state,
            _format_results(stage_results_state),
            gr.update(),
        )

    stage_num = len(stage_results_state) + 1

    # Build prompt
    if stage_num == 1:
        prompt = _STAGE1_PROMPT
    else:
        prev_text = stage_results_state[-1]["text"]
        prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=prev_text)

    # On Stage 1, pass raw text so encoding happens inside GPU context
    encode_input = text.strip() if stage_num == 1 else None

    result_dict, elapsed, steps, returned_emb_cpu = _run_stage_gpu(
        target_emb_state, encoder_name, prompt,
        beam_width, top_k, patience, max_steps, min_similarity, randomness,
        encode_text_input=encode_input,
    )

    # Store embedding on CPU — it gets moved back to GPU inside _run_stage_gpu
    target_emb_state = returned_emb_cpu

    stage_results_state = stage_results_state + [{
        "stage": stage_num,
        "text": result_dict["seq_str"],
        "cos_sim": result_dict["cos_sim"],
        "length": len(result_dict["token_ids"]),
        "time": elapsed,
        "steps": steps,
    }]

    html = _format_results(stage_results_state)
    btn_label = f"Run Stage {stage_num + 1}"

    return (
        target_emb_state,
        stage_results_state,
        html,
        gr.update(value=btn_label, visible=True),
    )


def reset_state():
    """Reset all state for a fresh run."""
    return None, [], "", gr.update(value="Run Stage 1", visible=True)


with gr.Blocks(title="ZSInvert") as demo:
    gr.Markdown("# Inverting Embeddings")
    gr.Markdown(
        "Reconstruct text from its embedding vector using "
        "cosine-similarity-guided beam search. "
        "Based on [Text Embeddings Reveal (Almost) As Much As Text](https://arxiv.org/abs/2504.00147) "
        "(Zhang, Morris, Shmatikov 2023)."
    )

    # --- State ---
    target_emb_state = gr.State(value=None)
    stage_results_state = gr.State(value=[])

    # --- Input row ---
    with gr.Row():
        text_input = gr.Textbox(
            label="Input text",
            placeholder="Enter text to encode and invert...",
            scale=4,
        )
        encoder_dd = gr.Dropdown(
            choices=_ENCODER_CHOICES,
            value="gte",
            label="Encoder",
            scale=1,
        )

    # --- Advanced settings ---
    with gr.Accordion("Advanced Settings", open=False):
        with gr.Row():
            beam_width_sl = gr.Slider(5, 50, value=10, step=1, label="beam_width")
            top_k_sl = gr.Slider(5, 50, value=10, step=1, label="top_k")
            patience_sl = gr.Slider(0, 20, value=5, step=1, label="patience (0=off)")
        with gr.Row():
            max_steps_sl = gr.Slider(0, 64, value=0, step=1, label="max_steps (0=unlimited)")
            min_sim_sl = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="min_similarity (0=off)")
            randomness_cb = gr.Checkbox(value=True, label="randomness")

    # --- Run button ---
    run_btn = gr.Button("Run Stage 1", variant="primary")

    # --- Results ---
    results_html = gr.HTML(value="", label="Results")

    # --- Wiring ---
    all_inputs = [
        text_input, encoder_dd,
        beam_width_sl, top_k_sl, patience_sl, max_steps_sl, min_sim_sl, randomness_cb,
        target_emb_state, stage_results_state,
    ]
    all_outputs = [
        target_emb_state, stage_results_state,
        results_html, run_btn,
    ]

    run_btn.click(fn=run_stage, inputs=all_inputs, outputs=all_outputs)

    # Reset when input text or encoder changes
    text_input.change(fn=reset_state, inputs=[], outputs=all_outputs)
    encoder_dd.change(fn=reset_state, inputs=[], outputs=all_outputs)


if __name__ == "__main__":
    demo.launch(server_port=7860, theme=gr.themes.Base())