broadfield-dev commited on
Commit
aa316bb
Β·
verified Β·
1 Parent(s): 2c143bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -8,22 +8,27 @@ from huggingface_hub import hf_hub_download
8
 
9
  app = Flask(__name__)
10
 
11
- # Cache tokenizer/embed layer so repeated requests don't reload from scratch
12
  _cache = {}
13
 
14
 
15
  def get_sigma(hidden_size: int, seed: int):
16
- """Derive client-side encryption permutation from secret seed."""
17
  rng = np.random.default_rng(seed)
18
  return rng.permutation(hidden_size)
19
 
20
 
21
  def load_client_components(ee_model_name: str):
22
- """Load (and cache) tokenizer + original embed layer for a given EE model."""
 
 
 
 
 
23
  if ee_model_name in _cache:
24
  return _cache[ee_model_name]
25
 
26
- # 1. Fetch EE config to discover hidden_size + original model name
27
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
28
  with open(config_path) as f:
29
  ee_config = json.load(f)
@@ -31,17 +36,21 @@ def load_client_components(ee_model_name: str):
31
  hidden_size = ee_config["hidden_size"]
32
  original_model_name = ee_config["original_model"]
33
 
34
- # 2. Load tokenizer (from the EE model)
35
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
36
 
37
- # 3. Load ONLY the original embedding layer (CPU is fine β€” no forward pass needed)
38
- embed_model = AutoModelForCausalLM.from_pretrained(
 
 
39
  original_model_name,
40
- torch_dtype=torch.float16,
41
  device_map="cpu",
42
  trust_remote_code=True,
43
  )
44
- embed_layer = embed_model.model.embed_tokens
 
 
45
 
46
  _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
47
  return tokenizer, embed_layer, hidden_size
@@ -55,29 +64,32 @@ def index():
55
 
56
  if request.method == "POST":
57
  form_data = request.form.to_dict()
58
- server_url = request.form["server_url"].rstrip("/")
59
  ee_model_name = request.form["ee_model_name"].strip()
60
- ee_seed = int(request.form["ee_seed"])
61
- prompt = request.form["prompt"].strip()
62
- max_tokens = int(request.form.get("max_tokens", 256))
63
 
64
  try:
65
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
66
 
67
- # Derive encryption key
68
  sigma = get_sigma(hidden_size, ee_seed)
69
 
70
  # Tokenize
71
  inputs = tokenizer(prompt, return_tensors="pt")
72
 
73
- # Compute plain embeddings
74
  with torch.no_grad():
75
  normal_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
76
 
77
- # Encrypt: permute hidden dimension β€” server sees only scrambled vectors
78
- encrypted_embeds = normal_embeds[..., sigma]
 
 
 
 
79
 
80
- # Send to server
81
  payload = {
82
  "encrypted_embeds": encrypted_embeds.tolist(),
83
  "attention_mask": inputs.attention_mask.tolist(),
@@ -89,17 +101,22 @@ def index():
89
  json=payload,
90
  timeout=300,
91
  )
92
- resp.raise_for_status()
 
 
 
 
 
93
 
94
  gen_ids = resp.json()["generated_ids"]
95
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
96
 
 
 
97
  except requests.exceptions.ConnectionError:
98
- error = f"Could not connect to server at {server_url}. Is it running?"
99
- except requests.exceptions.HTTPError as e:
100
- error = f"Server returned an error: {e.response.status_code} β€” {e.response.text}"
101
  except Exception as e:
102
- error = str(e)
103
 
104
  return render_template("client.html", result=result, error=error, form=form_data)
105
 
 
8
 
9
  app = Flask(__name__)
10
 
11
+ # Cache tokenizer + embed layer so repeated requests don't reload
12
  _cache = {}
13
 
14
 
15
  def get_sigma(hidden_size: int, seed: int):
16
+ """Derive encryption permutation from secret seed."""
17
  rng = np.random.default_rng(seed)
18
  return rng.permutation(hidden_size)
19
 
20
 
21
  def load_client_components(ee_model_name: str):
22
+ """
23
+ Load (and cache) only what the client needs:
24
+ - tokenizer from the EE model
25
+ - embedding layer from the ORIGINAL model (just embed_tokens, not the full LLM)
26
+ - hidden_size from ee_config
27
+ """
28
  if ee_model_name in _cache:
29
  return _cache[ee_model_name]
30
 
31
+ # 1. Read EE config to get hidden_size + original model name
32
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
33
  with open(config_path) as f:
34
  ee_config = json.load(f)
 
36
  hidden_size = ee_config["hidden_size"]
37
  original_model_name = ee_config["original_model"]
38
 
39
+ # 2. Tokenizer from the EE model
40
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
41
 
42
+ # 3. Load ONLY the original model's embed_tokens layer β€” we don't need the full LLM,
43
+ # but HF doesn't support partial loading so we load it fully then discard the rest.
44
+ # float32 on CPU is fine β€” we're only doing one embedding lookup, no generation.
45
+ original_model = AutoModelForCausalLM.from_pretrained(
46
  original_model_name,
47
+ torch_dtype=torch.float32,
48
  device_map="cpu",
49
  trust_remote_code=True,
50
  )
51
+ embed_layer = original_model.model.embed_tokens
52
+ embed_layer.eval()
53
+ del original_model # free memory β€” we only need the embed layer
54
 
55
  _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
56
  return tokenizer, embed_layer, hidden_size
 
64
 
65
  if request.method == "POST":
66
  form_data = request.form.to_dict()
67
+ server_url = request.form["server_url"].rstrip("/")
68
  ee_model_name = request.form["ee_model_name"].strip()
69
+ ee_seed = int(request.form["ee_seed"])
70
+ prompt = request.form["prompt"].strip()
71
+ max_tokens = int(request.form.get("max_tokens", 256))
72
 
73
  try:
74
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
75
 
76
+ # Derive encryption permutation from secret seed
77
  sigma = get_sigma(hidden_size, ee_seed)
78
 
79
  # Tokenize
80
  inputs = tokenizer(prompt, return_tensors="pt")
81
 
82
+ # Compute plain embeddings from original model's embed layer
83
  with torch.no_grad():
84
  normal_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
85
 
86
+ # Encrypt: permute hidden dimension with secret key
87
+ # Server sees only scrambled vectors β€” can't recover original prompt
88
+ encrypted_embeds = normal_embeds[..., sigma] # (1, seq_len, hidden)
89
+
90
+ # Cast to float16 to match server model dtype
91
+ encrypted_embeds = encrypted_embeds.to(torch.float16)
92
 
 
93
  payload = {
94
  "encrypted_embeds": encrypted_embeds.tolist(),
95
  "attention_mask": inputs.attention_mask.tolist(),
 
101
  json=payload,
102
  timeout=300,
103
  )
104
+
105
+ # Surface the server's error body if it returns non-2xx
106
+ if not resp.ok:
107
+ raise RuntimeError(
108
+ f"Server returned {resp.status_code}: {resp.text[:500]}"
109
+ )
110
 
111
  gen_ids = resp.json()["generated_ids"]
112
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
113
 
114
+ except RuntimeError as e:
115
+ error = str(e)
116
  except requests.exceptions.ConnectionError:
117
+ error = f"Could not connect to {server_url} β€” is the server Space running?"
 
 
118
  except Exception as e:
119
+ error = f"{type(e).__name__}: {e}"
120
 
121
  return render_template("client.html", result=result, error=error, form=form_data)
122