File size: 4,518 Bytes
0e77718
 
6ad0d9a
0e77718
7a0f913
 
0e77718
 
 
 
9737a84
 
 
9272618
7a0f913
9737a84
 
 
 
 
 
 
 
 
 
 
 
 
 
52357b2
9737a84
 
52357b2
aa316bb
9737a84
9272618
9737a84
 
 
aa316bb
 
9272618
9737a84
 
 
 
7a0f913
0e77718
 
 
7a0f913
9737a84
7a0f913
0e77718
9737a84
aa316bb
7a0f913
52357b2
aa316bb
 
0e77718
 
9737a84
7a0f913
52357b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e77718
52357b2
7a0f913
52357b2
07ee289
52357b2
aa316bb
0e77718
52357b2
0e77718
7a0f913
9272618
 
52357b2
0e77718
7a0f913
9272618
aa316bb
 
9272618
0e77718
07ee289
 
52357b2
07ee289
52357b2
07ee289
52357b2
0e77718
aa316bb
 
9737a84
aa316bb
0e77718
aa316bb
0e77718
9737a84
 
0e77718
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download

app = Flask(__name__)

_cache = {}


def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
    rng = np.random.default_rng(seed)
    return rng.permutation(hidden_size)


def load_client_components(ee_model_name: str):
    if ee_model_name in _cache:
        return _cache[ee_model_name]

    config_path = hf_hub_download(ee_model_name, "ee_config.json")
    with open(config_path) as f:
        ee_config = json.load(f)

    hidden_size = ee_config["hidden_size"]
    original_model_name = ee_config["original_model"]

    # Tokenizer from EE model (same vocab as original)
    tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)

    # Load ORIGINAL model just to extract embed_tokens, then discard
    original_model = AutoModelForCausalLM.from_pretrained(
        original_model_name,
        torch_dtype=torch.float32,
        device_map="cpu",
        trust_remote_code=True,
    )
    embed_layer = original_model.model.embed_tokens
    embed_layer.eval()
    del original_model

    _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
    return tokenizer, embed_layer, hidden_size


@app.route("/", methods=["GET", "POST"])
def index():
    result = None
    error = None
    form_data = {}

    if request.method == "POST":
        form_data = request.form.to_dict()
        server_url    = request.form["server_url"].rstrip("/")
        ee_model_name = request.form["ee_model_name"].strip()
        ee_seed       = int(request.form["ee_seed"])
        prompt        = request.form["prompt"].strip()
        max_tokens    = int(request.form.get("max_tokens", 256))

        try:
            tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)

            # --- Step 1: Apply chat template ---
            # Qwen3 (and most instruct models) require the prompt wrapped in the
            # chat template before tokenization, otherwise the model sees raw text
            # with no special tokens and produces garbage.
            messages = [{"role": "user", "content": prompt}]
            formatted = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,  # appends <|im_start|>assistant\n
            )

            # --- Step 2: Tokenize the formatted prompt ---
            inputs = tokenizer(formatted, return_tensors="pt")
            input_ids = inputs.input_ids  # (1, seq_len)
            input_len = input_ids.shape[1]

            # --- Step 3: Embed with ORIGINAL model's embed layer ---
            with torch.no_grad():
                plain_embeds = embed_layer(input_ids)  # (1, seq_len, hidden)

            # --- Step 4: Encrypt — permute hidden dim with secret sigma ---
            sigma = get_sigma(hidden_size, ee_seed)
            encrypted_embeds = plain_embeds[..., sigma]       # (1, seq_len, hidden)
            encrypted_embeds = encrypted_embeds.to(torch.float16)

            # --- Step 5: Send to server ---
            payload = {
                "encrypted_embeds": encrypted_embeds.tolist(),
                "attention_mask":   inputs.attention_mask.tolist(),
                "max_new_tokens":   max_tokens,
                "input_len":        input_len,  # so server can strip prompt tokens
            }

            resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)

            if not resp.ok:
                raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}")

            body = resp.json()
            if "error" in body:
                raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback', '')}")

            # --- Step 6: Decode only the NEW tokens (strip echoed prompt) ---
            gen_ids = body["generated_ids"]
            result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

        except RuntimeError as e:
            error = str(e)
        except requests.exceptions.ConnectionError:
            error = f"Could not connect to {server_url} — is the server Space running?"
        except Exception as e:
            error = f"{type(e).__name__}: {e}"

    return render_template("client.html", result=result, error=error, form=form_data)


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)