broadfield-dev's picture
Update app.py
9272618 verified
raw
history blame
4.77 kB
from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download
app = Flask(__name__)
_cache = {}
def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
"""
Derive the encryption permutation from the secret seed.
This is the CLIENT'S secret key β€” it never leaves this Space.
The server only ever sees embeddings already scrambled by sigma.
"""
rng = np.random.default_rng(seed)
return rng.permutation(hidden_size)
def load_client_components(ee_model_name: str):
"""
Load and cache:
- ee_config β†’ hidden_size + original model name
- tokenizer β†’ from EE model
- embed_layer β†’ from the ORIGINAL (untransformed) model
The original embed_layer is used to produce plain vectors from token IDs.
The client then applies sigma to those plain vectors before sending.
The server's EE model has weights permuted with sigma_inv, so:
EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens))
"""
if ee_model_name in _cache:
return _cache[ee_model_name]
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
hidden_size = ee_config["hidden_size"]
original_model_name = ee_config["original_model"]
tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
# Load ORIGINAL model just for its embed layer β€” discard everything else
original_model = AutoModelForCausalLM.from_pretrained(
original_model_name,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
embed_layer = original_model.model.embed_tokens
embed_layer.eval()
del original_model
_cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
return tokenizer, embed_layer, hidden_size
@app.route("/", methods=["GET", "POST"])
def index():
result = None
error = None
form_data = {}
if request.method == "POST":
form_data = request.form.to_dict()
server_url = request.form["server_url"].rstrip("/")
ee_model_name = request.form["ee_model_name"].strip()
ee_seed = int(request.form["ee_seed"]) # SECRET β€” client only
prompt = request.form["prompt"].strip()
max_tokens = int(request.form.get("max_tokens", 256))
try:
tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
# --- CLIENT-SIDE ENCRYPTION ---
# Step 1: tokenize
inputs = tokenizer(prompt, return_tensors="pt")
# Step 2: embed with ORIGINAL model embed layer β†’ plain vectors
with torch.no_grad():
plain_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
# Step 3: apply sigma permutation β€” this is the encryption
# The server NEVER sees plain_embeds, only the scrambled version.
# Without knowing the seed, the server cannot recover the original.
sigma = get_sigma(hidden_size, ee_seed)
encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
encrypted_embeds = encrypted_embeds.to(torch.float16)
# --- SEND TO SERVER ---
payload = {
"encrypted_embeds": encrypted_embeds.tolist(),
"attention_mask": inputs.attention_mask.tolist(),
"max_new_tokens": max_tokens,
}
resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
if not resp.ok:
raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}")
body = resp.json()
if "error" in body:
raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")
# --- OUTPUT DECODING ---
# The EE model's lm_head rows are permuted with sigma_inv, so output
# logits correctly index the real vocabulary β€” decode normally.
gen_ids = body["generated_ids"]
result = tokenizer.decode(gen_ids, skip_special_tokens=True)
except RuntimeError as e:
error = str(e)
except requests.exceptions.ConnectionError:
error = f"Could not connect to {server_url} β€” is the server Space running?"
except Exception as e:
error = f"{type(e).__name__}: {e}"
return render_template("client.html", result=result, error=error, form=form_data)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)