GLM4 / main.py
Hivra's picture
Update main.py
529baee verified
#!/usr/bin/env python3
# ─── PATCH SSL EARLY ─────────────────────────────────────────────────────────────
from gevent import monkey
monkey.patch_all()
# ─── STANDARD IMPORTS ───────────────────────────────────────────────────────────
import os
import random
import string
import time
import json
import requests
from functools import wraps
from flask import Flask, request, Response, jsonify
# ─── ENV & CONFIG ────────────────────────────────────────────────────────────────
API_KEY = os.getenv("API_KEY")
if not API_KEY:
raise RuntimeError("Missing API_KEY env var")
MODEL_BASE_URL = os.getenv("MODEL_BASE_URL")
if not MODEL_BASE_URL:
raise RuntimeError("Missing MODEL_BASE_URL env var")
SPACE_URL = os.getenv("SPACE_URL", "")
app = Flask(__name__)
app.json.sort_keys = False
# ─── GLOBAL ERROR HANDLER ────────────────────────────────────────────────────────
@app.errorhandler(Exception)
def handle_all_errors(e):
app.logger.exception(e)
return jsonify({"error": str(e)}), 500
# ─── API‑KEY DECORATOR ───────────────────────────────────────────────────────────
def require_api_key(f):
@wraps(f)
def decorated(*args, **kwargs):
key = request.headers.get("X-API-Key") or request.headers.get("Authorization", "")
if key.startswith("Bearer "):
key = key.split(" ", 1)[1]
if not key:
return jsonify({"error": "API key missing"}), 401
if key != API_KEY:
return jsonify({"error": "Invalid API key"}), 403
return f(*args, **kwargs)
return decorated
# ─── MODEL LIST ─────────────────────────────────────────────────────────────────
@app.route('/api/v1/models', methods=['GET', 'POST'])
@app.route('/v1/models', methods=['GET', 'POST'])
@require_api_key
def model_list():
now = int(time.time())
return jsonify({
"object": "list",
"data": [
{"id": "glm-4", "object": "model", "created": now, "owned_by": "tastypear"},
{"id": "gpt-3.5-turbo", "object": "model", "created": now, "owned_by": "tastypear"}
]
})
# ─── INDEX ──────────────────────────────────────────────────────────────────────
@app.route("/", methods=["GET"])
def index():
return Response(
f'ZhipuAI GLM-4 OpenAI Compatible API<br><br>'
f'Set "{SPACE_URL}/api" as proxy in your Chatbot.<br><br>'
f'Full API: {SPACE_URL}/api/v1/chat/completions'
)
# ─── CHAT COMPLETIONS ────────────────────────────────────────────────────────────
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
@require_api_key
def chat_completions():
if request.method == "OPTIONS":
return Response(
headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"}
)
data = request.get_json() or {}
msgs = data.get("messages")
if not msgs:
return jsonify({"error": "Missing 'messages' field"}), 400
system = None
history = []
for i, m in enumerate(msgs[:-1]):
if m.get("role") == "system":
system = m.get("content", system)
elif m.get("role") == "user":
nxt = msgs[i+1].get("role")
if nxt == "assistant":
history.append([m.get("content",""), msgs[i+1].get("content","")])
else:
history.append([m.get("content",""), ""])
prompt = msgs[-1].get("content","")
session_hash = "".join(random.choices(string.ascii_lowercase+string.digits, k=11))
payload = {"data":[prompt, history, system], "fn_index":0, "session_hash":session_hash}
def generate():
requests.post(f"{MODEL_BASE_URL}/queue/join", json=payload)
url = f"{MODEL_BASE_URL}/queue/data?session_hash={session_hash}"
resp = requests.get(url, stream=True)
start_ts = int(time.time())
for line in resp.iter_lines():
if not line: continue
msg = json.loads(line.decode("utf-8")[6:])
if msg["msg"] == "process_starts":
chunk = make_chunk({}, start=True, ts=start_ts)
elif msg["msg"] == "process_generating":
chunk = make_chunk(msg, start=False, ts=start_ts)
elif msg["msg"] == "process_completed":
yield "data: [DONE]"
break
else:
continue
yield f"data: {json.dumps(chunk)}\n\n"
return Response(
generate(),
mimetype="text/event-stream",
headers={"Access-Control-Allow-Origin":"*","Access-Control-Allow-Headers":"*"},
)
def make_chunk(data, start=False, ts=None):
if ts is None:
ts = int(time.time())
base = {
"id": "chatcmpl",
"object": "chat.completion.chunk",
"created": ts,
"model": "glm-4",
"choices": [{"index": 0, "finish_reason": None}]
}
if start:
base["choices"][0]["delta"] = {"role": "assistant", "content": ""}
else:
pair = data.get("output",{}).get("data",[None,None])[1] or []
if not pair:
base["choices"][0]["finish_reason"] = "stop"
else:
base["choices"][0]["delta"] = {"content": pair[-1][-1]}
return base
# ─── RUN ────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
# Dev: use Flask’s debug server so you see tracebacks in browser
app.run(host=args.host, port=args.port, debug=True)
# Prod: swap to gevent server
# gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()