broadfield-dev commited on
Commit
52357b2
Β·
verified Β·
1 Parent(s): c9573f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -12,27 +12,11 @@ _cache = {}
12
 
13
 
14
  def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
15
- """
16
- Derive the encryption permutation from the secret seed.
17
- This is the CLIENT'S secret key β€” it never leaves this Space.
18
- The server only ever sees embeddings already scrambled by sigma.
19
- """
20
  rng = np.random.default_rng(seed)
21
  return rng.permutation(hidden_size)
22
 
23
 
24
  def load_client_components(ee_model_name: str):
25
- """
26
- Load and cache:
27
- - ee_config β†’ hidden_size + original model name
28
- - tokenizer β†’ from EE model
29
- - embed_layer β†’ from the ORIGINAL (untransformed) model
30
-
31
- The original embed_layer is used to produce plain vectors from token IDs.
32
- The client then applies sigma to those plain vectors before sending.
33
- The server's EE model has weights permuted with sigma_inv, so:
34
- EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens))
35
- """
36
  if ee_model_name in _cache:
37
  return _cache[ee_model_name]
38
 
@@ -43,9 +27,10 @@ def load_client_components(ee_model_name: str):
43
  hidden_size = ee_config["hidden_size"]
44
  original_model_name = ee_config["original_model"]
45
 
 
46
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
47
 
48
- # Load ORIGINAL model just for its embed layer β€” discard everything else
49
  original_model = AutoModelForCausalLM.from_pretrained(
50
  original_model_name,
51
  torch_dtype=torch.float32,
@@ -70,34 +55,44 @@ def index():
70
  form_data = request.form.to_dict()
71
  server_url = request.form["server_url"].rstrip("/")
72
  ee_model_name = request.form["ee_model_name"].strip()
73
- ee_seed = int(request.form["ee_seed"]) # SECRET β€” client only
74
  prompt = request.form["prompt"].strip()
75
  max_tokens = int(request.form.get("max_tokens", 256))
76
 
77
  try:
78
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
79
 
80
- # --- CLIENT-SIDE ENCRYPTION ---
81
-
82
- # Step 1: tokenize
83
- inputs = tokenizer(prompt, return_tensors="pt")
84
-
85
- # Step 2: embed with ORIGINAL model embed layer β†’ plain vectors
 
 
 
 
 
 
 
 
 
 
 
86
  with torch.no_grad():
87
- plain_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
88
 
89
- # Step 3: apply sigma permutation β€” this is the encryption
90
- # The server NEVER sees plain_embeds, only the scrambled version.
91
- # Without knowing the seed, the server cannot recover the original.
92
  sigma = get_sigma(hidden_size, ee_seed)
93
- encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
94
  encrypted_embeds = encrypted_embeds.to(torch.float16)
95
 
96
- # --- SEND TO SERVER ---
97
  payload = {
98
  "encrypted_embeds": encrypted_embeds.tolist(),
99
  "attention_mask": inputs.attention_mask.tolist(),
100
  "max_new_tokens": max_tokens,
 
101
  }
102
 
103
  resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
@@ -107,13 +102,11 @@ def index():
107
 
108
  body = resp.json()
109
  if "error" in body:
110
- raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")
111
 
112
- # --- OUTPUT DECODING ---
113
- # The EE model's lm_head rows are permuted with sigma_inv, so output
114
- # logits correctly index the real vocabulary β€” decode normally.
115
  gen_ids = body["generated_ids"]
116
- result = tokenizer.decode(gen_ids, skip_special_tokens=True)
117
 
118
  except RuntimeError as e:
119
  error = str(e)
 
12
 
13
 
14
  def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
 
 
 
 
 
15
  rng = np.random.default_rng(seed)
16
  return rng.permutation(hidden_size)
17
 
18
 
19
  def load_client_components(ee_model_name: str):
 
 
 
 
 
 
 
 
 
 
 
20
  if ee_model_name in _cache:
21
  return _cache[ee_model_name]
22
 
 
27
  hidden_size = ee_config["hidden_size"]
28
  original_model_name = ee_config["original_model"]
29
 
30
+ # Tokenizer from EE model (same vocab as original)
31
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
32
 
33
+ # Load ORIGINAL model just to extract embed_tokens, then discard
34
  original_model = AutoModelForCausalLM.from_pretrained(
35
  original_model_name,
36
  torch_dtype=torch.float32,
 
55
  form_data = request.form.to_dict()
56
  server_url = request.form["server_url"].rstrip("/")
57
  ee_model_name = request.form["ee_model_name"].strip()
58
+ ee_seed = int(request.form["ee_seed"])
59
  prompt = request.form["prompt"].strip()
60
  max_tokens = int(request.form.get("max_tokens", 256))
61
 
62
  try:
63
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
64
 
65
+ # --- Step 1: Apply chat template ---
66
+ # Qwen3 (and most instruct models) require the prompt wrapped in the
67
+ # chat template before tokenization, otherwise the model sees raw text
68
+ # with no special tokens and produces garbage.
69
+ messages = [{"role": "user", "content": prompt}]
70
+ formatted = tokenizer.apply_chat_template(
71
+ messages,
72
+ tokenize=False,
73
+ add_generation_prompt=True, # appends <|im_start|>assistant\n
74
+ )
75
+
76
+ # --- Step 2: Tokenize the formatted prompt ---
77
+ inputs = tokenizer(formatted, return_tensors="pt")
78
+ input_ids = inputs.input_ids # (1, seq_len)
79
+ input_len = input_ids.shape[1]
80
+
81
+ # --- Step 3: Embed with ORIGINAL model's embed layer ---
82
  with torch.no_grad():
83
+ plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
84
 
85
+ # --- Step 4: Encrypt β€” permute hidden dim with secret sigma ---
 
 
86
  sigma = get_sigma(hidden_size, ee_seed)
87
+ encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
88
  encrypted_embeds = encrypted_embeds.to(torch.float16)
89
 
90
+ # --- Step 5: Send to server ---
91
  payload = {
92
  "encrypted_embeds": encrypted_embeds.tolist(),
93
  "attention_mask": inputs.attention_mask.tolist(),
94
  "max_new_tokens": max_tokens,
95
+ "input_len": input_len, # so server can strip prompt tokens
96
  }
97
 
98
  resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
 
102
 
103
  body = resp.json()
104
  if "error" in body:
105
+ raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback', '')}")
106
 
107
+ # --- Step 6: Decode only the NEW tokens (strip echoed prompt) ---
 
 
108
  gen_ids = body["generated_ids"]
109
+ result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
110
 
111
  except RuntimeError as e:
112
  error = str(e)