broadfield-dev commited on
Commit
9e6e352
·
verified ·
1 Parent(s): 5cad3e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -52
app.py CHANGED
@@ -2,17 +2,14 @@ from flask import Flask, render_template, request, flash, jsonify
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
- import traceback
6
  import os, json
7
 
8
  app = Flask(__name__)
9
  app.secret_key = os.urandom(24)
10
 
11
- # Globals
12
  ee_model = None
13
  ee_tokenizer = None
14
  ee_config = None
15
- loaded_model_name = None
16
 
17
  SPACE_HOST = os.environ.get("SPACE_HOST", "")
18
  SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
@@ -20,91 +17,99 @@ SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
20
 
21
  @app.route("/", methods=["GET", "POST"])
22
  def index():
23
- global ee_model, ee_tokenizer, ee_config, loaded_model_name
24
 
25
  if request.method == "POST":
26
- action = request.form.get("action")
 
27
 
28
- if action == "start_server":
29
- ee_model_name = request.form["ee_model_name"].strip()
30
- hf_token = request.form["hf_token"].strip()
31
 
32
- try:
33
- login(token=hf_token)
34
-
35
- ee_model = AutoModelForCausalLM.from_pretrained(
36
- ee_model_name,
37
- torch_dtype=torch.float16,
38
- device_map="auto",
39
- trust_remote_code=True,
40
- )
41
- ee_tokenizer = AutoTokenizer.from_pretrained(
42
- ee_model_name, trust_remote_code=True
43
- )
44
 
45
- from huggingface_hub import hf_hub_download
46
- config_path = hf_hub_download(ee_model_name, "ee_config.json")
47
- with open(config_path) as f:
48
- ee_config = json.load(f)
49
 
50
- loaded_model_name = ee_model_name
51
- flash(f"Model loaded successfully: {ee_model_name}", "success")
52
- flash("Point your Client Space to this Space's URL below.", "info")
53
 
54
- except Exception as e:
55
- flash(f"Error: {str(e)}", "danger")
56
 
57
  return render_template(
58
  "index.html",
59
  server_ready=(ee_model is not None),
60
- model_name=loaded_model_name,
61
  space_url=SPACE_URL,
62
  )
63
 
64
 
65
  @app.route("/generate", methods=["POST"])
66
  def generate():
 
 
 
 
 
 
67
  if ee_model is None:
68
- return jsonify({"error": "Server not started yet — load a model first"}), 400
69
 
70
  try:
71
  data = request.json
72
- if data is None:
73
- return jsonify({"error": "Request body must be JSON"}), 400
74
-
75
  model_dtype = next(ee_model.parameters()).dtype
76
 
77
- # Cast incoming embeddings to model dtype + move to device
78
- encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(
79
  dtype=model_dtype, device=ee_model.device
80
- ) # (1, seq_len, hidden)
81
-
82
- input_seq_len = encrypted_embeds.shape[1]
83
 
84
  attention_mask = torch.tensor(
85
- data.get("attention_mask", [[1] * input_seq_len])
86
  ).to(device=ee_model.device)
87
 
88
- max_new = int(data.get("max_new_tokens", 256))
 
 
 
 
 
 
 
 
89
 
90
  with torch.no_grad():
91
- output_ids = ee_model.generate(
92
- inputs_embeds=encrypted_embeds,
93
  attention_mask=attention_mask,
94
- max_new_tokens=max_new,
95
- do_sample=True,
96
- temperature=0.7,
97
- top_p=0.9,
98
- pad_token_id=ee_tokenizer.eos_token_id,
99
  )
100
 
101
- # output_ids includes the full sequence; return only the newly generated tokens
102
- # (the client sent embeddings, not IDs, so output starts at position 0)
103
- new_ids = output_ids[0].tolist()
 
 
 
 
104
 
105
- return jsonify({"generated_ids": new_ids})
 
 
 
106
 
107
  except Exception as e:
 
108
  return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
109
 
110
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
 
5
  import os, json
