File size: 5,293 Bytes
d53e842
0b5dc29
 
 
 
 
 
 
d53e842
0b5dc29
d53e842
 
0b5dc29
 
d53e842
0b5dc29
 
 
 
 
 
 
d53e842
 
 
725b17a
 
 
0b5dc29
 
 
 
d53e842
 
0b5dc29
d53e842
0b5dc29
725b17a
d53e842
 
 
0b5dc29
d53e842
 
0b5dc29
 
d53e842
0b5dc29
 
d53e842
0b5dc29
d53e842
 
 
0b5dc29
d53e842
 
0b5dc29
c4765e9
0b5dc29
 
c4765e9
 
0b5dc29
 
c4765e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce0ae73
 
0b5dc29
c4765e9
 
 
 
 
 
 
 
ce0ae73
c4765e9
 
 
 
 
 
 
 
 
ce0ae73
 
c4765e9
ce0ae73
0b5dc29
c4765e9
ce0ae73
 
c4765e9
 
 
 
 
0b5dc29
 
d53e842
 
0b5dc29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from flask import Flask, jsonify, request, Response, stream_with_context
import requests
import os
import logging
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(level=logging.INFO) # 统一使用 logging 模块

API_BASE_URL = os.getenv("API_BASE_URL", "")
AUTH_KEY = os.getenv("AUTH_KEY", "callme")
API_KEYS = os.getenv("API_KEYS")

API_KEY_LIST = API_KEYS.split(",") if API_KEYS else [] # 确保 API_KEYS 为空时不会出错

key_index = 0

app = Flask(__name__)

def getAPI_KEY():
    global key_index
    if not API_KEY_LIST:
        logging.warning("API_KEYS is not configured.")
        return "" # 或者抛出异常
    key = API_KEY_LIST[key_index]
    key_index = (key_index + 1) % len(API_KEY_LIST)
    return key

@app.before_request
def check_api_key():
    key = request.headers.get("Authorization")
    if key != "Bearer " + AUTH_KEY:
        return jsonify({"success": False, "message": "Unauthorized: Invalid API key"}), 403

@app.route("/v1/models", methods=['GET'])
def getModels():
    api_key = getAPI_KEY()
    if not api_key:
        return jsonify({"success": False, "message": "API Key not available"}), 500

    headers = {
        "Authorization": "Bearer " + api_key,
        "Content-Type": "application/json"
    }
    try:
        response = requests.get(API_BASE_URL + "/models", headers=headers, timeout=30) # 增加超时
        response.raise_for_status()
        response_headers = {'Content-Type': response.headers.get('content-type', 'application/json')}
        # 直接返回内容和头部,这里不需要流式处理
        return (response.content, response.status_code, response_headers)
    except requests.exceptions.RequestException as e: # 捕获requests特有异常
        logging.error("Get models error. %s", e)
        return jsonify({"success": False, "message": f"Failed to fetch models: {e}"}), 500
    except Exception as e:
        logging.error("An unexpected error occurred in getModels: %s", e)
        return jsonify({"success": False, "message": str(e)}), 500

@app.route("/v1/chat/completions",methods=['POST'])
def chat():
    headers = {
        "Authorization":"Bearer "+getAPI_KEY(),
        "Content-Type":"application/json"
    }
    data = request.get_json()
    stream_flag = data.get('stream', True)

    def generate():
        try:
            with requests.post(API_BASE_URL+"/chat/completions", headers=headers, json=data, stream=stream_flag) as response:
                response.raise_for_status() # 检查上游请求是否成功
                for chunk in response.iter_content(chunk_size=1024):
                    yield chunk
        except requests.exceptions.RequestException as e:
            logging.error("Request to upstream API failed: %s", e)
            # 在这里处理上游请求失败,例如可以 yield 一个错误消息或者抛出异常
            # 但请注意,一旦开始 yield 数据,就不能改变 HTTP 状态码和头部了
            yield b'{"error": "Upstream API request failed"}' # 作为 JSON 错误返回
        except Exception as e:
            logging.error("Unexpected error during streaming: %s", e)
            yield b'{"error": "Internal server error during streaming"}'


    try:
        # 如果不是流式请求,可以考虑不使用生成器,或者根据 stream_flag 来判断
        if not stream_flag:
            # 对于非流式请求,直接返回完整响应
            response = requests.post(API_BASE_URL+"/chat/completions", headers=headers, json=data, stream=False)
            response.raise_for_status()
            return Response(response.content,
                            status=response.status_code,
                            content_type=response.headers.get('content-type'))
        else:
            # 对于流式请求,使用生成器
            # 注意:在生成器中处理异常时,如果已经开始发送数据,状态码和头部就不能更改了。
            # 所以最好是在生成器开始之前捕获requests.post的异常。
            initial_response = requests.post(API_BASE_URL+"/chat/completions", headers=headers, json=data, stream=True)
            initial_response.raise_for_status() # 检查初始请求是否成功

            return Response(generate_from_response(initial_response),
                            status=initial_response.status_code,
                            content_type=initial_response.headers.get('content-type'))

    except requests.exceptions.RequestException as e:
        logging.error("Initial upstream API request failed: %s", e)
        return jsonify({"success": False, "message": f"Upstream API request failed: {e}"}), 500
    except Exception as e:
        logging.error("Error setting up chat completion: %s", e)
        return jsonify({"success": False, "message": str(e)}), 500

def generate_from_response(upstream_response):
    # 这是一个辅助函数,用于将上游响应的迭代器包装成一个生成器
    for chunk in upstream_response.iter_content(chunk_size=1024):
        yield chunk


if __name__ == '__main__':
    print("Starting Flask app...")
    # 在生产环境,不建议使用 debug=True,且应通过 Gunicorn 等 WSGI 服务器运行
    app.run(host='0.0.0.0', port=7860, debug=True)