from flask import Flask, render_template, request import torch from transformers import AutoTokenizer, AutoModelForCausalLM import numpy as np import requests import json from huggingface_hub import hf_hub_download app = Flask(__name__) _cache = {} def get_sigma(hidden_size: int, seed: int) -> np.ndarray: """ Derive the encryption permutation from the secret seed. This is the CLIENT'S secret key — it never leaves this Space. The server only ever sees embeddings already scrambled by sigma. """ rng = np.random.default_rng(seed) return rng.permutation(hidden_size) def load_client_components(ee_model_name: str): """ Load and cache: - ee_config → hidden_size + original model name - tokenizer → from EE model - embed_layer → from the ORIGINAL (untransformed) model The original embed_layer is used to produce plain vectors from token IDs. The client then applies sigma to those plain vectors before sending. The server's EE model has weights permuted with sigma_inv, so: EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens)) """ if ee_model_name in _cache: return _cache[ee_model_name] config_path = hf_hub_download(ee_model_name, "ee_config.json") with open(config_path) as f: ee_config = json.load(f) hidden_size = ee_config["hidden_size"] original_model_name = ee_config["original_model"] tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True) # Load ORIGINAL model just for its embed layer — discard everything else original_model = AutoModelForCausalLM.from_pretrained( original_model_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) embed_layer = original_model.model.embed_tokens embed_layer.eval() del original_model _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size) return tokenizer, embed_layer, hidden_size @app.route("/", methods=["GET", "POST"]) def index(): result = None error = None form_data = {} if request.method == "POST": form_data = request.form.to_dict() server_url = request.form["server_url"].rstrip("/") ee_model_name = request.form["ee_model_name"].strip() ee_seed = int(request.form["ee_seed"]) # SECRET — client only prompt = request.form["prompt"].strip() max_tokens = int(request.form.get("max_tokens", 256)) try: tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name) # --- CLIENT-SIDE ENCRYPTION --- # Step 1: tokenize inputs = tokenizer(prompt, return_tensors="pt") # Step 2: embed with ORIGINAL model embed layer → plain vectors with torch.no_grad(): plain_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden) # Step 3: apply sigma permutation — this is the encryption # The server NEVER sees plain_embeds, only the scrambled version. # Without knowing the seed, the server cannot recover the original. sigma = get_sigma(hidden_size, ee_seed) encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden) encrypted_embeds = encrypted_embeds.to(torch.float16) # --- SEND TO SERVER --- payload = { "encrypted_embeds": encrypted_embeds.tolist(), "attention_mask": inputs.attention_mask.tolist(), "max_new_tokens": max_tokens, } resp = requests.post(f"{server_url}/generate", json=payload, timeout=300) if not resp.ok: raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}") body = resp.json() if "error" in body: raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}") # --- OUTPUT DECODING --- # The EE model's lm_head rows are permuted with sigma_inv, so output # logits correctly index the real vocabulary — decode normally. gen_ids = body["generated_ids"] result = tokenizer.decode(gen_ids, skip_special_tokens=True) except RuntimeError as e: error = str(e) except requests.exceptions.ConnectionError: error = f"Could not connect to {server_url} — is the server Space running?" except Exception as e: error = f"{type(e).__name__}: {e}" return render_template("client.html", result=result, error=error, form=form_data) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)