6
 
7
  app = Flask(__name__)
8
  app.secret_key = os.urandom(24)
9
 
 
10
  ee_model = None
11
  ee_tokenizer = None
12
  ee_config = None
 
13
 
14
  SPACE_HOST = os.environ.get("SPACE_HOST", "")
15
  SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
 
17
 
18
  @app.route("/", methods=["GET", "POST"])
19
  def index():
20
+ global ee_model, ee_tokenizer, ee_config
21
 
22
  if request.method == "POST":
23
+ ee_model_name = request.form["ee_model_name"].strip()
24
+ hf_token = request.form["hf_token"].strip()
25
 
26
+ try:
27
+ login(token=hf_token)
 
28
 
29
+ ee_model = AutoModelForCausalLM.from_pretrained(
30
+ ee_model_name, torch_dtype=torch.float16,
31
+ device_map="auto", trust_remote_code=True
32
+ )
33
+ ee_tokenizer = AutoTokenizer.from_pretrained(
34
+ ee_model_name, trust_remote_code=True
35
+ )
 
 
 
 
 
36
 
37
+ from huggingface_hub import hf_hub_download
38
+ config_path = hf_hub_download(ee_model_name, "ee_config.json")
39
+ with open(config_path) as f:
40
+ ee_config = json.load(f)
41
 
42
+ flash(f"✅ Model loaded: {ee_model_name}", "success")
43
+ flash("Point your Client Space to this Space's URL below.", "info")
 
44
 
45
+ except Exception as e:
46
+ flash(f"Error: {str(e)}", "danger")
47
 
48
  return render_template(
49
  "index.html",
50
  server_ready=(ee_model is not None),
51
+ model_name=ee_config["original_model"] if ee_config else None,
52
  space_url=SPACE_URL,
53
  )
54
 
55
 
56
  @app.route("/generate", methods=["POST"])
57
  def generate():
58
+ """
59
+ Receives sigma-encrypted embeddings + optional past_key_values.
60
+ Returns last hidden state (still in sigma-space) + new KV cache.
61
+ Does NOT run lm_head — that stays on the client.
62
+ Server never sees token IDs, logits, or plaintext.
63
+ """
64
  if ee_model is None:
65
+ return jsonify({"error": "Server not started yet"}), 400
66
 
67
  try:
68
  data = request.json
 
 
 
69
  model_dtype = next(ee_model.parameters()).dtype
70
 
71
+ inputs_embeds = torch.tensor(data["inputs_embeds"]).to(
 
72
  dtype=model_dtype, device=ee_model.device
73
+ )
 
 
74
 
75
  attention_mask = torch.tensor(
76
+ data.get("attention_mask", [[1] * inputs_embeds.shape[1]])
77
  ).to(device=ee_model.device)
78
 
79
+ past_key_values = None
80
+ if data.get("past_key_values"):
81
+ past_key_values = tuple(
82
+ tuple(
83
+ torch.tensor(t).to(dtype=model_dtype, device=ee_model.device)
84
+ for t in layer
85
+ )
86
+ for layer in data["past_key_values"]
87
+ )
88
 
89
  with torch.no_grad():
90
+ out = ee_model(
91
+ inputs_embeds=inputs_embeds,
92
  attention_mask=attention_mask,
93
+ past_key_values=past_key_values,
94
+ use_cache=True,
95
+ output_hidden_states=True,
 
 
96
  )
97
 
98
+ # Return final hidden state in sigma-space client decrypts + runs lm_head
99
+ last_hidden = out.hidden_states[-1] # (1, seq_len, hidden)
100
+
101
+ new_past = [
102
+ [t.cpu().tolist() for t in layer]
103
+ for layer in out.past_key_values
104
+ ]
105
 
106
+ return jsonify({
107
+ "last_hidden": last_hidden.cpu().tolist(),
108
+ "past_key_values": new_past,
109
+ })
110
 
111
  except Exception as e:
112
+ import traceback
113
  return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
114
 
115