File size: 4,768 Bytes
0e77718
 
6ad0d9a
0e77718
7a0f913
 
0e77718
 
 
 
9737a84
 
 
9272618
 
 
 
 
 
7a0f913
9737a84
 
 
 
aa316bb
9272618
07ee289
9272618
 
 
 
 
 
 
aa316bb
9737a84
 
 
 
 
 
 
 
 
 
 
 
9272618
aa316bb
9737a84
9272618
9737a84
 
 
aa316bb
 
9272618
9737a84
 
 
 
7a0f913
0e77718
 
 
7a0f913
9737a84
7a0f913
0e77718
9737a84
aa316bb
7a0f913
9272618
aa316bb
 
0e77718
 
9737a84
7a0f913
9272618
 
 
0e77718
9737a84
9272618
0e77718
9272618
7a0f913
9272618
 
 
07ee289
9272618
aa316bb
0e77718
9272618
0e77718
7a0f913
9272618
 
0e77718
7a0f913
9272618
aa316bb
 
9272618
0e77718
07ee289
 
 
 
9272618
 
 
07ee289
0e77718
 
aa316bb
 
9737a84
aa316bb
0e77718
aa316bb
0e77718
9737a84
 
0e77718
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download

app = Flask(__name__)

_cache = {}


def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
    """
    Derive the encryption permutation from the secret seed.
    This is the CLIENT'S secret key β€” it never leaves this Space.
    The server only ever sees embeddings already scrambled by sigma.
    """
    rng = np.random.default_rng(seed)
    return rng.permutation(hidden_size)


def load_client_components(ee_model_name: str):
    """
    Load and cache:
      - ee_config  β†’ hidden_size + original model name
      - tokenizer  β†’ from EE model
      - embed_layer β†’ from the ORIGINAL (untransformed) model

    The original embed_layer is used to produce plain vectors from token IDs.
    The client then applies sigma to those plain vectors before sending.
    The server's EE model has weights permuted with sigma_inv, so:
        EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens))
    """
    if ee_model_name in _cache:
        return _cache[ee_model_name]

    config_path = hf_hub_download(ee_model_name, "ee_config.json")
    with open(config_path) as f:
        ee_config = json.load(f)

    hidden_size = ee_config["hidden_size"]
    original_model_name = ee_config["original_model"]

    tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)

    # Load ORIGINAL model just for its embed layer β€” discard everything else
    original_model = AutoModelForCausalLM.from_pretrained(
        original_model_name,
        torch_dtype=torch.float32,
        device_map="cpu",
        trust_remote_code=True,
    )
    embed_layer = original_model.model.embed_tokens
    embed_layer.eval()
    del original_model

    _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
    return tokenizer, embed_layer, hidden_size


@app.route("/", methods=["GET", "POST"])
def index():
    result = None
    error = None
    form_data = {}

    if request.method == "POST":
        form_data = request.form.to_dict()
        server_url    = request.form["server_url"].rstrip("/")
        ee_model_name = request.form["ee_model_name"].strip()
        ee_seed       = int(request.form["ee_seed"])   # SECRET β€” client only
        prompt        = request.form["prompt"].strip()
        max_tokens    = int(request.form.get("max_tokens", 256))

        try:
            tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)

            # --- CLIENT-SIDE ENCRYPTION ---

            # Step 1: tokenize
            inputs = tokenizer(prompt, return_tensors="pt")

            # Step 2: embed with ORIGINAL model embed layer β†’ plain vectors
            with torch.no_grad():
                plain_embeds = embed_layer(inputs.input_ids)  # (1, seq_len, hidden)

            # Step 3: apply sigma permutation β€” this is the encryption
            # The server NEVER sees plain_embeds, only the scrambled version.
            # Without knowing the seed, the server cannot recover the original.
            sigma = get_sigma(hidden_size, ee_seed)
            encrypted_embeds = plain_embeds[..., sigma]        # (1, seq_len, hidden)
            encrypted_embeds = encrypted_embeds.to(torch.float16)

            # --- SEND TO SERVER ---
            payload = {
                "encrypted_embeds": encrypted_embeds.tolist(),
                "attention_mask":   inputs.attention_mask.tolist(),
                "max_new_tokens":   max_tokens,
            }

            resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)

            if not resp.ok:
                raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}")

            body = resp.json()
            if "error" in body:
                raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")

            # --- OUTPUT DECODING ---
            # The EE model's lm_head rows are permuted with sigma_inv, so output
            # logits correctly index the real vocabulary β€” decode normally.
            gen_ids = body["generated_ids"]
            result = tokenizer.decode(gen_ids, skip_special_tokens=True)

        except RuntimeError as e:
            error = str(e)
        except requests.exceptions.ConnectionError:
            error = f"Could not connect to {server_url} β€” is the server Space running?"
        except Exception as e:
            error = f"{type(e).__name__}: {e}"

    return render_template("client.html", result=result, error=error, form=form_data)


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)