Hivra commited on
Commit
9af3cb3
Β·
verified Β·
1 Parent(s): 16a4437

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -40
main.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
  import random
3
  import string
@@ -6,10 +12,8 @@ import json
6
  import requests
7
  from functools import wraps
8
  from flask import Flask, request, Response, jsonify
9
- import gevent.pywsgi
10
- from gevent import monkey; monkey.patch_all()
11
 
12
- # β€”β€” Load & check env vars β€”β€”
13
  API_KEY = os.getenv("API_KEY")
14
  if not API_KEY:
15
  raise RuntimeError("Missing API_KEY env var")
@@ -20,17 +24,16 @@ if not MODEL_BASE_URL:
20
 
21
  SPACE_URL = os.getenv("SPACE_URL", "")
22
 
23
- # β€”β€” Flask setup β€”β€”
24
  app = Flask(__name__)
25
  app.json.sort_keys = False
26
 
27
- # β€”β€” Error handler to show real errors β€”β€”
28
  @app.errorhandler(Exception)
29
  def handle_all_errors(e):
30
  app.logger.exception(e)
31
  return jsonify({"error": str(e)}), 500
32
 
33
- # β€”β€” API‑key decorator β€”β€”
34
  def require_api_key(f):
35
  @wraps(f)
36
  def decorated(*args, **kwargs):
@@ -44,7 +47,7 @@ def require_api_key(f):
44
  return f(*args, **kwargs)
45
  return decorated
46
 
47
- # β€”β€” Model list endpoint β€”β€”
48
  @app.route('/api/v1/models', methods=['GET', 'POST'])
49
  @app.route('/v1/models', methods=['GET', 'POST'])
50
  @require_api_key
@@ -58,7 +61,7 @@ def model_list():
58
  ]
59
  })
60
 
61
- # β€”β€” Home page β€”β€”
62
  @app.route("/", methods=["GET"])
63
  def index():
64
  return Response(
@@ -67,7 +70,7 @@ def index():
67
  f'Full API: {SPACE_URL}/api/v1/chat/completions'
68
  )
69
 
70
- # β€”β€” Chat completions β€”β€”
71
  @app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
72
  @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
73
  @require_api_key
@@ -78,43 +81,39 @@ def chat_completions():
78
  )
79
 
80
  data = request.get_json() or {}
81
- messages = data.get("messages")
82
- if not messages:
83
  return jsonify({"error": "Missing 'messages' field"}), 400
84
 
85
- # parse messages
86
  system = "You are a helpful assistant."
87
- chat_history = []
88
- for i, msg in enumerate(messages[:-1]):
89
- role = msg.get("role")
90
- if role == "system":
91
- system = msg.get("content", system)
92
- elif role == "user":
93
- next_role = messages[i+1].get("role")
94
- if next_role == "assistant":
95
- chat_history.append([msg.get("content",""), messages[i+1].get("content","")])
96
  else:
97
- chat_history.append([msg.get("content",""), ""])
98
 
99
- prompt = messages[-1].get("content","")
100
  session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11))
101
- json_prompt = {"data":[prompt, chat_history, system], "fn_index":0, "session_hash":session_hash}
102
 
103
  def generate():
104
- # enqueue
105
- requests.post(f"{MODEL_BASE_URL}/queue/join", json=json_prompt)
106
  url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}"
107
  resp = requests.get(url, stream=True)
108
- start_time = int(time.time())
109
 
110
  for line in resp.iter_lines():
111
- if not line:
112
- continue
113
  msg = json.loads(line.decode("utf-8")[6:])
114
  if msg["msg"] == "process_starts":
115
- chunk = gen_res_data({}, start=True, time_now=start_time)
116
  elif msg["msg"] == "process_generating":
117
- chunk = gen_res_data(msg, start=False, time_now=start_time)
118
  elif msg["msg"] == "process_completed":
119
  yield "data: [DONE]"
120
  break
@@ -128,26 +127,27 @@ def chat_completions():
128
  headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"},
129
  )
130
 
131
- def gen_res_data(data, start=False, time_now=None):
132
- if time_now is None:
133
- time_now = int(time.time())
134
  base = {
135
  "id": "chatcmpl",
136
  "object": "chat.completion.chunk",
137
- "created": time_now,
138
  "model": "glm-4",
139
  "choices": [{"index": 0, "finish_reason": None}]
140
  }
141
  if start:
142
  base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
143
  else:
144
- chat_pair = data.get("output",{}).get("data",[None,None])[1]
145
- if not chat_pair:
146
  base["choices"][0]["finish_reason"] = "stop"
147
  else:
148
- base["choices"][0]["delta"] = {"content": chat_pair[-1][-1]}
149
  return base
150
 
 
151
  if __name__ == "__main__":
152
  import argparse
153
  parser = argparse.ArgumentParser()
@@ -155,7 +155,7 @@ if __name__ == "__main__":
155
  parser.add_argument("--port", type=int, default=7860)
156
  args = parser.parse_args()
157
 
158
- # turn on debug so you see errors in browser
159
  app.run(host=args.host, port=args.port, debug=True)
