|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import time |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from json import JSONDecodeError |
|
|
|
|
|
import requests |
|
|
from flask import Flask, Response, jsonify, request, stream_with_context |
|
|
from flask_cors import CORS |
|
|
|
|
|
from auth_utils import AuthManager |
|
|
from constants import ( |
|
|
CONTENT_TYPE_EVENT_STREAM, |
|
|
DEFAULT_AUTH_EMAIL, |
|
|
DEFAULT_AUTH_PASSWORD, |
|
|
DEFAULT_NOTDIAMOND_URL, |
|
|
DEFAULT_PORT, |
|
|
DEFAULT_TEMPERATURE, |
|
|
MAX_WORKERS, |
|
|
SYSTEM_MESSAGE_CONTENT, |
|
|
USER_AGENT, |
|
|
) |
|
|
from model_info import MODEL_INFO |
|
|
from utils import count_message_tokens, handle_non_stream_response, generate_stream_response |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app, resources={r"/*": {"origins": "*"}}) |
|
|
|
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) |
|
|
proxy_url = os.getenv('PROXY_URL') |
|
|
NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', DEFAULT_NOTDIAMOND_URL).split(',') |
|
|
|
|
|
|
|
|
auth_manager = AuthManager( |
|
|
os.getenv("AUTH_EMAIL", DEFAULT_AUTH_EMAIL), |
|
|
os.getenv("AUTH_PASSWORD", DEFAULT_AUTH_PASSWORD), |
|
|
) |
|
|
|
|
|
def get_notdiamond_url(): |
|
|
"""随机选择并返回一个 notdiamond URL。""" |
|
|
return random.choice(NOTDIAMOND_URLS) |
|
|
|
|
|
def get_notdiamond_headers(): |
|
|
"""返回用于 notdiamond API 请求的头信息。""" |
|
|
jwt = auth_manager.get_jwt_value() |
|
|
if not jwt: |
|
|
auth_manager.login() |
|
|
jwt = auth_manager.get_jwt_value() |
|
|
|
|
|
return { |
|
|
'accept': CONTENT_TYPE_EVENT_STREAM, |
|
|
'accept-language': 'zh-CN,zh;q=0.9', |
|
|
'content-type': 'application/json', |
|
|
'user-agent': USER_AGENT, |
|
|
'authorization': f'Bearer {jwt}' |
|
|
} |
|
|
|
|
|
def build_payload(request_data, model_id): |
|
|
"""构建请求有效负载。""" |
|
|
messages = request_data.get('messages', []) |
|
|
|
|
|
if not any(message.get('role') == 'system' for message in messages): |
|
|
system_message = { |
|
|
"role": "system", |
|
|
"content": SYSTEM_MESSAGE_CONTENT |
|
|
} |
|
|
messages.insert(0, system_message) |
|
|
|
|
|
mapping = MODEL_INFO.get(model_id, {}).get('mapping', model_id) |
|
|
payload = { |
|
|
key: value for key, value in request_data.items() |
|
|
if key not in ('stream',) |
|
|
} |
|
|
payload['messages'] = messages |
|
|
payload['model'] = mapping |
|
|
payload['temperature'] = request_data.get('temperature', DEFAULT_TEMPERATURE) |
|
|
|
|
|
return payload |
|
|
|
|
|
def make_request(payload): |
|
|
"""发送请求并处理可能的认证刷新。""" |
|
|
url = get_notdiamond_url() |
|
|
|
|
|
for _ in range(3): |
|
|
headers = get_notdiamond_headers() |
|
|
response = executor.submit( |
|
|
requests.post, |
|
|
url, |
|
|
headers=headers, |
|
|
json=payload, |
|
|
stream=True |
|
|
).result() |
|
|
|
|
|
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream': |
|
|
return response |
|
|
|
|
|
auth_manager.refresh_user_token() |
|
|
|
|
|
return response |
|
|
|
|
|
@app.route('/v1/models', methods=['GET']) |
|
|
def proxy_models(): |
|
|
"""返回可用模型列表。""" |
|
|
models = [ |
|
|
{ |
|
|
"id": model_id, |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "notdiamond", |
|
|
"permission": [], |
|
|
"root": model_id, |
|
|
"parent": None, |
|
|
} for model_id in MODEL_INFO.keys() |
|
|
] |
|
|
return jsonify({ |
|
|
"object": "list", |
|
|
"data": models |
|
|
}) |
|
|
|
|
|
@app.route('/v1/chat/completions', methods=['POST']) |
|
|
def handle_request(): |
|
|
"""处理聊天完成请求。""" |
|
|
try: |
|
|
request_data = request.get_json() |
|
|
model_id = request_data.get('model', '') |
|
|
stream = request_data.get('stream', False) |
|
|
|
|
|
prompt_tokens = count_message_tokens( |
|
|
request_data.get('messages', []), |
|
|
model_id |
|
|
) |
|
|
|
|
|
payload = build_payload(request_data, model_id) |
|
|
response = make_request(payload) |
|
|
|
|
|
if stream: |
|
|
return Response( |
|
|
stream_with_context(generate_stream_response(response, model_id, prompt_tokens)), |
|
|
content_type=CONTENT_TYPE_EVENT_STREAM |
|
|
) |
|
|
else: |
|
|
return handle_non_stream_response(response, model_id, prompt_tokens) |
|
|
|
|
|
except requests.RequestException as e: |
|
|
logger.error("Request error: %s", str(e), exc_info=True) |
|
|
return jsonify({ |
|
|
'error': { |
|
|
'message': 'Error communicating with the API', |
|
|
'type': 'api_error', |
|
|
'param': None, |
|
|
'code': None, |
|
|
'details': str(e) |
|
|
} |
|
|
}), 503 |
|
|
except JSONDecodeError as e: |
|
|
logger.error("JSON decode error: %s", str(e), exc_info=True) |
|
|
return jsonify({ |
|
|
'error': { |
|
|
'message': 'Invalid JSON in request', |
|
|
'type': 'invalid_request_error', |
|
|
'param': None, |
|
|
'code': None, |
|
|
'details': str(e) |
|
|
} |
|
|
}), 400 |
|
|
except Exception as e: |
|
|
logger.error("Unexpected error: %s", str(e), exc_info=True) |
|
|
return jsonify({ |
|
|
'error': { |
|
|
'message': 'Internal Server Error', |
|
|
'type': 'server_error', |
|
|
'param': None, |
|
|
'code': None, |
|
|
'details': str(e) |
|
|
} |
|
|
}), 500 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", DEFAULT_PORT)) |
|
|
app.run(debug=False, host='0.0.0.0', port=port, threaded=True) |