File size: 3,969 Bytes
3977e64 ba9a967 3977e64 ba9a967 3383b9c ba9a967 3977e64 9e6e352 ba9a967 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 ba9a967 9e6e352 ba9a967 3977e64 9e6e352 3977e64 9e6e352 3383b9c 9e6e352 3383b9c 9e6e352 5cad3e1 3383b9c 9e6e352 5cad3e1 3383b9c 9e6e352 3383b9c 4da358a 3383b9c 9e6e352 3383b9c 9e6e352 3383b9c 4da358a 9e6e352 4da358a 5cad3e1 9e6e352 3383b9c 9e6e352 3383b9c 3977e64 ba9a967 3977e64 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | from flask import Flask, render_template, request, flash, jsonify
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import os, json
app = Flask(__name__)
app.secret_key = os.urandom(24)
ee_model = None
ee_tokenizer = None
ee_config = None
SPACE_HOST = os.environ.get("SPACE_HOST", "")
SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
@app.route("/", methods=["GET", "POST"])
def index():
global ee_model, ee_tokenizer, ee_config
if request.method == "POST":
ee_model_name = request.form["ee_model_name"].strip()
hf_token = request.form["hf_token"].strip()
try:
login(token=hf_token)
ee_model = AutoModelForCausalLM.from_pretrained(
ee_model_name, torch_dtype=torch.float16,
device_map="auto", trust_remote_code=True
)
ee_tokenizer = AutoTokenizer.from_pretrained(
ee_model_name, trust_remote_code=True
)
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
flash(f"✅ Model loaded: {ee_model_name}", "success")
flash("Point your Client Space to this Space's URL below.", "info")
except Exception as e:
flash(f"Error: {str(e)}", "danger")
return render_template(
"index.html",
server_ready=(ee_model is not None),
model_name=ee_config["original_model"] if ee_config else None,
space_url=SPACE_URL,
)
@app.route("/generate", methods=["POST"])
def generate():
"""
Receives sigma-encrypted embeddings + optional past_key_values.
Returns last hidden state (still in sigma-space) + new KV cache.
Does NOT run lm_head — that stays on the client.
Server never sees token IDs, logits, or plaintext.
"""
if ee_model is None:
return jsonify({"error": "Server not started yet"}), 400
try:
data = request.json
model_dtype = next(ee_model.parameters()).dtype
inputs_embeds = torch.tensor(data["inputs_embeds"]).to(
dtype=model_dtype, device=ee_model.device
)
attention_mask = torch.tensor(
data.get("attention_mask", [[1] * inputs_embeds.shape[1]])
).to(device=ee_model.device)
past_key_values = None
if data.get("past_key_values"):
past_key_values = tuple(
tuple(
torch.tensor(t).to(dtype=model_dtype, device=ee_model.device)
for t in layer
)
for layer in data["past_key_values"]
)
# Ensure model config has caching enabled
ee_model.config.use_cache = True
with torch.no_grad():
out = ee_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
output_hidden_states=True,
)
# Final hidden state (sigma-space) — client decrypts + runs lm_head
last_hidden = out.hidden_states[-1] # (1, seq_len, hidden)
# Serialize KV cache — guard against None (some models/configs don't return it)
new_past = None
if out.past_key_values is not None:
new_past = [
[t.cpu().tolist() for t in layer]
for layer in out.past_key_values
]
return jsonify({
"last_hidden": last_hidden.cpu().tolist(),
"past_key_values": new_past,
})
except Exception as e:
import traceback
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |