broadfield-dev commited on
Commit
3383b9c
·
verified ·
1 Parent(s): 3fb89d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -29
app.py CHANGED
@@ -2,23 +2,21 @@ from flask import Flask, render_template, request, flash, jsonify
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
 
5
  import os, json
6
 
7
  app = Flask(__name__)
8
  app.secret_key = os.urandom(24)
9
 
10
- # Globals for running server mode
11
  ee_model = None
12
  ee_tokenizer = None
13
  ee_config = None
14
  loaded_model_name = None
15
 
16
- # Detect the HF Space URL automatically, fallback to localhost
17
  SPACE_HOST = os.environ.get("SPACE_HOST", "")
18
- if SPACE_HOST:
19
- SPACE_URL = f"https://{SPACE_HOST}"
20
- else:
21
- SPACE_URL = "http://localhost:7860"
22
 
23
 
24
  @app.route("/", methods=["GET", "POST"])
@@ -39,13 +37,12 @@ def index():
39
  ee_model_name,
40
  torch_dtype=torch.float16,
41
  device_map="auto",
42
- trust_remote_code=True
43
  )
44
  ee_tokenizer = AutoTokenizer.from_pretrained(
45
  ee_model_name, trust_remote_code=True
46
  )
47
 
48
- # Load EE config
49
  from huggingface_hub import hf_hub_download
50
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
51
  with open(config_path) as f:
@@ -70,27 +67,42 @@ def index():
70
  @app.route("/generate", methods=["POST"])
71
  def generate():
72
  if ee_model is None:
73
- return jsonify({"error": "Server not started yet"}), 400
74
-
75
- data = request.json
76
- encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(ee_model.device)
77
- attention_mask = torch.tensor(
78
- data.get("attention_mask", [[1] * encrypted_embeds.shape[1]])
79
- ).to(ee_model.device)
80
- max_new = int(data.get("max_new_tokens", 256))
81
-
82
- with torch.no_grad():
83
- output_ids = ee_model.generate(
84
- inputs_embeds=encrypted_embeds,
85
- attention_mask=attention_mask,
86
- max_new_tokens=max_new,
87
- do_sample=True,
88
- temperature=0.7,
89
- top_p=0.9,
90
- pad_token_id=ee_tokenizer.eos_token_id,
91
- )
92
-
93
- return jsonify({"generated_ids": output_ids[0].tolist()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  if __name__ == "__main__":
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
+ import traceback
6
  import os, json
7
 
8
  app = Flask(__name__)
9
  app.secret_key = os.urandom(24)
10
 
11
+ # Globals
12
  ee_model = None
13
  ee_tokenizer = None
14
  ee_config = None
15
  loaded_model_name = None
16
 
17
+ # Detect HF Space URL automatically
18
  SPACE_HOST = os.environ.get("SPACE_HOST", "")
19
+ SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
 
 
 
20
 
21
 
22
  @app.route("/", methods=["GET", "POST"])
 
37
  ee_model_name,
38
  torch_dtype=torch.float16,
39
  device_map="auto",
40
+ trust_remote_code=True,
41
  )
42
  ee_tokenizer = AutoTokenizer.from_pretrained(
43
  ee_model_name, trust_remote_code=True
44
  )
45
 
 
46
  from huggingface_hub import hf_hub_download
47
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
48
  with open(config_path) as f:
 
67
  @app.route("/generate", methods=["POST"])
68
  def generate():
69
  if ee_model is None:
70
+ return jsonify({"error": "Server not started yet — load a model first"}), 400
71
+
72
+ try:
73
+ data = request.json
74
+ if data is None:
75
+ return jsonify({"error": "Request body must be JSON"}), 400
76
+
77
+ # Determine the model's actual dtype so we always match it
78
+ model_dtype = next(ee_model.parameters()).dtype
79
+
80
+ # Build tensors, cast to model dtype + move to device in one step
81
+ encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(
82
+ dtype=model_dtype, device=ee_model.device
83
+ ) # (1, seq_len, hidden)
84
+
85
+ attention_mask = torch.tensor(
86
+ data.get("attention_mask", [[1] * encrypted_embeds.shape[1]])
87
+ ).to(device=ee_model.device) # stays int64, that's correct
88
+
89
+ max_new = int(data.get("max_new_tokens", 256))
90
+
91
+ with torch.no_grad():
92
+ output_ids = ee_model.generate(
93
+ inputs_embeds=encrypted_embeds,
94
+ attention_mask=attention_mask,
95
+ max_new_tokens=max_new,
96
+ do_sample=True,
97
+ temperature=0.7,
98
+ top_p=0.9,
99
+ pad_token_id=ee_tokenizer.eos_token_id,
100
+ )
101
+
102
+ return jsonify({"generated_ids": output_ids[0].tolist()})
103
+
104
+ except Exception as e:
105
+ return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
106
 
107
 
108
  if __name__ == "__main__":