Delete app.py
Browse files
app.py
DELETED
|
@@ -1,187 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
import random
|
| 4 |
-
import time
|
| 5 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
-
from json import JSONDecodeError
|
| 7 |
-
|
| 8 |
-
import requests
|
| 9 |
-
from flask import Flask, Response, jsonify, request, stream_with_context
|
| 10 |
-
from flask_cors import CORS
|
| 11 |
-
|
| 12 |
-
from auth_utils import AuthManager
|
| 13 |
-
from constants import (
|
| 14 |
-
CONTENT_TYPE_EVENT_STREAM,
|
| 15 |
-
DEFAULT_AUTH_EMAIL,
|
| 16 |
-
DEFAULT_AUTH_PASSWORD,
|
| 17 |
-
DEFAULT_NOTDIAMOND_URL,
|
| 18 |
-
DEFAULT_PORT,
|
| 19 |
-
DEFAULT_TEMPERATURE,
|
| 20 |
-
MAX_WORKERS,
|
| 21 |
-
SYSTEM_MESSAGE_CONTENT,
|
| 22 |
-
USER_AGENT,
|
| 23 |
-
)
|
| 24 |
-
from model_info import MODEL_INFO
|
| 25 |
-
from utils import count_message_tokens, handle_non_stream_response, generate_stream_response
|
| 26 |
-
|
| 27 |
-
# 配置日志
|
| 28 |
-
logging.basicConfig(level=logging.INFO)
|
| 29 |
-
logger = logging.getLogger(__name__)
|
| 30 |
-
|
| 31 |
-
# 初始化 Flask 应用
|
| 32 |
-
app = Flask(__name__)
|
| 33 |
-
CORS(app, resources={r"/*": {"origins": "*"}})
|
| 34 |
-
|
| 35 |
-
# 初始化线程池和其他全局变量
|
| 36 |
-
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
| 37 |
-
proxy_url = os.getenv('PROXY_URL')
|
| 38 |
-
NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', DEFAULT_NOTDIAMOND_URL).split(',')
|
| 39 |
-
|
| 40 |
-
# 初始化认证管理器
|
| 41 |
-
auth_manager = AuthManager(
|
| 42 |
-
os.getenv("AUTH_EMAIL", DEFAULT_AUTH_EMAIL),
|
| 43 |
-
os.getenv("AUTH_PASSWORD", DEFAULT_AUTH_PASSWORD),
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
def get_notdiamond_url():
|
| 47 |
-
"""随机选择并返回一个 notdiamond URL。"""
|
| 48 |
-
return random.choice(NOTDIAMOND_URLS)
|
| 49 |
-
|
| 50 |
-
def get_notdiamond_headers():
|
| 51 |
-
"""返回用于 notdiamond API 请求的头信息。"""
|
| 52 |
-
jwt = auth_manager.get_jwt_value()
|
| 53 |
-
if not jwt:
|
| 54 |
-
auth_manager.login()
|
| 55 |
-
jwt = auth_manager.get_jwt_value()
|
| 56 |
-
|
| 57 |
-
return {
|
| 58 |
-
'accept': CONTENT_TYPE_EVENT_STREAM,
|
| 59 |
-
'accept-language': 'zh-CN,zh;q=0.9',
|
| 60 |
-
'content-type': 'application/json',
|
| 61 |
-
'user-agent': USER_AGENT,
|
| 62 |
-
'authorization': f'Bearer {jwt}'
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
def build_payload(request_data, model_id):
|
| 66 |
-
"""构建请求有效负载。"""
|
| 67 |
-
messages = request_data.get('messages', [])
|
| 68 |
-
|
| 69 |
-
if not any(message.get('role') == 'system' for message in messages):
|
| 70 |
-
system_message = {
|
| 71 |
-
"role": "system",
|
| 72 |
-
"content": SYSTEM_MESSAGE_CONTENT
|
| 73 |
-
}
|
| 74 |
-
messages.insert(0, system_message)
|
| 75 |
-
|
| 76 |
-
mapping = MODEL_INFO.get(model_id, {}).get('mapping', model_id)
|
| 77 |
-
payload = {
|
| 78 |
-
key: value for key, value in request_data.items()
|
| 79 |
-
if key not in ('stream',)
|
| 80 |
-
}
|
| 81 |
-
payload['messages'] = messages
|
| 82 |
-
payload['model'] = mapping
|
| 83 |
-
payload['temperature'] = request_data.get('temperature', DEFAULT_TEMPERATURE)
|
| 84 |
-
|
| 85 |
-
return payload
|
| 86 |
-
|
| 87 |
-
def make_request(payload):
|
| 88 |
-
"""发送请求并处理可能的认证刷新。"""
|
| 89 |
-
url = get_notdiamond_url()
|
| 90 |
-
|
| 91 |
-
for _ in range(3): # 最多尝试3次
|
| 92 |
-
headers = get_notdiamond_headers()
|
| 93 |
-
response = executor.submit(
|
| 94 |
-
requests.post,
|
| 95 |
-
url,
|
| 96 |
-
headers=headers,
|
| 97 |
-
json=payload,
|
| 98 |
-
stream=True
|
| 99 |
-
).result()
|
| 100 |
-
|
| 101 |
-
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
|
| 102 |
-
return response
|
| 103 |
-
|
| 104 |
-
auth_manager.refresh_user_token()
|
| 105 |
-
|
| 106 |
-
return response # 如果所有尝试都失败,返回最后一次的响应
|
| 107 |
-
|
| 108 |
-
@app.route('/v1/models', methods=['GET'])
|
| 109 |
-
def proxy_models():
|
| 110 |
-
"""返回可用模型列表。"""
|
| 111 |
-
models = [
|
| 112 |
-
{
|
| 113 |
-
"id": model_id,
|
| 114 |
-
"object": "model",
|
| 115 |
-
"created": int(time.time()),
|
| 116 |
-
"owned_by": "notdiamond",
|
| 117 |
-
"permission": [],
|
| 118 |
-
"root": model_id,
|
| 119 |
-
"parent": None,
|
| 120 |
-
} for model_id in MODEL_INFO.keys()
|
| 121 |
-
]
|
| 122 |
-
return jsonify({
|
| 123 |
-
"object": "list",
|
| 124 |
-
"data": models
|
| 125 |
-
})
|
| 126 |
-
|
| 127 |
-
@app.route('/v2/chat/completions', methods=['POST'])
|
| 128 |
-
def handle_request():
|
| 129 |
-
"""处理聊天完成请求。"""
|
| 130 |
-
try:
|
| 131 |
-
request_data = request.get_json()
|
| 132 |
-
model_id = request_data.get('model', '')
|
| 133 |
-
stream = request_data.get('stream', False)
|
| 134 |
-
|
| 135 |
-
prompt_tokens = count_message_tokens(
|
| 136 |
-
request_data.get('messages', []),
|
| 137 |
-
model_id
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
payload = build_payload(request_data, model_id)
|
| 141 |
-
response = make_request(payload)
|
| 142 |
-
|
| 143 |
-
if stream:
|
| 144 |
-
return Response(
|
| 145 |
-
stream_with_context(generate_stream_response(response, model_id, prompt_tokens)),
|
| 146 |
-
content_type=CONTENT_TYPE_EVENT_STREAM
|
| 147 |
-
)
|
| 148 |
-
else:
|
| 149 |
-
return handle_non_stream_response(response, model_id, prompt_tokens)
|
| 150 |
-
|
| 151 |
-
except requests.RequestException as e:
|
| 152 |
-
logger.error("Request error: %s", str(e), exc_info=True)
|
| 153 |
-
return jsonify({
|
| 154 |
-
'error': {
|
| 155 |
-
'message': 'Error communicating with the API',
|
| 156 |
-
'type': 'api_error',
|
| 157 |
-
'param': None,
|
| 158 |
-
'code': None,
|
| 159 |
-
'details': str(e)
|
| 160 |
-
}
|
| 161 |
-
}), 503
|
| 162 |
-
except JSONDecodeError as e:
|
| 163 |
-
logger.error("JSON decode error: %s", str(e), exc_info=True)
|
| 164 |
-
return jsonify({
|
| 165 |
-
'error': {
|
| 166 |
-
'message': 'Invalid JSON in request',
|
| 167 |
-
'type': 'invalid_request_error',
|
| 168 |
-
'param': None,
|
| 169 |
-
'code': None,
|
| 170 |
-
'details': str(e)
|
| 171 |
-
}
|
| 172 |
-
}), 400
|
| 173 |
-
except Exception as e:
|
| 174 |
-
logger.error("Unexpected error: %s", str(e), exc_info=True)
|
| 175 |
-
return jsonify({
|
| 176 |
-
'error': {
|
| 177 |
-
'message': 'Internal Server Error',
|
| 178 |
-
'type': 'server_error',
|
| 179 |
-
'param': None,
|
| 180 |
-
'code': None,
|
| 181 |
-
'details': str(e)
|
| 182 |
-
}
|
| 183 |
-
}), 500
|
| 184 |
-
|
| 185 |
-
if __name__ == "__main__":
|
| 186 |
-
port = int(os.environ.get("PORT", DEFAULT_PORT))
|
| 187 |
-
app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|