import logging import os import random import time from concurrent.futures import ThreadPoolExecutor from json import JSONDecodeError import requests from flask import Flask, Response, jsonify, request, stream_with_context from flask_cors import CORS from auth_utils import AuthManager from constants import ( CONTENT_TYPE_EVENT_STREAM, DEFAULT_AUTH_EMAIL, DEFAULT_AUTH_PASSWORD, DEFAULT_NOTDIAMOND_URL, DEFAULT_PORT, DEFAULT_TEMPERATURE, MAX_WORKERS, SYSTEM_MESSAGE_CONTENT, USER_AGENT, ) from model_info import MODEL_INFO from utils import count_message_tokens, handle_non_stream_response, generate_stream_response # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 初始化 Flask 应用 app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) # 初始化线程池和其他全局变量 executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) proxy_url = os.getenv('PROXY_URL') NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', DEFAULT_NOTDIAMOND_URL).split(',') # 初始化认证管理器 auth_manager = AuthManager( os.getenv("AUTH_EMAIL", DEFAULT_AUTH_EMAIL), os.getenv("AUTH_PASSWORD", DEFAULT_AUTH_PASSWORD), ) def get_notdiamond_url(): """随机选择并返回一个 notdiamond URL。""" return random.choice(NOTDIAMOND_URLS) def get_notdiamond_headers(): """返回用于 notdiamond API 请求的头信息。""" jwt = auth_manager.get_jwt_value() if not jwt: auth_manager.login() jwt = auth_manager.get_jwt_value() return { 'accept': CONTENT_TYPE_EVENT_STREAM, 'accept-language': 'zh-CN,zh;q=0.9', 'content-type': 'application/json', 'user-agent': USER_AGENT, 'authorization': f'Bearer {jwt}' } def build_payload(request_data, model_id): """构建请求有效负载。""" messages = request_data.get('messages', []) if not any(message.get('role') == 'system' for message in messages): system_message = { "role": "system", "content": SYSTEM_MESSAGE_CONTENT } messages.insert(0, system_message) mapping = MODEL_INFO.get(model_id, {}).get('mapping', model_id) payload = { key: value for key, value in request_data.items() if key not in ('stream',) } payload['messages'] = messages payload['model'] = mapping payload['temperature'] = request_data.get('temperature', DEFAULT_TEMPERATURE) return payload def make_request(payload): """发送请求并处理可能的认证刷新。""" url = get_notdiamond_url() for _ in range(3): # 最多尝试3次 headers = get_notdiamond_headers() response = executor.submit( requests.post, url, headers=headers, json=payload, stream=True ).result() if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream': return response auth_manager.refresh_user_token() return response # 如果所有尝试都失败,返回最后一次的响应 @app.route('/v1/models', methods=['GET']) def proxy_models(): """返回可用模型列表。""" models = [ { "id": model_id, "object": "model", "created": int(time.time()), "owned_by": "notdiamond", "permission": [], "root": model_id, "parent": None, } for model_id in MODEL_INFO.keys() ] return jsonify({ "object": "list", "data": models }) @app.route('/v1/chat/completions', methods=['POST']) def handle_request(): """处理聊天完成请求。""" try: request_data = request.get_json() model_id = request_data.get('model', '') stream = request_data.get('stream', False) prompt_tokens = count_message_tokens( request_data.get('messages', []), model_id ) payload = build_payload(request_data, model_id) response = make_request(payload) if stream: return Response( stream_with_context(generate_stream_response(response, model_id, prompt_tokens)), content_type=CONTENT_TYPE_EVENT_STREAM ) else: return handle_non_stream_response(response, model_id, prompt_tokens) except requests.RequestException as e: logger.error("Request error: %s", str(e), exc_info=True) return jsonify({ 'error': { 'message': 'Error communicating with the API', 'type': 'api_error', 'param': None, 'code': None, 'details': str(e) } }), 503 except JSONDecodeError as e: logger.error("JSON decode error: %s", str(e), exc_info=True) return jsonify({ 'error': { 'message': 'Invalid JSON in request', 'type': 'invalid_request_error', 'param': None, 'code': None, 'details': str(e) } }), 400 except Exception as e: logger.error("Unexpected error: %s", str(e), exc_info=True) return jsonify({ 'error': { 'message': 'Internal Server Error', 'type': 'server_error', 'param': None, 'code': None, 'details': str(e) } }), 500 if __name__ == "__main__": port = int(os.environ.get("PORT", DEFAULT_PORT)) app.run(debug=False, host='0.0.0.0', port=port, threaded=True)