Rfym21 commited on
Commit
b8e982a
·
verified ·
1 Parent(s): f6631ce

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -187
app.py DELETED
@@ -1,187 +0,0 @@
1
- import logging
2
- import os
3
- import random
4
- import time
5
- from concurrent.futures import ThreadPoolExecutor
6
- from json import JSONDecodeError
7
-
8
- import requests
9
- from flask import Flask, Response, jsonify, request, stream_with_context
10
- from flask_cors import CORS
11
-
12
- from auth_utils import AuthManager
13
- from constants import (
14
- CONTENT_TYPE_EVENT_STREAM,
15
- DEFAULT_AUTH_EMAIL,
16
- DEFAULT_AUTH_PASSWORD,
17
- DEFAULT_NOTDIAMOND_URL,
18
- DEFAULT_PORT,
19
- DEFAULT_TEMPERATURE,
20
- MAX_WORKERS,
21
- SYSTEM_MESSAGE_CONTENT,
22
- USER_AGENT,
23
- )
24
- from model_info import MODEL_INFO
25
- from utils import count_message_tokens, handle_non_stream_response, generate_stream_response
26
-
27
- # 配置日志
28
- logging.basicConfig(level=logging.INFO)
29
- logger = logging.getLogger(__name__)
30
-
31
- # 初始化 Flask 应用
32
- app = Flask(__name__)
33
- CORS(app, resources={r"/*": {"origins": "*"}})
34
-
35
- # 初始化线程池和其他全局变量
36
- executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
37
- proxy_url = os.getenv('PROXY_URL')
38
- NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', DEFAULT_NOTDIAMOND_URL).split(',')
39
-
40
- # 初始化认证管理器
41
- auth_manager = AuthManager(
42
- os.getenv("AUTH_EMAIL", DEFAULT_AUTH_EMAIL),
43
- os.getenv("AUTH_PASSWORD", DEFAULT_AUTH_PASSWORD),
44
- )
45
-
46
- def get_notdiamond_url():
47
- """随机选择并返回一个 notdiamond URL。"""
48
- return random.choice(NOTDIAMOND_URLS)
49
-
50
- def get_notdiamond_headers():
51
- """返回用于 notdiamond API 请求的头信息。"""
52
- jwt = auth_manager.get_jwt_value()
53
- if not jwt:
54
- auth_manager.login()
55
- jwt = auth_manager.get_jwt_value()
56
-
57
- return {
58
- 'accept': CONTENT_TYPE_EVENT_STREAM,
59
- 'accept-language': 'zh-CN,zh;q=0.9',
60
- 'content-type': 'application/json',
61
- 'user-agent': USER_AGENT,
62
- 'authorization': f'Bearer {jwt}'
63
- }
64
-
65
- def build_payload(request_data, model_id):
66
- """构建请求有效负载。"""
67
- messages = request_data.get('messages', [])
68
-
69
- if not any(message.get('role') == 'system' for message in messages):
70
- system_message = {
71
- "role": "system",
72
- "content": SYSTEM_MESSAGE_CONTENT
73
- }
74
- messages.insert(0, system_message)
75
-
76
- mapping = MODEL_INFO.get(model_id, {}).get('mapping', model_id)
77
- payload = {
78
- key: value for key, value in request_data.items()
79
- if key not in ('stream',)
80
- }
81
- payload['messages'] = messages
82
- payload['model'] = mapping
83
- payload['temperature'] = request_data.get('temperature', DEFAULT_TEMPERATURE)
84
-
85
- return payload
86
-
87
- def make_request(payload):
88
- """发送请求并处理可能的认证刷新。"""
89
- url = get_notdiamond_url()
90
-
91
- for _ in range(3): # 最多尝试3次
92
- headers = get_notdiamond_headers()
93
- response = executor.submit(
94
- requests.post,
95
- url,
96
- headers=headers,
97
- json=payload,
98
- stream=True
99
- ).result()
100
-
101
- if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
102
- return response
103
-
104
- auth_manager.refresh_user_token()
105
-
106
- return response # 如果所有尝试都失败,返回最后一次的响应
107
-
108
- @app.route('/v1/models', methods=['GET'])
109
- def proxy_models():
110
- """返回可用模型列表。"""
111
- models = [
112
- {
113
- "id": model_id,
114
- "object": "model",
115
- "created": int(time.time()),
116
- "owned_by": "notdiamond",
117
- "permission": [],
118
- "root": model_id,
119
- "parent": None,
120
- } for model_id in MODEL_INFO.keys()
121
- ]
122
- return jsonify({
123
- "object": "list",
124
- "data": models
125
- })
126
-
127
- @app.route('/v2/chat/completions', methods=['POST'])
128
- def handle_request():
129
- """处理聊天完成请求。"""
130
- try:
131
- request_data = request.get_json()
132
- model_id = request_data.get('model', '')
133
- stream = request_data.get('stream', False)
134
-
135
- prompt_tokens = count_message_tokens(
136
- request_data.get('messages', []),
137
- model_id
138
- )
139
-
140
- payload = build_payload(request_data, model_id)
141
- response = make_request(payload)
142
-
143
- if stream:
144
- return Response(
145
- stream_with_context(generate_stream_response(response, model_id, prompt_tokens)),
146
- content_type=CONTENT_TYPE_EVENT_STREAM
147
- )
148
- else:
149
- return handle_non_stream_response(response, model_id, prompt_tokens)
150
-
151
- except requests.RequestException as e:
152
- logger.error("Request error: %s", str(e), exc_info=True)
153
- return jsonify({
154
- 'error': {
155
- 'message': 'Error communicating with the API',
156
- 'type': 'api_error',
157
- 'param': None,
158
- 'code': None,
159
- 'details': str(e)
160
- }
161
- }), 503
162
- except JSONDecodeError as e:
163
- logger.error("JSON decode error: %s", str(e), exc_info=True)
164
- return jsonify({
165
- 'error': {
166
- 'message': 'Invalid JSON in request',
167
- 'type': 'invalid_request_error',
168
- 'param': None,
169
- 'code': None,
170
- 'details': str(e)
171
- }
172
- }), 400
173
- except Exception as e:
174
- logger.error("Unexpected error: %s", str(e), exc_info=True)
175
- return jsonify({
176
- 'error': {
177
- 'message': 'Internal Server Error',
178
- 'type': 'server_error',
179
- 'param': None,
180
- 'code': None,
181
- 'details': str(e)
182
- }
183
- }), 500
184
-
185
- if __name__ == "__main__":
186
- port = int(os.environ.get("PORT", DEFAULT_PORT))
187
- app.run(debug=False, host='0.0.0.0', port=port, threaded=True)