|
|
|
|
|
|
|
|
from gevent import monkey |
|
|
monkey.patch_all() |
|
|
|
|
|
|
|
|
import os |
|
|
import random |
|
|
import string |
|
|
import time |
|
|
import json |
|
|
import requests |
|
|
from functools import wraps |
|
|
from flask import Flask, request, Response, jsonify |
|
|
|
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
if not API_KEY: |
|
|
raise RuntimeError("Missing API_KEY env var") |
|
|
|
|
|
MODEL_BASE_URL = os.getenv("MODEL_BASE_URL") |
|
|
if not MODEL_BASE_URL: |
|
|
raise RuntimeError("Missing MODEL_BASE_URL env var") |
|
|
|
|
|
SPACE_URL = os.getenv("SPACE_URL", "") |
|
|
|
|
|
app = Flask(__name__) |
|
|
app.json.sort_keys = False |
|
|
|
|
|
|
|
|
@app.errorhandler(Exception) |
|
|
def handle_all_errors(e): |
|
|
app.logger.exception(e) |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
def require_api_key(f): |
|
|
@wraps(f) |
|
|
def decorated(*args, **kwargs): |
|
|
key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "") |
|
|
if key.startswith("Bearer "): |
|
|
key = key.split(" ", 1)[1] |
|
|
if not key: |
|
|
return jsonify({"error": "API key missing"}), 401 |
|
|
if key != API_KEY: |
|
|
return jsonify({"error": "Invalid API key"}), 403 |
|
|
return f(*args, **kwargs) |
|
|
return decorated |
|
|
|
|
|
|
|
|
@app.route('/api/v1/models', methods=['GET', 'POST']) |
|
|
@app.route('/v1/models', methods=['GET', 'POST']) |
|
|
@require_api_key |
|
|
def model_list(): |
|
|
now = int(time.time()) |
|
|
return jsonify({ |
|
|
"object": "list", |
|
|
"data": [ |
|
|
{"id": "glm-4", "object": "model", "created": now, "owned_by": "tastypear"}, |
|
|
{"id": "gpt-3.5-turbo", "object": "model", "created": now, "owned_by": "tastypear"} |
|
|
] |
|
|
}) |
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def index(): |
|
|
return Response( |
|
|
f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>' |
|
|
f'Set "{SPACE_URL}/api" as proxy in your Chatbot.<br><br>' |
|
|
f'Full API: {SPACE_URL}/api/v1/chat/completions' |
|
|
) |
|
|
|
|
|
|
|
|
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"]) |
|
|
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"]) |
|
|
@require_api_key |
|
|
def chat_completions(): |
|
|
if request.method == "OPTIONS": |
|
|
return Response( |
|
|
headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"} |
|
|
) |
|
|
|
|
|
data = request.get_json() or {} |
|
|
msgs = data.get("messages") |
|
|
if not msgs: |
|
|
return jsonify({"error": "Missing 'messages' field"}), 400 |
|
|
|
|
|
system = None |
|
|
history = [] |
|
|
for i, m in enumerate(msgs[:-1]): |
|
|
if m.get("role") == "system": |
|
|
system = m.get("content", system) |
|
|
elif m.get("role") == "user": |
|
|
nxt = msgs[i+1].get("role") |
|
|
if nxt == "assistant": |
|
|
history.append([m.get("content",""), msgs[i+1].get("content","")]) |
|
|
else: |
|
|
history.append([m.get("content",""), ""]) |
|
|
|
|
|
prompt = msgs[-1].get("content","") |
|
|
session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11)) |
|
|
payload = {"data":[prompt, history, system], "fn_index":0, "session_hash":session_hash} |
|
|
|
|
|
def generate(): |
|
|
requests.post(f"{MODEL_BASE_URL}/queue/join", json=payload) |
|
|
url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}" |
|
|
resp = requests.get(url, stream=True) |
|
|
start_ts = int(time.time()) |
|
|
|
|
|
for line in resp.iter_lines(): |
|
|
if not line: continue |
|
|
msg = json.loads(line.decode("utf-8")[6:]) |
|
|
if msg["msg"] == "process_starts": |
|
|
chunk = make_chunk({}, start=True, ts=start_ts) |
|
|
elif msg["msg"] == "process_generating": |
|
|
chunk = make_chunk(msg, start=False, ts=start_ts) |
|
|
elif msg["msg"] == "process_completed": |
|
|
yield "data: [DONE]" |
|
|
break |
|
|
else: |
|
|
continue |
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
|
|
return Response( |
|
|
generate(), |
|
|
mimetype="text/event-stream", |
|
|
headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"}, |
|
|
) |
|
|
|
|
|
def make_chunk(data, start=False, ts=None): |
|
|
if ts is None: |
|
|
ts = int(time.time()) |
|
|
base = { |
|
|
"id": "chatcmpl", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": ts, |
|
|
"model": "glm-4", |
|
|
"choices": [{"index": 0, "finish_reason": None}] |
|
|
} |
|
|
if start: |
|
|
base["choices"][0]["delta"] = {"role": "assistant", "content": ""} |
|
|
else: |
|
|
pair = data.get("output",{}).get("data",[None,None])[1] or [] |
|
|
if not pair: |
|
|
base["choices"][0]["finish_reason"] = "stop" |
|
|
else: |
|
|
base["choices"][0]["delta"] = {"content": pair[-1][-1]} |
|
|
return base |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--host", default="0.0.0.0") |
|
|
parser.add_argument("--port", type=int, default=7860) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
app.run(host=args.host, port=args.port, debug=True) |
|
|
|
|
|
|
|
|
|