160
- # once it works, you can switch back to:
161
  # gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
 
1
+ #!/usr/bin/env python3
2
+ # ─── PATCH SSL EARLY ─────────────────────────────────────────────────────────────
3
+ from gevent import monkey
4
+ monkey.patch_all()
5
+
6
+ # ─── STANDARD IMPORTS ───────────────────────────────────────────────────────────
7
  import os
8
  import random
9
  import string
 
12
  import requests
13
  from functools import wraps
14
  from flask import Flask, request, Response, jsonify
 
 
15
 
16
+ # ─── ENV & CONFIG ────────────────────────────────────────────────────────────────
17
  API_KEY = os.getenv("API_KEY")
18
  if not API_KEY:
19
  raise RuntimeError("Missing API_KEY env var")
 
24
 
25
  SPACE_URL = os.getenv("SPACE_URL", "")
26
 
 
27
  app = Flask(__name__)
28
  app.json.sort_keys = False
29
 
30
+ # ─── GLOBAL ERROR HANDLER ────────────────────────────────────────────────────────
31
  @app.errorhandler(Exception)
32
  def handle_all_errors(e):
33
  app.logger.exception(e)
34
  return jsonify({"error": str(e)}), 500
35
 
36
+ # ─── API‑KEY DECORATOR ───────────────────────────────────────────────────────────
37
  def require_api_key(f):
38
  @wraps(f)
39
  def decorated(*args, **kwargs):
 
47
  return f(*args, **kwargs)
48
  return decorated
49
 
50
+ # ─── MODEL LIST ─────────────────────────────────────────────────────────────────
51
  @app.route('/api/v1/models', methods=['GET', 'POST'])
52
  @app.route('/v1/models', methods=['GET', 'POST'])
53
  @require_api_key
 
61
  ]
62
  })
63
 
64
+ # ─── INDEX ──────────────────────────────────────────────────────────────────────
65
  @app.route("/", methods=["GET"])
66
  def index():
67
  return Response(
 
70
  f'Full API: {SPACE_URL}/api/v1/chat/completions'
71
  )
72
 
73
+ # ─── CHAT COMPLETIONS ────────────────────────────────────────────────────────────
74
  @app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
75
  @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
76
  @require_api_key
 
81
  )
82
 
83
  data = request.get_json() or {}
84
+ msgs = data.get("messages")
85
+ if not msgs:
86
  return jsonify({"error": "Missing 'messages' field"}), 400
87
 
 
88
  system = "You are a helpful assistant."
89
+ history = []
90
+ for i, m in enumerate(msgs[:-1]):
91
+ if m.get("role") == "system":
92
+ system = m.get("content", system)
93
+ elif m.get("role") == "user":
94
+ nxt = msgs[i+1].get("role")
95
+ if nxt == "assistant":
96
+ history.append([m.get("content",""), msgs[i+1].get("content","")])
 
97
  else:
98
+ history.append([m.get("content",""), ""])
99
 
100
+ prompt = msgs[-1].get("content","")
101
  session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11))
102
+ payload = {"data":[prompt, history, system], "fn_index":0, "session_hash":session_hash}
103
 
104
  def generate():
105
+ requests.post(f"{MODEL_BASE_URL}/queue/join", json=payload)
 
106
  url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}"
107
  resp = requests.get(url, stream=True)
108
+ start_ts = int(time.time())
109
 
110
  for line in resp.iter_lines():
111
+ if not line: continue
 
112
  msg = json.loads(line.decode("utf-8")[6:])
113
  if msg["msg"] == "process_starts":
114
+ chunk = make_chunk({}, start=True, ts=start_ts)
115
  elif msg["msg"] == "process_generating":
116
+ chunk = make_chunk(msg, start=False, ts=start_ts)
117
  elif msg["msg"] == "process_completed":
118
  yield "data: [DONE]"
119
  break
 
127
  headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"},
128
  )
129
 
130
+ def make_chunk(data, start=False, ts=None):
131
+ if ts is None:
132
+ ts = int(time.time())
133
  base = {
134
  "id": "chatcmpl",
135
  "object": "chat.completion.chunk",
136
+ "created": ts,
137
  "model": "glm-4",
138
  "choices": [{"index": 0, "finish_reason": None}]
139
  }
140
  if start:
141
  base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
142
  else:
143
+ pair = data.get("output",{}).get("data",[None,None])[1] or []
144
+ if not pair:
145
  base["choices"][0]["finish_reason"] = "stop"
146
  else:
147
+ base["choices"][0]["delta"] = {"content": pair[-1][-1]}
148
  return base
149
 
150
+ # ─── RUN ────────────────────────────────────────────────────────────────────────
151
  if __name__ == "__main__":
152
  import argparse
153
  parser = argparse.ArgumentParser()
 
155
  parser.add_argument("--port", type=int, default=7860)
156
  args = parser.parse_args()
157
 
158
+ # Dev: use Flask’s debug server so you see tracebacks in browser
159
  app.run(host=args.host, port=args.port, debug=True)
160
+ # Prod: swap to gevent server
161
  # gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()