broadfield-dev commited on
Commit
5cad3e1
·
verified ·
1 Parent(s): 3383b9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -14,7 +14,6 @@ ee_tokenizer = None
14
  ee_config = None
15
  loaded_model_name = None
16
 
17
- # Detect HF Space URL automatically
18
  SPACE_HOST = os.environ.get("SPACE_HOST", "")
19
  SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
20
 
@@ -63,7 +62,6 @@ def index():
63
  )
64
 
65
 
66
- # === INFERENCE ENDPOINT ===
67
  @app.route("/generate", methods=["POST"])
68
  def generate():
69
  if ee_model is None:
@@ -74,17 +72,18 @@ def generate():
74
  if data is None:
75
  return jsonify({"error": "Request body must be JSON"}), 400
76
 
77
- # Determine the model's actual dtype so we always match it
78
  model_dtype = next(ee_model.parameters()).dtype
79
 
80
- # Build tensors, cast to model dtype + move to device in one step
81
  encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(
82
  dtype=model_dtype, device=ee_model.device
83
  ) # (1, seq_len, hidden)
84
 
 
 
85
  attention_mask = torch.tensor(
86
- data.get("attention_mask", [[1] * encrypted_embeds.shape[1]])
87
- ).to(device=ee_model.device) # stays int64, that's correct
88
 
89
  max_new = int(data.get("max_new_tokens", 256))
90
 
@@ -99,7 +98,11 @@ def generate():
99
  pad_token_id=ee_tokenizer.eos_token_id,
100
  )
101
 
102
- return jsonify({"generated_ids": output_ids[0].tolist()})
 
 
 
 
103
 
104
  except Exception as e:
105
  return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
 
14
  ee_config = None
15
  loaded_model_name = None
16
 
 
17
  SPACE_HOST = os.environ.get("SPACE_HOST", "")
18
  SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
19
 
 
62
  )
63
 
64
 
 
65
  @app.route("/generate", methods=["POST"])
66
  def generate():
67
  if ee_model is None:
 
72
  if data is None:
73
  return jsonify({"error": "Request body must be JSON"}), 400
74
 
 
75
  model_dtype = next(ee_model.parameters()).dtype
76
 
77
+ # Cast incoming embeddings to model dtype + move to device
78
  encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(
79
  dtype=model_dtype, device=ee_model.device
80
  ) # (1, seq_len, hidden)
81
 
82
+ input_seq_len = encrypted_embeds.shape[1]
83
+
84
  attention_mask = torch.tensor(
85
+ data.get("attention_mask", [[1] * input_seq_len])
86
+ ).to(device=ee_model.device)
87
 
88
  max_new = int(data.get("max_new_tokens", 256))
89
 
 
98
  pad_token_id=ee_tokenizer.eos_token_id,
99
  )
100
 
101
+ # output_ids includes the full sequence; return only the newly generated tokens
102
+ # (the client sent embeddings, not IDs, so output starts at position 0)
103
+ new_ids = output_ids[0].tolist()
104
+
105
+ return jsonify({"generated_ids": new_ids})
106
 
107
  except Exception as e:
108
  return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500