Hivra commited on
Commit
45ea54e
Β·
verified Β·
1 Parent(s): ed35952

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +97 -92
main.py CHANGED
@@ -1,57 +1,74 @@
1
- import gevent.pywsgi
2
- from gevent import monkey;monkey.patch_all()
3
- from flask import Flask, request, Response, jsonify
4
- import argparse
5
- import requests
6
  import random
7
  import string
8
  import time
9
  import json
10
- import os
 
 
 
 
11
 
12
  app = Flask(__name__)
13
  app.json.sort_keys = False
14
 
15
- parser = argparse.ArgumentParser(description="An example of Zhipu GLM-4 with a similar API to OAI.")
16
- parser.add_argument("--host", type=str, help="Set the ip address.(default: 0.0.0.0)", default='0.0.0.0')
17
- parser.add_argument("--port", type=int, help="Set the port.(default: 7860)", default=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  args = parser.parse_args()
19
 
20
- base_url = os.getenv('MODEL_BASE_URL')
 
21
 
22
  @app.route('/api/v1/models', methods=["GET", "POST"])
23
- @app.route('/v1/models', methods=["GET", "POST"])
 
24
  def model_list():
25
  time_now = int(time.time())
26
- model_list = {
27
  "object": "list",
28
  "data": [
29
- {
30
- "id": "glm-4",
31
- "object": "model",
32
- "created": time_now,
33
- "owned_by": "tastypear"
34
- },
35
- {
36
- "id": "gpt-3.5-turbo",
37
- "object": "model",
38
- "created": time_now,
39
- "owned_by": "tastypear"
40
- }
41
  ]
42
  }
43
- return jsonify(model_list)
 
44
 
45
  @app.route("/", methods=["GET"])
46
  def index():
47
- return Response(f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'+
48
- f'Set "{os.getenv("SPACE_URL")}/api" as proxy (or API Domain) in your Chatbot.<br><br>'+
49
- f'The complete API is: {os.getenv("SPACE_URL")}/api/v1/chat/completions')
 
 
 
50
 
51
  @app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
52
- @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
 
53
  def chat_completions():
54
-
55
  if request.method == "OPTIONS":
56
  return Response(
57
  headers={
@@ -60,66 +77,56 @@ def chat_completions():
60
  }
61
  )
62
 
63
- data = request.get_json()
 
 
64
 
65
  # reorganize data
66
  system = "You are a helpful assistant."
67
  chat_history = []
68
- prompt = ""
69
-
70
- if "messages" in data:
71
- messages = data["messages"]
72
- message_size = len(messages)
73
-
74
- prompt = messages[-1].get("content")
75
- for i in range(message_size - 1):
76
- role_this = messages[i].get("role")
77
- role_next = messages[i + 1].get("role")
78
- if role_this == "system":
79
- system = messages[i].get("content")
80
- elif role_this == "user":
81
- if role_next == "assistant":
82
- chat_history.append(
83
- [messages[i].get("content"), messages[i + 1].get("content")]
84
- )
85
- else:
86
- chat_history.append([messages[i].get("content"), " "])
87
-
88
- # print(f'{system = }')
89
- # print(f'{chat_history = }')
90
- # print(f'{prompt = }')
91
-
92
- fn_index = 0
93
-
94
- # gen a random char(11) hash
95
- chars = string.ascii_lowercase + string.digits
96
- session_hash = "".join(random.choice(chars) for _ in range(11))
97
-
98
- json_prompt = {
99
- "data": [prompt, chat_history, system],
100
- "fn_index": fn_index,
101
- "session_hash": session_hash,
102
- }
103
 
104
  def generate():
105
- response = requests.post(f"{base_url}/queue/join", json=json_prompt)
 
106
  url = f"{base_url}/queue/data?session_hash={session_hash}"
107
- data = requests.get(url, stream=True)
108
-
109
- time_now = int(time.time())
110
-
111
- for line in data.iter_lines():
112
- if line:
113
- decoded_line = line.decode("utf-8")
114
- json_line = json.loads(decoded_line[6:])
115
- if json_line["msg"] == "process_starts":
116
- res_data = gen_res_data({}, time_now=time_now, start=True)
117
- yield f"data: {json.dumps(res_data)}\n\n"
118
- elif json_line["msg"] == "process_generating":
119
- res_data = gen_res_data(json_line, time_now=time_now)
120
- yield f"data: {json.dumps(res_data)}\n\n"
121
- elif json_line["msg"] == "process_completed":
122
- yield "data: [DONE]"
 
123
 
124
  return Response(
125
  generate(),
@@ -132,25 +139,23 @@ def chat_completions():
132
 
133
 
134
  def gen_res_data(data, time_now=0, start=False):
135
- res_data = {
136
  "id": "chatcmpl",
137
  "object": "chat.completion.chunk",
138
  "created": time_now,
139
  "model": "glm-4",
140
- "choices": [{"index": 0, "finish_reason": None}],
141
  }
142
-
143
  if start:
144
- res_data["choices"][0]["delta"] = {"role": "assistant", "content": ""}
145
  else:
146
  chat_pair = data["output"]["data"][1]
147
- if chat_pair == []:
148
- res_data["choices"][0]["finish_reason"] = "stop"
149
  else:
150
- res_data["choices"][0]["delta"] = {"content": chat_pair[-1][-1]}
151
- return res_data
152
 
153
 
154
  if __name__ == "__main__":
155
- # app.run(host=args.host, port=args.port, debug=True)
156
  gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
 
1
+ import os
 
 
 
 
2
  import random
3
  import string
4
  import time
5
  import json
6
+ import requests
7
+ from functools import wraps
8
+ from flask import Flask, request, Response, jsonify, abort
9
+ import gevent.pywsgi
10
+ from gevent import monkey; monkey.patch_all()
11
 
12
  app = Flask(__name__)
13
  app.json.sort_keys = False
14
 
15
+ # β€”β€”β€” Load and enforce API key β€”β€”β€”
16
+ API_KEY = os.getenv("API_KEY")
17
+ if not API_KEY:
18
+ raise RuntimeError("Missing API_KEY in env")
19
+
20
+ def require_api_key(f):
21
+ @wraps(f)
22
+ def decorated(*args, **kwargs):
23
+ # look in X-API-Key header or Authorization: Bearer ...
24
+ key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "")
25
+ if key.startswith("Bearer "):
26
+ key = key.split(" ", 1)[1]
27
+ if not key:
28
+ return jsonify({"error": "API key missing"}), 401
29
+ if key != API_KEY:
30
+ return jsonify({"error": "Invalid API key"}), 403
31
+ return f(*args, **kwargs)
32
+ return decorated
33
+
34
+ # β€”β€”β€” Your existing arg parsing & base_url β€”β€”β€”
35
+ import argparse
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--host", default="0.0.0.0")
38
+ parser.add_argument("--port", type=int, default=7860)
39
  args = parser.parse_args()
40
 
41
+ base_url = os.getenv("MODEL_BASE_URL")
42
+
43
 
44
  @app.route('/api/v1/models', methods=["GET", "POST"])
45
+ @app.route('/v1/models', methods=["GET", "POST"])
46
+ @require_api_key
47
  def model_list():
48
  time_now = int(time.time())
49
+ models = {
50
  "object": "list",
51
  "data": [
52
+ {"id": "glm-4", "object": "model", "created": time_now, "owned_by": "tastypear"},
53
+ {"id": "gpt-3.5-turbo","object": "model", "created": time_now, "owned_by": "tastypear"}
 
 
 
 
 
 
 
 
 
 
54
  ]
55
  }
56
+ return jsonify(models)
57
+
58
 
59
  @app.route("/", methods=["GET"])
60
  def index():
61
+ return Response(
62
+ f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'
63
+ f'Set "{os.getenv("SPACE_URL")}/api" as proxy in your Chatbot.<br><br>'
64
+ f'Full API: {os.getenv("SPACE_URL")}/api/v1/chat/completions'
65
+ )
66
+
67
 
68
  @app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
69
+ @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
70
+ @require_api_key
71
  def chat_completions():
 
72
  if request.method == "OPTIONS":
73
  return Response(
74
  headers={
 
77
  }
78
  )
79
 
80
+ data = request.get_json() or {}
81
+ if "messages" not in data:
82
+ return jsonify({"error": "Missing 'messages' field"}), 400
83
 
84
  # reorganize data
85
  system = "You are a helpful assistant."
86
  chat_history = []
87
+ messages = data["messages"]
88
+ prompt = messages[-1].get("content", "")
89
+
90
+ for i in range(len(messages) - 1):
91
+ r0 = messages[i].get("role")
92
+ r1 = messages[i+1].get("role")
93
+ if r0 == "system":
94
+ system = messages[i]["content"]
95
+ elif r0 == "user":
96
+ if r1 == "assistant":
97
+ chat_history.append([messages[i]["content"], messages[i+1]["content"]])
98
+ else:
99
+ chat_history.append([messages[i]["content"], " "])
100
+
101
+ # random session id
102
+ session_hash = "".join(random.choices(string.ascii_lowercase + string.digits, k=11))
103
+ json_prompt = {
104
+ "data": [prompt, chat_history, system],
105
+ "fn_index": 0,
106
+ "session_hash": session_hash,
107
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def generate():
110
+ # enqueue job
111
+ requests.post(f"{base_url}/queue/join", json=json_prompt)
112
  url = f"{base_url}/queue/data?session_hash={session_hash}"
113
+ stream = requests.get(url, stream=True)
114
+
115
+ start_time = int(time.time())
116
+ for line in stream.iter_lines():
117
+ if not line:
118
+ continue
119
+ msg = json.loads(line.decode("utf-8")[6:])
120
+ if msg["msg"] == "process_starts":
121
+ chunk = gen_res_data({}, time_now=start_time, start=True)
122
+ elif msg["msg"] == "process_generating":
123
+ chunk = gen_res_data(msg, time_now=start_time)
124
+ elif msg["msg"] == "process_completed":
125
+ yield "data: [DONE]"
126
+ break
127
+ else:
128
+ continue
129
+ yield f"data: {json.dumps(chunk)}\n\n"
130
 
131
  return Response(
132
  generate(),
 
139
 
140
 
141
  def gen_res_data(data, time_now=0, start=False):
142
+ base = {
143
  "id": "chatcmpl",
144
  "object": "chat.completion.chunk",
145
  "created": time_now,
146
  "model": "glm-4",
147
+ "choices": [{"index": 0, "finish_reason": None}]
148
  }
 
149
  if start:
150
+ base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
151
  else:
152
  chat_pair = data["output"]["data"][1]
153
+ if not chat_pair:
154
+ base["choices"][0]["finish_reason"] = "stop"
155
  else:
156
+ base["choices"][0]["delta"] = {"content": chat_pair[-1][-1]}
157
+ return base
158
 
159
 
160
  if __name__ == "__main__":
 
161
  gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()