broadfield-dev commited on
Commit
07ee289
Β·
verified Β·
1 Parent(s): aa316bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -26
app.py CHANGED
@@ -8,27 +8,32 @@ from huggingface_hub import hf_hub_download
8
 
9
  app = Flask(__name__)
10
 
11
- # Cache tokenizer + embed layer so repeated requests don't reload
12
  _cache = {}
13
 
14
 
15
  def get_sigma(hidden_size: int, seed: int):
16
- """Derive encryption permutation from secret seed."""
17
  rng = np.random.default_rng(seed)
18
  return rng.permutation(hidden_size)
19
 
20
 
21
  def load_client_components(ee_model_name: str):
22
  """
23
- Load (and cache) only what the client needs:
24
- - tokenizer from the EE model
25
- - embedding layer from the ORIGINAL model (just embed_tokens, not the full LLM)
26
- - hidden_size from ee_config
 
 
 
 
 
 
27
  """
28
  if ee_model_name in _cache:
29
  return _cache[ee_model_name]
30
 
31
- # 1. Read EE config to get hidden_size + original model name
32
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
33
  with open(config_path) as f:
34
  ee_config = json.load(f)
@@ -36,21 +41,18 @@ def load_client_components(ee_model_name: str):
36
  hidden_size = ee_config["hidden_size"]
37
  original_model_name = ee_config["original_model"]
38
 
39
- # 2. Tokenizer from the EE model
40
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
41
 
42
- # 3. Load ONLY the original model's embed_tokens layer β€” we don't need the full LLM,
43
- # but HF doesn't support partial loading so we load it fully then discard the rest.
44
- # float32 on CPU is fine β€” we're only doing one embedding lookup, no generation.
45
  original_model = AutoModelForCausalLM.from_pretrained(
46
  original_model_name,
47
- torch_dtype=torch.float32,
48
  device_map="cpu",
49
  trust_remote_code=True,
50
  )
51
  embed_layer = original_model.model.embed_tokens
52
  embed_layer.eval()
53
- del original_model # free memory β€” we only need the embed layer
54
 
55
  _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
56
  return tokenizer, embed_layer, hidden_size
@@ -73,23 +75,25 @@ def index():
73
  try:
74
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
75
 
76
- # Derive encryption permutation from secret seed
77
- sigma = get_sigma(hidden_size, ee_seed)
78
-
79
- # Tokenize
80
  inputs = tokenizer(prompt, return_tensors="pt")
 
81
 
82
- # Compute plain embeddings from original model's embed layer
83
  with torch.no_grad():
84
- normal_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
85
 
86
- # Encrypt: permute hidden dimension with secret key
87
- # Server sees only scrambled vectors β€” can't recover original prompt
88
- encrypted_embeds = normal_embeds[..., sigma] # (1, seq_len, hidden)
 
 
 
89
 
90
- # Cast to float16 to match server model dtype
91
  encrypted_embeds = encrypted_embeds.to(torch.float16)
92
 
 
93
  payload = {
94
  "encrypted_embeds": encrypted_embeds.tolist(),
95
  "attention_mask": inputs.attention_mask.tolist(),
@@ -102,13 +106,20 @@ def index():
102
  timeout=300,
103
  )
104
 
105
- # Surface the server's error body if it returns non-2xx
106
  if not resp.ok:
107
  raise RuntimeError(
108
- f"Server returned {resp.status_code}: {resp.text[:500]}"
109
  )
110
 
111
- gen_ids = resp.json()["generated_ids"]
 
 
 
 
 
 
 
 
112
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
113
 
114
  except RuntimeError as e:
 
8
 
9
  app = Flask(__name__)
10
 
11
+ # Cache per EE model name so repeated requests don't re-download
12
  _cache = {}
13
 
14
 
15
  def get_sigma(hidden_size: int, seed: int):
16
+ """Derive the hidden-dimension permutation from the secret seed."""
17
  rng = np.random.default_rng(seed)
18
  return rng.permutation(hidden_size)
19
 
20
 
21
  def load_client_components(ee_model_name: str):
22
  """
23
+ Load and cache everything the client needs:
24
+ - ee_config β†’ hidden_size + original model name
25
+ - tokenizer β†’ from the EE model
26
+ - embed_layer β†’ from the ORIGINAL (unmodified) model
27
+
28
+ Why we need the original embed layer:
29
+ The EE model's weights were permuted with sigma, but its embedding table was
30
+ NOT permuted (it maps token IDs β†’ plain vectors). The client must embed with
31
+ the original model and then apply sigma to produce the scrambled vectors the
32
+ EE model expects.
33
  """
34
  if ee_model_name in _cache:
35
  return _cache[ee_model_name]
36
 
 
37
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
38
  with open(config_path) as f:
39
  ee_config = json.load(f)
 
41
  hidden_size = ee_config["hidden_size"]
42
  original_model_name = ee_config["original_model"]
43
 
 
44
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
45
 
46
+ # We only need embed_tokens β€” load the full model then discard everything else
 
 
47
  original_model = AutoModelForCausalLM.from_pretrained(
48
  original_model_name,
49
+ torch_dtype=torch.float32, # float32 for precision on CPU
50
  device_map="cpu",
51
  trust_remote_code=True,
52
  )
53
  embed_layer = original_model.model.embed_tokens
54
  embed_layer.eval()
55
+ del original_model # free RAM β€” we only keep the embed layer
56
 
57
  _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
58
  return tokenizer, embed_layer, hidden_size
 
75
  try:
76
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
77
 
78
+ # --- Step 1: tokenize ---
 
 
 
79
  inputs = tokenizer(prompt, return_tensors="pt")
80
+ input_ids = inputs.input_ids # (1, seq_len)
81
 
82
+ # --- Step 2: embed with ORIGINAL model's embed layer ---
83
  with torch.no_grad():
84
+ plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
85
 
86
+ # --- Step 3: ENCRYPT β€” permute hidden dim with secret sigma ---
87
+ # The EE model's weight matrices were pre-permuted with sigma,
88
+ # so feeding sigma-permuted embeddings is equivalent to feeding
89
+ # plain embeddings to the original model.
90
+ sigma = get_sigma(hidden_size, ee_seed)
91
+ encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
92
 
93
+ # Match server model dtype (float16)
94
  encrypted_embeds = encrypted_embeds.to(torch.float16)
95
 
96
+ # --- Step 4: send to server ---
97
  payload = {
98
  "encrypted_embeds": encrypted_embeds.tolist(),
99
  "attention_mask": inputs.attention_mask.tolist(),
 
106
  timeout=300,
107
  )
108
 
 
109
  if not resp.ok:
110
  raise RuntimeError(
111
+ f"Server {resp.status_code}: {resp.text[:600]}"
112
  )
113
 
114
+ body = resp.json()
115
+ if "error" in body:
116
+ raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")
117
+
118
+ # --- Step 5: decode ---
119
+ # No decryption needed on the output β€” the EE model's lm_head was
120
+ # also permuted so output logits correctly map to the real vocabulary.
121
+ # We skip special tokens and strip the prompt echo if present.
122
+ gen_ids = body["generated_ids"]
123
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
124
 
125
  except RuntimeError as e: