broadfield-dev commited on
Commit
ba9a967
·
verified ·
1 Parent(s): 71e6b73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -1,21 +1,30 @@
1
  from flask import Flask, render_template, request, flash, jsonify
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from huggingface_hub import login, whoami, HfApi
5
- import numpy as np
6
  import os, json
7
 
8
  app = Flask(__name__)
9
  app.secret_key = os.urandom(24)
10
 
11
- # Global for running server mode
12
  ee_model = None
13
  ee_tokenizer = None
14
  ee_config = None
 
 
 
 
 
 
 
 
 
15
 
16
  @app.route("/", methods=["GET", "POST"])
17
  def index():
18
- global ee_model, ee_tokenizer, ee_config
 
19
  if request.method == "POST":
20
  action = request.form.get("action")
21
 
@@ -25,40 +34,49 @@ def index():
25
 
26
  try:
27
  login(token=hf_token)
28
- global ee_model, ee_tokenizer, ee_config
29
 
30
  ee_model = AutoModelForCausalLM.from_pretrained(
31
  ee_model_name,
32
  torch_dtype=torch.float16,
33
- #load_in_4bit=True,
34
  device_map="auto",
35
  trust_remote_code=True
36
  )
37
- ee_tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
 
 
38
 
39
- # Load config
40
  from huggingface_hub import hf_hub_download
41
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
42
  with open(config_path) as f:
43
  ee_config = json.load(f)
44
 
45
- flash(f"✅ Server ready! Model loaded: {ee_model_name}", "success")
46
- flash("Now use the Client Space and point it to this Space's URL", "info")
 
47
 
48
  except Exception as e:
49
  flash(f"Error: {str(e)}", "danger")
50
 
51
- return render_template("index.html")
 
 
 
 
 
 
52
 
53
- # === INFERENCE ENDPOINT (always available when model is loaded) ===
54
  @app.route("/generate", methods=["POST"])
55
  def generate():
56
  if ee_model is None:
57
  return jsonify({"error": "Server not started yet"}), 400
58
 
59
  data = request.json
60
- encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(ee_model.device) # (1, seq, hidden)
61
- attention_mask = torch.tensor(data.get("attention_mask", [[1]*encrypted_embeds.shape[1]])).to(ee_model.device)
 
 
62
  max_new = int(data.get("max_new_tokens", 256))
63
 
64
  with torch.no_grad():
@@ -69,10 +87,11 @@ def generate():
69
  do_sample=True,
70
  temperature=0.7,
71
  top_p=0.9,
72
- pad_token_id=ee_tokenizer.eos_token_id
73
  )
74
 
75
  return jsonify({"generated_ids": output_ids[0].tolist()})
76
 
 
77
  if __name__ == "__main__":
78
  app.run(host="0.0.0.0", port=7860)
 
1
  from flask import Flask, render_template, request, flash, jsonify
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import login
 
5
  import os, json
6
 
7
  app = Flask(__name__)
8
  app.secret_key = os.urandom(24)
9
 
10
+ # Globals for running server mode
11
  ee_model = None
12
  ee_tokenizer = None
13
  ee_config = None
14
+ loaded_model_name = None
15
+
16
+ # Detect the HF Space URL automatically, fallback to localhost
17
+ SPACE_HOST = os.environ.get("SPACE_HOST", "")
18
+ if SPACE_HOST:
19
+ SPACE_URL = f"https://{SPACE_HOST}"
20
+ else:
21
+ SPACE_URL = "http://localhost:7860"
22
+
23
 
24
  @app.route("/", methods=["GET", "POST"])
25
  def index():
26
+ global ee_model, ee_tokenizer, ee_config, loaded_model_name
27
+
28
  if request.method == "POST":
29
  action = request.form.get("action")
30
 
 
34
 
35
  try:
36
  login(token=hf_token)
 
37
 
38
  ee_model = AutoModelForCausalLM.from_pretrained(
39
  ee_model_name,
40
  torch_dtype=torch.float16,
 
41
  device_map="auto",
42
  trust_remote_code=True
43
  )
44
+ ee_tokenizer = AutoTokenizer.from_pretrained(
45
+ ee_model_name, trust_remote_code=True
46
+ )
47
 
48
+ # Load EE config
49
  from huggingface_hub import hf_hub_download
50
  config_path = hf_hub_download(ee_model_name, "ee_config.json")
51
  with open(config_path) as f:
52
  ee_config = json.load(f)
53
 
54
+ loaded_model_name = ee_model_name
55
+ flash(f"Model loaded successfully: {ee_model_name}", "success")
56
+ flash("Point your Client Space to this Space's URL below.", "info")
57
 
58
  except Exception as e:
59
  flash(f"Error: {str(e)}", "danger")
60
 
61
+ return render_template(
62
+ "index.html",
63
+ server_ready=(ee_model is not None),
64
+ model_name=loaded_model_name,
65
+ space_url=SPACE_URL,
66
+ )
67
+
68
 
69
+ # === INFERENCE ENDPOINT ===
70
  @app.route("/generate", methods=["POST"])
71
  def generate():
72
  if ee_model is None:
73
  return jsonify({"error": "Server not started yet"}), 400
74
 
75
  data = request.json
76
+ encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(ee_model.device)
77
+ attention_mask = torch.tensor(
78
+ data.get("attention_mask", [[1] * encrypted_embeds.shape[1]])
79
+ ).to(ee_model.device)
80
  max_new = int(data.get("max_new_tokens", 256))
81
 
82
  with torch.no_grad():
 
87
  do_sample=True,
88
  temperature=0.7,
89
  top_p=0.9,
90
+ pad_token_id=ee_tokenizer.eos_token_id,
91
  )
92
 
93
  return jsonify({"generated_ids": output_ids[0].tolist()})
94
 
95
+
96
  if __name__ == "__main__":
97
  app.run(host="0.0.0.0", port=7860)