broadfield-dev commited on
Commit
0e77718
·
verified ·
1 Parent(s): e132351

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import numpy as np
5
+ import requests, json
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ app = Flask(__name__)
9
+
10
+ @app.route("/", methods=["GET", "POST"])
11
+ def index():
12
+ result = None
13
+ if request.method == "POST":
14
+ server_url = request.form["server_url"].rstrip("/")
15
+ prompt = request.form["prompt"]
16
+ ee_seed = int(request.form["ee_seed"])
17
+ ee_model_name = request.form["ee_model_name"]
18
+ max_tokens = int(request.form["max_tokens"])
19
+
20
+ try:
21
+ # Load tokenizer + ee_config from the EE model
22
+ tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
23
+ config_path = hf_hub_download(ee_model_name, "ee_config.json")
24
+ with open(config_path) as f:
25
+ ee_config = json.load(f)
26
+
27
+ # Load only embedding layer from EE model (transformed!)
28
+ embed_layer = AutoModelForCausalLM.from_pretrained(
29
+ ee_model_name, trust_remote_code=True, device_map="cpu"
30
+ ).model.embed_tokens
31
+
32
+ # Tokenize + get encrypted embeddings
33
+ inputs = tokenizer(prompt, return_tensors="pt")
34
+ with torch.no_grad():
35
+ embeds = embed_layer(inputs.input_ids) # already "encrypted" because we loaded transformed embed
36
+
37
+ # Send to server
38
+ payload = {
39
+ "encrypted_embeds": embeds.tolist(),
40
+ "attention_mask": inputs.attention_mask.tolist(),
41
+ "max_new_tokens": max_tokens
42
+ }
43
+ resp = requests.post(f"{server_url}/generate", json=payload, timeout=180)
44
+ resp.raise_for_status()
45
+ gen_ids = resp.json()["generated_ids"]
46
+
47
+ result = tokenizer.decode(gen_ids, skip_special_tokens=True)
48
+
49
+ except Exception as e:
50
+ result = f"Error: {str(e)}"
51
+
52
+ return render_template("client.html", result=result)
53
+
54
+ if __name__ == "__main__":
55
+ app.run(host="0.0.0.0", port=7860)