File size: 7,002 Bytes
9af3cb3
 
 
 
 
 
45ea54e
eb71efa
 
 
 
45ea54e
 
16a4437
eb71efa
9af3cb3
16a4437
 
 
 
 
 
 
 
 
 
eb71efa
 
 
9af3cb3
16a4437
 
 
 
45ea54e
9af3cb3
45ea54e
 
 
 
 
 
 
 
 
 
 
 
 
9af3cb3
16a4437
 
45ea54e
eb71efa
16a4437
 
eb71efa
 
16a4437
 
eb71efa
16a4437
eb71efa
9af3cb3
eb71efa
 
45ea54e
 
16a4437
 
45ea54e
 
9af3cb3
eb71efa
45ea54e
 
eb71efa
 
 
16a4437
eb71efa
 
45ea54e
9af3cb3
 
45ea54e
eb71efa
529baee
9af3cb3
 
 
 
 
 
 
 
45ea54e
9af3cb3
eb71efa
9af3cb3
16a4437
9af3cb3
45ea54e
16a4437
9af3cb3
16a4437
 
9af3cb3
16a4437
 
9af3cb3
45ea54e
 
9af3cb3
45ea54e
9af3cb3
45ea54e
 
 
 
 
 
eb71efa
 
 
 
16a4437
eb71efa
 
9af3cb3
 
 
45ea54e
eb71efa
 
9af3cb3
eb71efa
45ea54e
eb71efa
 
45ea54e
eb71efa
9af3cb3
 
45ea54e
eb71efa
9af3cb3
45ea54e
eb71efa
9af3cb3
eb71efa
16a4437
 
 
 
 
 
9af3cb3
16a4437
9af3cb3
16a4437
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
# ─── PATCH SSL EARLY ─────────────────────────────────────────────────────────────
from gevent import monkey
monkey.patch_all()

# ─── STANDARD IMPORTS ───────────────────────────────────────────────────────────
import os
import random
import string
import time
import json
import requests
from functools import wraps
from flask import Flask, request, Response, jsonify

# ─── ENV & CONFIG ────────────────────────────────────────────────────────────────
API_KEY = os.getenv("API_KEY")
if not API_KEY:
    raise RuntimeError("Missing API_KEY env var")

MODEL_BASE_URL = os.getenv("MODEL_BASE_URL")
if not MODEL_BASE_URL:
    raise RuntimeError("Missing MODEL_BASE_URL env var")

SPACE_URL = os.getenv("SPACE_URL", "")

app = Flask(__name__)
app.json.sort_keys = False

# ─── GLOBAL ERROR HANDLER ────────────────────────────────────────────────────────
@app.errorhandler(Exception)
def handle_all_errors(e):
    app.logger.exception(e)
    return jsonify({"error": str(e)}), 500

# ─── API‑KEY DECORATOR ───────────────────────────────────────────────────────────
def require_api_key(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "")
        if key.startswith("Bearer "):
            key = key.split(" ", 1)[1]
        if not key:
            return jsonify({"error": "API key missing"}), 401
        if key != API_KEY:
            return jsonify({"error": "Invalid API key"}), 403
        return f(*args, **kwargs)
    return decorated

# ─── MODEL LIST ─────────────────────────────────────────────────────────────────
@app.route('/api/v1/models', methods=['GET', 'POST'])
@app.route('/v1/models',      methods=['GET', 'POST'])
@require_api_key
def model_list():
    now = int(time.time())
    return jsonify({
        "object": "list",
        "data": [
            {"id": "glm-4",        "object": "model", "created": now, "owned_by": "tastypear"},
            {"id": "gpt-3.5-turbo", "object": "model", "created": now, "owned_by": "tastypear"}
        ]
    })

# ─── INDEX ──────────────────────────────────────────────────────────────────────
@app.route("/", methods=["GET"])
def index():
    return Response(
        f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'
        f'Set "{SPACE_URL}/api" as proxy in your Chatbot.<br><br>'
        f'Full API: {SPACE_URL}/api/v1/chat/completions'
    )

# ─── CHAT COMPLETIONS ────────────────────────────────────────────────────────────
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
@app.route("/v1/chat/completions",      methods=["POST", "OPTIONS"])
@require_api_key
def chat_completions():
    if request.method == "OPTIONS":
        return Response(
            headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"}
        )

    data = request.get_json() or {}
    msgs = data.get("messages")
    if not msgs:
        return jsonify({"error": "Missing 'messages' field"}), 400

    system = None
    history = []
    for i, m in enumerate(msgs[:-1]):
        if m.get("role") == "system":
            system = m.get("content", system)
        elif m.get("role") == "user":
            nxt = msgs[i+1].get("role")
            if nxt == "assistant":
                history.append([m.get("content",""), msgs[i+1].get("content","")])
            else:
                history.append([m.get("content",""), ""])

    prompt = msgs[-1].get("content","")
    session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11))
    payload = {"data":[prompt, history, system], "fn_index":0, "session_hash":session_hash}

    def generate():
        requests.post(f"{MODEL_BASE_URL}/queue/join", json=payload)
        url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}"
        resp = requests.get(url, stream=True)
        start_ts = int(time.time())

        for line in resp.iter_lines():
            if not line: continue
            msg = json.loads(line.decode("utf-8")[6:])
            if msg["msg"] == "process_starts":
                chunk = make_chunk({}, start=True, ts=start_ts)
            elif msg["msg"] == "process_generating":
                chunk = make_chunk(msg, start=False, ts=start_ts)
            elif msg["msg"] == "process_completed":
                yield "data: [DONE]"
                break
            else:
                continue
            yield f"data: {json.dumps(chunk)}\n\n"

    return Response(
        generate(),
        mimetype="text/event-stream",
        headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"},
    )

def make_chunk(data, start=False, ts=None):
    if ts is None:
        ts = int(time.time())
    base = {
        "id": "chatcmpl",
        "object": "chat.completion.chunk",
        "created": ts,
        "model": "glm-4",
        "choices": [{"index": 0, "finish_reason": None}]
    }
    if start:
        base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
    else:
        pair = data.get("output",{}).get("data",[None,None])[1] or []
        if not pair:
            base["choices"][0]["finish_reason"] = "stop"
        else:
            base["choices"][0]["delta"] = {"content": pair[-1][-1]}
    return base

# ─── RUN ────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=7860)
    args = parser.parse_args()

    # Dev: use Flask’s debug server so you see tracebacks in browser
    app.run(host=args.host, port=args.port, debug=True)
    # Prod: swap to gevent server
    # gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()