| import logging |
| import os |
| import random |
| import time |
| from concurrent.futures import ThreadPoolExecutor |
| from json import JSONDecodeError |
|
|
| import requests |
| from flask import Flask, Response, jsonify, request, stream_with_context |
| from flask_cors import CORS |
|
|
| from auth_utils import AuthManager |
| from constants import ( |
| CONTENT_TYPE_EVENT_STREAM, |
| DEFAULT_AUTH_EMAIL, |
| DEFAULT_AUTH_PASSWORD, |
| DEFAULT_NOTDIAMOND_URL, |
| DEFAULT_PORT, |
| DEFAULT_TEMPERATURE, |
| MAX_WORKERS, |
| SYSTEM_MESSAGE_CONTENT, |
| USER_AGENT, |
| ) |
| from model_info import MODEL_INFO |
| from utils import count_message_tokens, handle_non_stream_response, generate_stream_response |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = Flask(__name__) |
| CORS(app, resources={r"/*": {"origins": "*"}}) |
|
|
| |
| executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) |
| proxy_url = os.getenv('PROXY_URL') |
| NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', DEFAULT_NOTDIAMOND_URL).split(',') |
|
|
| |
| auth_manager = AuthManager( |
| os.getenv("AUTH_EMAIL", DEFAULT_AUTH_EMAIL), |
| os.getenv("AUTH_PASSWORD", DEFAULT_AUTH_PASSWORD), |
| ) |
|
|
| def get_notdiamond_url(): |
| """随机选择并返回一个 notdiamond URL。""" |
| return random.choice(NOTDIAMOND_URLS) |
|
|
| def get_notdiamond_headers(): |
| """返回用于 notdiamond API 请求的头信息。""" |
| jwt = auth_manager.get_jwt_value() |
| if not jwt: |
| auth_manager.login() |
| jwt = auth_manager.get_jwt_value() |
| |
| return { |
| 'accept': CONTENT_TYPE_EVENT_STREAM, |
| 'accept-language': 'zh-CN,zh;q=0.9', |
| 'content-type': 'application/json', |
| 'user-agent': USER_AGENT, |
| 'authorization': f'Bearer {jwt}' |
| } |
|
|
| def build_payload(request_data, model_id): |
| """构建请求有效负载。""" |
| messages = request_data.get('messages', []) |
| |
| if not any(message.get('role') == 'system' for message in messages): |
| system_message = { |
| "role": "system", |
| "content": SYSTEM_MESSAGE_CONTENT |
| } |
| messages.insert(0, system_message) |
| |
| mapping = MODEL_INFO.get(model_id, {}).get('mapping', model_id) |
| payload = { |
| key: value for key, value in request_data.items() |
| if key not in ('stream',) |
| } |
| payload['messages'] = messages |
| payload['model'] = mapping |
| payload['temperature'] = request_data.get('temperature', DEFAULT_TEMPERATURE) |
| |
| return payload |
|
|
| def make_request(payload): |
| """发送请求并处理可能的认证刷新。""" |
| url = get_notdiamond_url() |
| |
| for _ in range(3): |
| headers = get_notdiamond_headers() |
| response = executor.submit( |
| requests.post, |
| url, |
| headers=headers, |
| json=payload, |
| stream=True |
| ).result() |
| |
| if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream': |
| return response |
| |
| auth_manager.refresh_user_token() |
| |
| return response |
|
|
| @app.route('/v1/models', methods=['GET']) |
| def proxy_models(): |
| """返回可用模型列表。""" |
| models = [ |
| { |
| "id": model_id, |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "notdiamond", |
| "permission": [], |
| "root": model_id, |
| "parent": None, |
| } for model_id in MODEL_INFO.keys() |
| ] |
| return jsonify({ |
| "object": "list", |
| "data": models |
| }) |
|
|
| @app.route('/v2/chat/completions', methods=['POST']) |
| def handle_request(): |
| """处理聊天完成请求。""" |
| try: |
| request_data = request.get_json() |
| model_id = request_data.get('model', '') |
| stream = request_data.get('stream', False) |
|
|
| prompt_tokens = count_message_tokens( |
| request_data.get('messages', []), |
| model_id |
| ) |
|
|
| payload = build_payload(request_data, model_id) |
| response = make_request(payload) |
|
|
| if stream: |
| return Response( |
| stream_with_context(generate_stream_response(response, model_id, prompt_tokens)), |
| content_type=CONTENT_TYPE_EVENT_STREAM |
| ) |
| else: |
| return handle_non_stream_response(response, model_id, prompt_tokens) |
| |
| except requests.RequestException as e: |
| logger.error("Request error: %s", str(e), exc_info=True) |
| return jsonify({ |
| 'error': { |
| 'message': 'Error communicating with the API', |
| 'type': 'api_error', |
| 'param': None, |
| 'code': None, |
| 'details': str(e) |
| } |
| }), 503 |
| except JSONDecodeError as e: |
| logger.error("JSON decode error: %s", str(e), exc_info=True) |
| return jsonify({ |
| 'error': { |
| 'message': 'Invalid JSON in request', |
| 'type': 'invalid_request_error', |
| 'param': None, |
| 'code': None, |
| 'details': str(e) |
| } |
| }), 400 |
| except Exception as e: |
| logger.error("Unexpected error: %s", str(e), exc_info=True) |
| return jsonify({ |
| 'error': { |
| 'message': 'Internal Server Error', |
| 'type': 'server_error', |
| 'param': None, |
| 'code': None, |
| 'details': str(e) |
| } |
| }), 500 |
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", DEFAULT_PORT)) |
| app.run(debug=False, host='0.0.0.0', port=port, threaded=True) |