Update app.py
Browse files
app.py
CHANGED
|
@@ -1,23 +1,18 @@
|
|
| 1 |
-
from flask import Flask,jsonify,request,Response, stream_with_context
|
| 2 |
import requests
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
-
# 加载.env文件中的环境变量
|
| 8 |
load_dotenv()
|
| 9 |
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
API_BASE_URL = os.getenv("API_BASE_URL","")
|
| 15 |
-
|
| 16 |
-
AUTH_KEY = os.getenv("AUTH_KEY","callme")
|
| 17 |
-
|
| 18 |
API_KEYS = os.getenv("API_KEYS")
|
| 19 |
|
| 20 |
-
API_KEY_LIST = API_KEYS.split(",")
|
| 21 |
|
| 22 |
key_index = 0
|
| 23 |
|
|
@@ -25,95 +20,92 @@ app = Flask(__name__)
|
|
| 25 |
|
| 26 |
def getAPI_KEY():
|
| 27 |
global key_index
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
key = API_KEY_LIST[key_index]
|
| 31 |
-
|
| 32 |
-
#再计算下一个 index,使用取模运算来自动循环
|
| 33 |
key_index = (key_index + 1) % len(API_KEY_LIST)
|
| 34 |
-
|
| 35 |
return key
|
| 36 |
|
| 37 |
@app.before_request
|
| 38 |
def check_api_key():
|
| 39 |
key = request.headers.get("Authorization")
|
| 40 |
-
if key != "Bearer "+AUTH_KEY:
|
| 41 |
-
return jsonify({"success":False,"message": "Unauthorized: Invalid API key"}), 403
|
| 42 |
|
| 43 |
-
@app.route("/v1/models",methods=['GET'])
|
| 44 |
def getModels():
|
| 45 |
api_key = getAPI_KEY()
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
headers = {
|
| 48 |
-
"Authorization":"Bearer "+api_key,
|
| 49 |
-
"Content-Type":"application/json"
|
| 50 |
}
|
| 51 |
try:
|
| 52 |
-
response = requests.get(API_BASE_URL+"/models",headers=headers)
|
| 53 |
response.raise_for_status()
|
| 54 |
response_headers = {'Content-Type': response.headers.get('content-type', 'application/json')}
|
|
|
|
| 55 |
return (response.content, response.status_code, response_headers)
|
|
|
|
|
|
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
-
logging("
|
|
|
|
| 58 |
|
| 59 |
-
@app.route("/v1/chat/completions",methods=['POST'])
|
| 60 |
def chat():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
headers = {
|
| 62 |
-
"Authorization":"Bearer "+
|
| 63 |
-
"Content-Type":"application/json"
|
| 64 |
}
|
| 65 |
data = request.get_json()
|
| 66 |
-
stream_flag = data.get('stream', False)
|
| 67 |
-
|
| 68 |
-
def generate():
|
| 69 |
-
try:
|
| 70 |
-
with requests.post(API_BASE_URL+"/chat/completions", headers=headers, json=data, stream=stream_flag) as response:
|
| 71 |
-
response.raise_for_status() # 检查上游请求是否成功
|
| 72 |
-
for chunk in response.iter_content(chunk_size=1024):
|
| 73 |
-
yield chunk
|
| 74 |
-
except requests.exceptions.RequestException as e:
|
| 75 |
-
logging.error("Request to upstream API failed: %s", e)
|
| 76 |
-
# 在这里处理上游请求失败,例如可以 yield 一个错误消息或者抛出异常
|
| 77 |
-
# 但请注意,一旦开始 yield 数据,就不能改变 HTTP 状态码和头部了
|
| 78 |
-
yield b'{"error": "Upstream API request failed"}' # 作为 JSON 错误返回
|
| 79 |
-
except Exception as e:
|
| 80 |
-
logging.error("Unexpected error during streaming: %s", e)
|
| 81 |
-
yield b'{"error": "Internal server error during streaming"}'
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
try:
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
else:
|
| 94 |
-
#
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
initial_response.raise_for_status() # 检查初始请求是否成功
|
| 99 |
-
|
| 100 |
-
return Response(generate_from_response(initial_response),
|
| 101 |
-
status=initial_response.status_code,
|
| 102 |
-
content_type=initial_response.headers.get('content-type'))
|
| 103 |
|
| 104 |
except requests.exceptions.RequestException as e:
|
| 105 |
-
logging.error("
|
| 106 |
return jsonify({"success": False, "message": f"Upstream API request failed: {e}"}), 500
|
| 107 |
except Exception as e:
|
| 108 |
-
logging.error("
|
| 109 |
return jsonify({"success": False, "message": str(e)}), 500
|
| 110 |
|
| 111 |
-
def generate_from_response(upstream_response):
|
| 112 |
-
# 这是一个辅助函数,用于将上游响应的迭代器包装成一个生成器
|
| 113 |
-
for chunk in upstream_response.iter_content(chunk_size=1024):
|
| 114 |
-
yield chunk
|
| 115 |
-
|
| 116 |
|
| 117 |
if __name__ == '__main__':
|
| 118 |
-
print("
|
|
|
|
| 119 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
|
|
|
| 1 |
+
from flask import Flask, jsonify, request, Response, stream_with_context
|
| 2 |
import requests
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
|
|
|
| 7 |
load_dotenv()
|
| 8 |
|
| 9 |
+
logging.basicConfig(level=logging.INFO) # 统一使用 logging 模块
|
| 10 |
|
| 11 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "")
|
| 12 |
+
AUTH_KEY = os.getenv("AUTH_KEY", "callme")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
API_KEYS = os.getenv("API_KEYS")
|
| 14 |
|
| 15 |
+
API_KEY_LIST = API_KEYS.split(",") if API_KEYS else [] # 确保 API_KEYS 为空时不会出错
|
| 16 |
|
| 17 |
key_index = 0
|
| 18 |
|
|
|
|
| 20 |
|
| 21 |
def getAPI_KEY():
|
| 22 |
global key_index
|
| 23 |
+
if not API_KEY_LIST:
|
| 24 |
+
logging.warning("API_KEYS is not configured.")
|
| 25 |
+
return "" # 或者抛出异常
|
| 26 |
key = API_KEY_LIST[key_index]
|
|
|
|
|
|
|
| 27 |
key_index = (key_index + 1) % len(API_KEY_LIST)
|
|
|
|
| 28 |
return key
|
| 29 |
|
| 30 |
@app.before_request
|
| 31 |
def check_api_key():
|
| 32 |
key = request.headers.get("Authorization")
|
| 33 |
+
if key != "Bearer " + AUTH_KEY:
|
| 34 |
+
return jsonify({"success": False, "message": "Unauthorized: Invalid API key"}), 403
|
| 35 |
|
| 36 |
+
@app.route("/v1/models", methods=['GET'])
|
| 37 |
def getModels():
|
| 38 |
api_key = getAPI_KEY()
|
| 39 |
+
if not api_key:
|
| 40 |
+
return jsonify({"success": False, "message": "API Key not available"}), 500
|
| 41 |
+
|
| 42 |
headers = {
|
| 43 |
+
"Authorization": "Bearer " + api_key,
|
| 44 |
+
"Content-Type": "application/json"
|
| 45 |
}
|
| 46 |
try:
|
| 47 |
+
response = requests.get(API_BASE_URL + "/models", headers=headers, timeout=30) # 增加超时
|
| 48 |
response.raise_for_status()
|
| 49 |
response_headers = {'Content-Type': response.headers.get('content-type', 'application/json')}
|
| 50 |
+
# 直接返回内容和头部,这里不需要流式处理
|
| 51 |
return (response.content, response.status_code, response_headers)
|
| 52 |
+
except requests.exceptions.RequestException as e: # 捕获requests特有异常
|
| 53 |
+
logging.error("Get models error. %s", e)
|
| 54 |
+
return jsonify({"success": False, "message": f"Failed to fetch models: {e}"}), 500
|
| 55 |
except Exception as e:
|
| 56 |
+
logging.error("An unexpected error occurred in getModels: %s", e)
|
| 57 |
+
return jsonify({"success": False, "message": str(e)}), 500
|
| 58 |
|
| 59 |
+
@app.route("/v1/chat/completions", methods=['POST'])
|
| 60 |
def chat():
|
| 61 |
+
api_key = getAPI_KEY()
|
| 62 |
+
if not api_key:
|
| 63 |
+
return jsonify({"success": False, "message": "API Key not available"}), 500
|
| 64 |
+
|
| 65 |
headers = {
|
| 66 |
+
"Authorization": "Bearer " + api_key,
|
| 67 |
+
"Content-Type": "application/json"
|
| 68 |
}
|
| 69 |
data = request.get_json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# 重点在这里:根据客户端请求中的 'stream' 字段决定是否进行流式转发
|
| 72 |
+
# 如果客户端没有提供 'stream' 字段,我们假设它需要非流式响应(或者默认值取决于上游API的约定)
|
| 73 |
+
# 但为了明确支持非流式,这里我们设为 False
|
| 74 |
+
client_wants_stream = data.get('stream', False) # 客户端请求中 stream 的值
|
| 75 |
|
| 76 |
try:
|
| 77 |
+
# 使用 requests.post 发送请求到上游 API
|
| 78 |
+
# 上游 API 的 'stream' 参数应该与客户端请求的 'stream' 字段保持一致
|
| 79 |
+
upstream_response = requests.post(
|
| 80 |
+
API_BASE_URL + "/chat/completions",
|
| 81 |
+
headers=headers,
|
| 82 |
+
json=data, # 客户端请求的 payload,包括 stream 字段
|
| 83 |
+
stream=client_wants_stream, # 控制 requests 是否以流式接收上游响应
|
| 84 |
+
timeout= (600 if client_wants_stream else 60) # 流式请求可以有更长的超时
|
| 85 |
+
)
|
| 86 |
+
upstream_response.raise_for_status() # 检查上游 API 响应的 HTTP 状态码
|
| 87 |
+
|
| 88 |
+
# 根据客户端是否想要流式响应来处理
|
| 89 |
+
if client_wants_stream:
|
| 90 |
+
# 流式响应:使用 stream_with_context 逐块发送
|
| 91 |
+
return Response(stream_with_context(upstream_response.iter_content(chunk_size=1024)),
|
| 92 |
+
status=upstream_response.status_code,
|
| 93 |
+
content_type=upstream_response.headers.get('content-type', 'application/json'))
|
| 94 |
else:
|
| 95 |
+
# 非流式响应:直接返回完整的响应内容
|
| 96 |
+
# 这里 upstream_response.content 会等待所有数据接收完毕
|
| 97 |
+
response_headers = {'Content-Type': upstream_response.headers.get('content-type', 'application/json')}
|
| 98 |
+
return (upstream_response.content, upstream_response.status_code, response_headers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
except requests.exceptions.RequestException as e:
|
| 101 |
+
logging.error("Chat completion request error to upstream API: %s", e)
|
| 102 |
return jsonify({"success": False, "message": f"Upstream API request failed: {e}"}), 500
|
| 103 |
except Exception as e:
|
| 104 |
+
logging.error("An unexpected error occurred in chat completion: %s", e)
|
| 105 |
return jsonify({"success": False, "message": str(e)}), 500
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
if __name__ == '__main__':
|
| 109 |
+
print("Starting Flask app...")
|
| 110 |
+
# 在生产环境,不建议使用 debug=True,且应通过 Gunicorn 等 WSGI 服务器运行
|
| 111 |
app.run(host='0.0.0.0', port=7860, debug=True)
|