File size: 7,002 Bytes
9af3cb3 45ea54e eb71efa 45ea54e 16a4437 eb71efa 9af3cb3 16a4437 eb71efa 9af3cb3 16a4437 45ea54e 9af3cb3 45ea54e 9af3cb3 16a4437 45ea54e eb71efa 16a4437 eb71efa 16a4437 eb71efa 16a4437 eb71efa 9af3cb3 eb71efa 45ea54e 16a4437 45ea54e 9af3cb3 eb71efa 45ea54e eb71efa 16a4437 eb71efa 45ea54e 9af3cb3 45ea54e eb71efa 529baee 9af3cb3 45ea54e 9af3cb3 eb71efa 9af3cb3 16a4437 9af3cb3 45ea54e 16a4437 9af3cb3 16a4437 9af3cb3 16a4437 9af3cb3 45ea54e 9af3cb3 45ea54e 9af3cb3 45ea54e eb71efa 16a4437 eb71efa 9af3cb3 45ea54e eb71efa 9af3cb3 eb71efa 45ea54e eb71efa 45ea54e eb71efa 9af3cb3 45ea54e eb71efa 9af3cb3 45ea54e eb71efa 9af3cb3 eb71efa 16a4437 9af3cb3 16a4437 9af3cb3 16a4437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
#!/usr/bin/env python3
# βββ PATCH SSL EARLY βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from gevent import monkey
monkey.patch_all()
# βββ STANDARD IMPORTS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import os
import random
import string
import time
import json
import requests
from functools import wraps
from flask import Flask, request, Response, jsonify
# βββ ENV & CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
API_KEY = os.getenv("API_KEY")
if not API_KEY:
raise RuntimeError("Missing API_KEY env var")
MODEL_BASE_URL = os.getenv("MODEL_BASE_URL")
if not MODEL_BASE_URL:
raise RuntimeError("Missing MODEL_BASE_URL env var")
SPACE_URL = os.getenv("SPACE_URL", "")
app = Flask(__name__)
app.json.sort_keys = False
# βββ GLOBAL ERROR HANDLER ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.errorhandler(Exception)
def handle_all_errors(e):
app.logger.exception(e)
return jsonify({"error": str(e)}), 500
# βββ APIβKEY DECORATOR βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def require_api_key(f):
@wraps(f)
def decorated(*args, **kwargs):
key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "")
if key.startswith("Bearer "):
key = key.split(" ", 1)[1]
if not key:
return jsonify({"error": "API key missing"}), 401
if key != API_KEY:
return jsonify({"error": "Invalid API key"}), 403
return f(*args, **kwargs)
return decorated
# βββ MODEL LIST βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route('/api/v1/models', methods=['GET', 'POST'])
@app.route('/v1/models', methods=['GET', 'POST'])
@require_api_key
def model_list():
now = int(time.time())
return jsonify({
"object": "list",
"data": [
{"id": "glm-4", "object": "model", "created": now, "owned_by": "tastypear"},
{"id": "gpt-3.5-turbo", "object": "model", "created": now, "owned_by": "tastypear"}
]
})
# βββ INDEX ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/", methods=["GET"])
def index():
return Response(
f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'
f'Set "{SPACE_URL}/api" as proxy in your Chatbot.<br><br>'
f'Full API: {SPACE_URL}/api/v1/chat/completions'
)
# βββ CHAT COMPLETIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
@require_api_key
def chat_completions():
if request.method == "OPTIONS":
return Response(
headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"}
)
data = request.get_json() or {}
msgs = data.get("messages")
if not msgs:
return jsonify({"error": "Missing 'messages' field"}), 400
system = None
history = []
for i, m in enumerate(msgs[:-1]):
if m.get("role") == "system":
system = m.get("content", system)
elif m.get("role") == "user":
nxt = msgs[i+1].get("role")
if nxt == "assistant":
history.append([m.get("content",""), msgs[i+1].get("content","")])
else:
history.append([m.get("content",""), ""])
prompt = msgs[-1].get("content","")
session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11))
payload = {"data":[prompt, history, system], "fn_index":0, "session_hash":session_hash}
def generate():
requests.post(f"{MODEL_BASE_URL}/queue/join", json=payload)
url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}"
resp = requests.get(url, stream=True)
start_ts = int(time.time())
for line in resp.iter_lines():
if not line: continue
msg = json.loads(line.decode("utf-8")[6:])
if msg["msg"] == "process_starts":
chunk = make_chunk({}, start=True, ts=start_ts)
elif msg["msg"] == "process_generating":
chunk = make_chunk(msg, start=False, ts=start_ts)
elif msg["msg"] == "process_completed":
yield "data: [DONE]"
break
else:
continue
yield f"data: {json.dumps(chunk)}\n\n"
return Response(
generate(),
mimetype="text/event-stream",
headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"},
)
def make_chunk(data, start=False, ts=None):
if ts is None:
ts = int(time.time())
base = {
"id": "chatcmpl",
"object": "chat.completion.chunk",
"created": ts,
"model": "glm-4",
"choices": [{"index": 0, "finish_reason": None}]
}
if start:
base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
else:
pair = data.get("output",{}).get("data",[None,None])[1] or []
if not pair:
base["choices"][0]["finish_reason"] = "stop"
else:
base["choices"][0]["delta"] = {"content": pair[-1][-1]}
return base
# βββ RUN ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
# Dev: use Flaskβs debug server so you see tracebacks in browser
app.run(host=args.host, port=args.port, debug=True)
# Prod: swap to gevent server
# gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
|