File size: 4,518 Bytes
0e77718 6ad0d9a 0e77718 7a0f913 0e77718 9737a84 9272618 7a0f913 9737a84 52357b2 9737a84 52357b2 aa316bb 9737a84 9272618 9737a84 aa316bb 9272618 9737a84 7a0f913 0e77718 7a0f913 9737a84 7a0f913 0e77718 9737a84 aa316bb 7a0f913 52357b2 aa316bb 0e77718 9737a84 7a0f913 52357b2 0e77718 52357b2 7a0f913 52357b2 07ee289 52357b2 aa316bb 0e77718 52357b2 0e77718 7a0f913 9272618 52357b2 0e77718 7a0f913 9272618 aa316bb 9272618 0e77718 07ee289 52357b2 07ee289 52357b2 07ee289 52357b2 0e77718 aa316bb 9737a84 aa316bb 0e77718 aa316bb 0e77718 9737a84 0e77718 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download
app = Flask(__name__)
_cache = {}
def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
rng = np.random.default_rng(seed)
return rng.permutation(hidden_size)
def load_client_components(ee_model_name: str):
if ee_model_name in _cache:
return _cache[ee_model_name]
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
hidden_size = ee_config["hidden_size"]
original_model_name = ee_config["original_model"]
# Tokenizer from EE model (same vocab as original)
tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
# Load ORIGINAL model just to extract embed_tokens, then discard
original_model = AutoModelForCausalLM.from_pretrained(
original_model_name,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
embed_layer = original_model.model.embed_tokens
embed_layer.eval()
del original_model
_cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
return tokenizer, embed_layer, hidden_size
@app.route("/", methods=["GET", "POST"])
def index():
result = None
error = None
form_data = {}
if request.method == "POST":
form_data = request.form.to_dict()
server_url = request.form["server_url"].rstrip("/")
ee_model_name = request.form["ee_model_name"].strip()
ee_seed = int(request.form["ee_seed"])
prompt = request.form["prompt"].strip()
max_tokens = int(request.form.get("max_tokens", 256))
try:
tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
# --- Step 1: Apply chat template ---
# Qwen3 (and most instruct models) require the prompt wrapped in the
# chat template before tokenization, otherwise the model sees raw text
# with no special tokens and produces garbage.
messages = [{"role": "user", "content": prompt}]
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True, # appends <|im_start|>assistant\n
)
# --- Step 2: Tokenize the formatted prompt ---
inputs = tokenizer(formatted, return_tensors="pt")
input_ids = inputs.input_ids # (1, seq_len)
input_len = input_ids.shape[1]
# --- Step 3: Embed with ORIGINAL model's embed layer ---
with torch.no_grad():
plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
# --- Step 4: Encrypt — permute hidden dim with secret sigma ---
sigma = get_sigma(hidden_size, ee_seed)
encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
encrypted_embeds = encrypted_embeds.to(torch.float16)
# --- Step 5: Send to server ---
payload = {
"encrypted_embeds": encrypted_embeds.tolist(),
"attention_mask": inputs.attention_mask.tolist(),
"max_new_tokens": max_tokens,
"input_len": input_len, # so server can strip prompt tokens
}
resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
if not resp.ok:
raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}")
body = resp.json()
if "error" in body:
raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback', '')}")
# --- Step 6: Decode only the NEW tokens (strip echoed prompt) ---
gen_ids = body["generated_ids"]
result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
except RuntimeError as e:
error = str(e)
except requests.exceptions.ConnectionError:
error = f"Could not connect to {server_url} — is the server Space running?"
except Exception as e:
error = f"{type(e).__name__}: {e}"
return render_template("client.html", result=result, error=error, form=form_data)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |