File size: 4,768 Bytes
0e77718 6ad0d9a 0e77718 7a0f913 0e77718 9737a84 9272618 7a0f913 9737a84 aa316bb 9272618 07ee289 9272618 aa316bb 9737a84 9272618 aa316bb 9737a84 9272618 9737a84 aa316bb 9272618 9737a84 7a0f913 0e77718 7a0f913 9737a84 7a0f913 0e77718 9737a84 aa316bb 7a0f913 9272618 aa316bb 0e77718 9737a84 7a0f913 9272618 0e77718 9737a84 9272618 0e77718 9272618 7a0f913 9272618 07ee289 9272618 aa316bb 0e77718 9272618 0e77718 7a0f913 9272618 0e77718 7a0f913 9272618 aa316bb 9272618 0e77718 07ee289 9272618 07ee289 0e77718 aa316bb 9737a84 aa316bb 0e77718 aa316bb 0e77718 9737a84 0e77718 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download
app = Flask(__name__)
_cache = {}
def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
"""
Derive the encryption permutation from the secret seed.
This is the CLIENT'S secret key β it never leaves this Space.
The server only ever sees embeddings already scrambled by sigma.
"""
rng = np.random.default_rng(seed)
return rng.permutation(hidden_size)
def load_client_components(ee_model_name: str):
"""
Load and cache:
- ee_config β hidden_size + original model name
- tokenizer β from EE model
- embed_layer β from the ORIGINAL (untransformed) model
The original embed_layer is used to produce plain vectors from token IDs.
The client then applies sigma to those plain vectors before sending.
The server's EE model has weights permuted with sigma_inv, so:
EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens))
"""
if ee_model_name in _cache:
return _cache[ee_model_name]
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
hidden_size = ee_config["hidden_size"]
original_model_name = ee_config["original_model"]
tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
# Load ORIGINAL model just for its embed layer β discard everything else
original_model = AutoModelForCausalLM.from_pretrained(
original_model_name,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
embed_layer = original_model.model.embed_tokens
embed_layer.eval()
del original_model
_cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
return tokenizer, embed_layer, hidden_size
@app.route("/", methods=["GET", "POST"])
def index():
result = None
error = None
form_data = {}
if request.method == "POST":
form_data = request.form.to_dict()
server_url = request.form["server_url"].rstrip("/")
ee_model_name = request.form["ee_model_name"].strip()
ee_seed = int(request.form["ee_seed"]) # SECRET β client only
prompt = request.form["prompt"].strip()
max_tokens = int(request.form.get("max_tokens", 256))
try:
tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
# --- CLIENT-SIDE ENCRYPTION ---
# Step 1: tokenize
inputs = tokenizer(prompt, return_tensors="pt")
# Step 2: embed with ORIGINAL model embed layer β plain vectors
with torch.no_grad():
plain_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
# Step 3: apply sigma permutation β this is the encryption
# The server NEVER sees plain_embeds, only the scrambled version.
# Without knowing the seed, the server cannot recover the original.
sigma = get_sigma(hidden_size, ee_seed)
encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
encrypted_embeds = encrypted_embeds.to(torch.float16)
# --- SEND TO SERVER ---
payload = {
"encrypted_embeds": encrypted_embeds.tolist(),
"attention_mask": inputs.attention_mask.tolist(),
"max_new_tokens": max_tokens,
}
resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
if not resp.ok:
raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}")
body = resp.json()
if "error" in body:
raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")
# --- OUTPUT DECODING ---
# The EE model's lm_head rows are permuted with sigma_inv, so output
# logits correctly index the real vocabulary β decode normally.
gen_ids = body["generated_ids"]
result = tokenizer.decode(gen_ids, skip_special_tokens=True)
except RuntimeError as e:
error = str(e)
except requests.exceptions.ConnectionError:
error = f"Could not connect to {server_url} β is the server Space running?"
except Exception as e:
error = f"{type(e).__name__}: {e}"
return render_template("client.html", result=result, error=error, form=form_data)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |