Hivra commited on
Commit
16a4437
Β·
verified Β·
1 Parent(s): 45ea54e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +71 -71
main.py CHANGED
@@ -5,22 +5,35 @@ import time
5
  import json
6
  import requests
7
  from functools import wraps
8
- from flask import Flask, request, Response, jsonify, abort
9
  import gevent.pywsgi
10
  from gevent import monkey; monkey.patch_all()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  app = Flask(__name__)
13
  app.json.sort_keys = False
14
 
15
- # β€”β€”β€” Load and enforce API key β€”β€”β€”
16
- API_KEY = os.getenv("API_KEY")
17
- if not API_KEY:
18
- raise RuntimeError("Missing API_KEY in env")
 
19
 
 
20
  def require_api_key(f):
21
  @wraps(f)
22
  def decorated(*args, **kwargs):
23
- # look in X-API-Key header or Authorization: Bearer ...
24
  key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "")
25
  if key.startswith("Bearer "):
26
  key = key.split(" ", 1)[1]
@@ -31,96 +44,77 @@ def require_api_key(f):
31
  return f(*args, **kwargs)
32
  return decorated
33
 
34
- # β€”β€”β€” Your existing arg parsing & base_url β€”β€”β€”
35
- import argparse
36
- parser = argparse.ArgumentParser()
37
- parser.add_argument("--host", default="0.0.0.0")
38
- parser.add_argument("--port", type=int, default=7860)
39
- args = parser.parse_args()
40
-
41
- base_url = os.getenv("MODEL_BASE_URL")
42
-
43
-
44
- @app.route('/api/v1/models', methods=["GET", "POST"])
45
- @app.route('/v1/models', methods=["GET", "POST"])
46
  @require_api_key
47
  def model_list():
48
- time_now = int(time.time())
49
- models = {
50
  "object": "list",
51
  "data": [
52
- {"id": "glm-4", "object": "model", "created": time_now, "owned_by": "tastypear"},
53
- {"id": "gpt-3.5-turbo","object": "model", "created": time_now, "owned_by": "tastypear"}
54
  ]
55
- }
56
- return jsonify(models)
57
-
58
 
 
59
  @app.route("/", methods=["GET"])
60
  def index():
61
  return Response(
62
  f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'
63
- f'Set "{os.getenv("SPACE_URL")}/api" as proxy in your Chatbot.<br><br>'
64
- f'Full API: {os.getenv("SPACE_URL")}/api/v1/chat/completions'
65
  )
66
 
67
-
68
  @app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
69
  @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
70
  @require_api_key
71
  def chat_completions():
72
  if request.method == "OPTIONS":
73
  return Response(
74
- headers={
75
- "Access-Control-Allow-Origin": "*",
76
- "Access-Control-Allow-Headers": "*",
77
- }
78
  )
79
 
80
  data = request.get_json() or {}
81
- if "messages" not in data:
 
82
  return jsonify({"error": "Missing 'messages' field"}), 400
83
 
84
- # reorganize data
85
  system = "You are a helpful assistant."
86
  chat_history = []
87
- messages = data["messages"]
88
- prompt = messages[-1].get("content", "")
89
-
90
- for i in range(len(messages) - 1):
91
- r0 = messages[i].get("role")
92
- r1 = messages[i+1].get("role")
93
- if r0 == "system":
94
- system = messages[i]["content"]
95
- elif r0 == "user":
96
- if r1 == "assistant":
97
- chat_history.append([messages[i]["content"], messages[i+1]["content"]])
98
  else:
99
- chat_history.append([messages[i]["content"], " "])
100
-
101
- # random session id
102
- session_hash = "".join(random.choices(string.ascii_lowercase + string.digits, k=11))
103
- json_prompt = {
104
- "data": [prompt, chat_history, system],
105
- "fn_index": 0,
106
- "session_hash": session_hash,
107
- }
108
 
109
- def generate():
110
- # enqueue job
111
- requests.post(f"{base_url}/queue/join", json=json_prompt)
112
- url = f"{base_url}/queue/data?session_hash={session_hash}"
113
- stream = requests.get(url, stream=True)
114
 
 
 
 
 
 
115
  start_time = int(time.time())
116
- for line in stream.iter_lines():
 
117
  if not line:
118
  continue
119
  msg = json.loads(line.decode("utf-8")[6:])
120
  if msg["msg"] == "process_starts":
121
- chunk = gen_res_data({}, time_now=start_time, start=True)
122
  elif msg["msg"] == "process_generating":
123
- chunk = gen_res_data(msg, time_now=start_time)
124
  elif msg["msg"] == "process_completed":
125
  yield "data: [DONE]"
126
  break
@@ -131,14 +125,12 @@ def chat_completions():
131
  return Response(
132
  generate(),
133
  mimetype="text/event-stream",
134
- headers={
135
- "Access-Control-Allow-Origin": "*",
136
- "Access-Control-Allow-Headers": "*",
137
- },
138
  )
139
 
140
-
141
- def gen_res_data(data, time_now=0, start=False):
 
