| from flask import Flask, render_template, request |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import numpy as np |
| import requests |
| import json |
| from huggingface_hub import hf_hub_download |
|
|
| app = Flask(__name__) |
|
|
| _cache = {} |
|
|
|
|
| def get_sigma(hidden_size: int, seed: int) -> np.ndarray: |
| """ |
| Derive the encryption permutation from the secret seed. |
| This is the CLIENT'S secret key β it never leaves this Space. |
| The server only ever sees embeddings already scrambled by sigma. |
| """ |
| rng = np.random.default_rng(seed) |
| return rng.permutation(hidden_size) |
|
|
|
|
| def load_client_components(ee_model_name: str): |
| """ |
| Load and cache: |
| - ee_config β hidden_size + original model name |
| - tokenizer β from EE model |
| - embed_layer β from the ORIGINAL (untransformed) model |
| |
| The original embed_layer is used to produce plain vectors from token IDs. |
| The client then applies sigma to those plain vectors before sending. |
| The server's EE model has weights permuted with sigma_inv, so: |
| EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens)) |
| """ |
| if ee_model_name in _cache: |
| return _cache[ee_model_name] |
|
|
| config_path = hf_hub_download(ee_model_name, "ee_config.json") |
| with open(config_path) as f: |
| ee_config = json.load(f) |
|
|
| hidden_size = ee_config["hidden_size"] |
| original_model_name = ee_config["original_model"] |
|
|
| tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True) |
|
|
| |
| original_model = AutoModelForCausalLM.from_pretrained( |
| original_model_name, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True, |
| ) |
| embed_layer = original_model.model.embed_tokens |
| embed_layer.eval() |
| del original_model |
|
|
| _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size) |
| return tokenizer, embed_layer, hidden_size |
|
|
|
|
| @app.route("/", methods=["GET", "POST"]) |
| def index(): |
| result = None |
| error = None |
| form_data = {} |
|
|
| if request.method == "POST": |
| form_data = request.form.to_dict() |
| server_url = request.form["server_url"].rstrip("/") |
| ee_model_name = request.form["ee_model_name"].strip() |
| ee_seed = int(request.form["ee_seed"]) |
| prompt = request.form["prompt"].strip() |
| max_tokens = int(request.form.get("max_tokens", 256)) |
|
|
| try: |
| tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name) |
|
|
| |
|
|
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
|
|
| |
| with torch.no_grad(): |
| plain_embeds = embed_layer(inputs.input_ids) |
|
|
| |
| |
| |
| sigma = get_sigma(hidden_size, ee_seed) |
| encrypted_embeds = plain_embeds[..., sigma] |
| encrypted_embeds = encrypted_embeds.to(torch.float16) |
|
|
| |
| payload = { |
| "encrypted_embeds": encrypted_embeds.tolist(), |
| "attention_mask": inputs.attention_mask.tolist(), |
| "max_new_tokens": max_tokens, |
| } |
|
|
| resp = requests.post(f"{server_url}/generate", json=payload, timeout=300) |
|
|
| if not resp.ok: |
| raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}") |
|
|
| body = resp.json() |
| if "error" in body: |
| raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}") |
|
|
| |
| |
| |
| gen_ids = body["generated_ids"] |
| result = tokenizer.decode(gen_ids, skip_special_tokens=True) |
|
|
| except RuntimeError as e: |
| error = str(e) |
| except requests.exceptions.ConnectionError: |
| error = f"Could not connect to {server_url} β is the server Space running?" |
| except Exception as e: |
| error = f"{type(e).__name__}: {e}" |
|
|
| return render_template("client.html", result=result, error=error, form=form_data) |
|
|
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860) |