142
  base = {
143
  "id": "chatcmpl",
144
  "object": "chat.completion.chunk",
@@ -149,13 +141,21 @@ def gen_res_data(data, time_now=0, start=False):
149
  if start:
150
  base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
151
  else:
152
- chat_pair = data["output"]["data"][1]
153
  if not chat_pair:
154
  base["choices"][0]["finish_reason"] = "stop"
155
  else:
156
  base["choices"][0]["delta"] = {"content": chat_pair[-1][-1]}
157
  return base
158
 
159
-
160
  if __name__ == "__main__":
161
- gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
 
 
 
 
 
 
 
 
 
 
5
  import json
6
  import requests
7
  from functools import wraps
8
+ from flask import Flask, request, Response, jsonify
9
  import gevent.pywsgi
10
  from gevent import monkey; monkey.patch_all()
11
 
12
+ # β€”β€” Load & check env vars β€”β€”
13
+ API_KEY = os.getenv("API_KEY")
14
+ if not API_KEY:
15
+ raise RuntimeError("Missing API_KEY env var")
16
+
17
+ MODEL_BASE_URL = os.getenv("MODEL_BASE_URL")
18
+ if not MODEL_BASE_URL:
19
+ raise RuntimeError("Missing MODEL_BASE_URL env var")
20
+
21
+ SPACE_URL = os.getenv("SPACE_URL", "")
22
+
23
+ # β€”β€” Flask setup β€”β€”
24
  app = Flask(__name__)
25
  app.json.sort_keys = False
26
 
27
+ # β€”β€” Error handler to show real errors β€”β€”
28
+ @app.errorhandler(Exception)
29
+ def handle_all_errors(e):
30
+ app.logger.exception(e)
31
+ return jsonify({"error": str(e)}), 500
32
 
33
+ # β€”β€” API‑key decorator β€”β€”
34
  def require_api_key(f):
35
  @wraps(f)
36
  def decorated(*args, **kwargs):
 
37
  key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "")
38
  if key.startswith("Bearer "):
39
  key = key.split(" ", 1)[1]
 
44
  return f(*args, **kwargs)
45
  return decorated
46
 
47
+ # β€”β€” Model list endpoint β€”β€”
48
+ @app.route('/api/v1/models', methods=['GET', 'POST'])
49
+ @app.route('/v1/models', methods=['GET', 'POST'])
 
 
 
 
 
 
 
 
 
50
  @require_api_key
51
  def model_list():
52
+ now = int(time.time())
53
+ return jsonify({
54
  "object": "list",
55
  "data": [
56
+ {"id": "glm-4", "object": "model", "created": now, "owned_by": "tastypear"},
57
+ {"id": "gpt-3.5-turbo", "object": "model", "created": now, "owned_by": "tastypear"}
58
  ]
59
+ })
 
 
60
 
61
+ # β€”β€” Home page β€”β€”
62
  @app.route("/", methods=["GET"])
63
  def index():
64
  return Response(
65
  f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'
66
+ f'Set "{SPACE_URL}/api" as proxy in your Chatbot.<br><br>'
67
+ f'Full API: {SPACE_URL}/api/v1/chat/completions'
68
  )
69
 
70
+ # β€”β€” Chat completions β€”β€”
71
  @app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
72
  @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
73
  @require_api_key
74
  def chat_completions():
75
  if request.method == "OPTIONS":
76
  return Response(
77
+ headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"}
 
 
 
78
  )
79
 
80
  data = request.get_json() or {}
81
+ messages = data.get("messages")
82
+ if not messages:
83
  return jsonify({"error": "Missing 'messages' field"}), 400
84
 
85
+ # parse messages
86
  system = "You are a helpful assistant."
87
  chat_history = []
88
+ for i, msg in enumerate(messages[:-1]):
89
+ role = msg.get("role")
90
+ if role == "system":
91
+ system = msg.get("content", system)
92
+ elif role == "user":
93
+ next_role = messages[i+1].get("role")
94
+ if next_role == "assistant":
95
+ chat_history.append([msg.get("content",""), messages[i+1].get("content","")])
 
 
 
96
  else:
97
+ chat_history.append([msg.get("content",""), ""])
 
 
 
 
 
 
 
 
98
 
99
+ prompt = messages[-1].get("content","")
100
+ session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11))
101
+ json_prompt = {"data":[prompt, chat_history, system], "fn_index":0, "session_hash":session_hash}
 
 
102
 
103
+ def generate():
104
+ # enqueue
105
+ requests.post(f"{MODEL_BASE_URL}/queue/join", json=json_prompt)
106
+ url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}"
107
+ resp = requests.get(url, stream=True)
108
  start_time = int(time.time())
109
+
110
+ for line in resp.iter_lines():
111
  if not line:
112
  continue
113
  msg = json.loads(line.decode("utf-8")[6:])
114
  if msg["msg"] == "process_starts":
115
+ chunk = gen_res_data({}, start=True, time_now=start_time)
116
  elif msg["msg"] == "process_generating":
117
+ chunk = gen_res_data(msg, start=False, time_now=start_time)
118
  elif msg["msg"] == "process_completed":
119
  yield "data: [DONE]"
120
  break
 
125
  return Response(
126
  generate(),
127
  mimetype="text/event-stream",
128
+ headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"},
 
 
 
129
  )
130
 
131
+ def gen_res_data(data, start=False, time_now=None):
132
+ if time_now is None:
133
+ time_now = int(time.time())
134
  base = {
135
  "id": "chatcmpl",
136
  "object": "chat.completion.chunk",
 
141
  if start:
142
  base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
143
  else:
144
+ chat_pair = data.get("output",{}).get("data",[None,None])[1]
145
  if not chat_pair:
146
  base["choices"][0]["finish_reason"] = "stop"
147
  else:
148
  base["choices"][0]["delta"] = {"content": chat_pair[-1][-1]}
149
  return base
150
 
 
151
  if __name__ == "__main__":
152
+ import argparse
153
+ parser = argparse.ArgumentParser()
154
+ parser.add_argument("--host", default="0.0.0.0")
155
+ parser.add_argument("--port", type=int, default=7860)
156
+ args = parser.parse_args()
157
+
158
+ # turn on debug so you see errors in browser
159
+ app.run(host=args.host, port=args.port, debug=True)
160
+ # once it works, you can switch back to:
161
+ # gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()