LAJILAODEEAIQ commited on
Commit
7c3a45a
·
verified ·
1 Parent(s): 23e2633

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ from flask import Flask, request, Response
5
+ import requests
6
+ import json
7
+ import concurrent.futures
8
+
9
+ # 配置日志 - 只使用控制台输出,避免文件权限问题
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
13
+ handlers=[logging.StreamHandler()]
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+ app = Flask(__name__)
18
+
19
+ # 全局超时设置(秒)
20
+ REQUEST_TIMEOUT = 90
21
+
22
+ def estimate_tokens(text):
23
+ """估算文本的token数量(保守估计:1 token ≈ 2个字符)"""
24
+ return len(text) // 2 if text else 0
25
+
26
+ def truncate_messages(messages, max_tokens=15000):
27
+ """
28
+ 截断消息列表以满足token限制
29
+ """
30
+ if not messages:
31
+ logger.info("消息列表为空,无需截断")
32
+ return messages
33
+
34
+ original_count = len(messages)
35
+ original_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in messages])
36
+ logger.info(f"原始消息数量: {original_count}, 原始token估算: {original_tokens}")
37
+
38
+ if original_tokens <= max_tokens:
39
+ logger.info(f"原始token({original_tokens}) <= 限制({max_tokens}),无需截断")
40
+ return messages
41
+
42
+ if len(messages) <= 2:
43
+ logger.info(f"消息数量 <= 2,直接返回,发送token: {original_tokens}")
44
+ return messages
45
+
46
+ # 1. 最高优先级:保留最后两条消息
47
+ last_two = messages[-2:]
48
+ last_two_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in last_two])
49
+ logger.info(f"最后两条消息token: {last_two_tokens}")
50
+
51
+ if last_two_tokens > max_tokens:
52
+ logger.warning(f"最后两条消息({last_two_tokens} tokens)超过限制({max_tokens} tokens),需要截断")
53
+
54
+ chars_per_msg = (max_tokens * 2) // len(last_two)
55
+
56
+ truncated_last = []
57
+ for i, msg in enumerate(last_two):
58
+ content = msg.get("content", "")
59
+ if len(content) > chars_per_msg:
60
+ if i == len(last_two) - 1:
61
+ truncated_content = "..." + content[-chars_per_msg+3:]
62
+ else:
63
+ truncated_content = content[:chars_per_msg-3] + "..."
64
+ msg = {**msg, "content": truncated_content}
65
+ truncated_last.append(msg)
66
+
67
+ final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in truncated_last])
68
+ logger.info(f"截断后仅保留最后两条消息,发送token: {final_tokens}")
69
+ return truncated_last
70
+
71
+ available_for_start = max_tokens - last_two_tokens
72
+ logger.info(f"可用于开始对话的token容量: {available_for_start}")
73
+
74
+ start_messages = messages[:-2]
75
+ preserved_start = []
76
+ current_tokens = 0
77
+
78
+ for i, msg in enumerate(start_messages):
79
+ content = msg.get("content", "")
80
+ msg_tokens = estimate_tokens(content)
81
+
82
+ if current_tokens + msg_tokens <= available_for_start:
83
+ preserved_start.append(msg)
84
+ current_tokens += msg_tokens
85
+ else:
86
+ remaining_tokens = available_for_start - current_tokens
87
+ remaining_chars = remaining_tokens * 2
88
+
89
+ if remaining_chars > 100:
90
+ truncated_content = content[:remaining_chars-3] + "..."
91
+ truncated_msg = {**msg, "content": truncated_content}
92
+ preserved_start.append(truncated_msg)
93
+ current_tokens += estimate_tokens(truncated_content)
94
+ break
95
+
96
+ result = preserved_start + last_two
97
+ final_count = len(result)
98
+ final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in result])
99
+
100
+ logger.info(f"截断完成: {original_count} -> {final_count} 条消息, {original_tokens} -> {final_tokens} tokens")
101
+
102
+ return result
103
+
104
+ def process_request(url, headers, json_data, timeout):
105
+ """处理请求并返回响应对象"""
106
+ try:
107
+ return requests.post(url, headers=headers, json=json_data, timeout=timeout)
108
+ except requests.exceptions.Timeout:
109
+ logger.warning("上游请求超时")
110
+ return None
111
+ except requests.exceptions.RequestException as e:
112
+ logger.error(f"上游请求异常: {e}")
113
+ return None
114
+
115
+ @app.route('/v1/chat/completions', methods=['POST'])
116
+ def claude():
117
+ start_time = time.time()
118
+
119
+ try:
120
+ json_data = request.get_json()
121
+ if not json_data:
122
+ logger.error("缺少JSON数据")
123
+ return Response(json.dumps({"error": "Missing JSON data"}), 400, content_type='application/json')
124
+
125
+ model = json_data.get('model', 'unknown')
126
+ stream_flag = json_data.get('stream', False)
127
+ logger.info(f"收到请求 - 模型: {model}, 流式: {stream_flag}")
128
+
129
+ if 'messages' in json_data and json_data['messages']:
130
+ original_msg_count = len(json_data['messages'])
131
+ json_data['messages'] = truncate_messages(json_data['messages'], max_tokens=15000)
132
+ final_msg_count = len(json_data['messages'])
133
+
134
+ if original_msg_count != final_msg_count:
135
+ logger.info(f"消息截断: {original_msg_count} -> {final_msg_count}")
136
+
137
+ headers = {key: value for key, value in request.headers if
138
+ key not in ['Host', 'User-Agent', "Accept-Encoding", "Accept", "Connection", "Content-Length"]}
139
+
140
+ url = "https://aiho.st/proxy/aws/claude/chat/completions"
141
+
142
+ if stream_flag:
143
+ with concurrent.futures.ThreadPoolExecutor() as executor:
144
+ session = requests.Session()
145
+ req = requests.Request('POST', url, headers=headers, json=json_data).prepare()
146
+
147
+ future = executor.submit(session.send, req, stream=True, timeout=REQUEST_TIMEOUT - 5)
148
+ try:
149
+ response = future.result(timeout=REQUEST_TIMEOUT)
150
+ logger.info("流式请求成功发送")
151
+ except concurrent.futures.TimeoutError:
152
+ logger.error("流式请求处理超时")
153
+ return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json')
154
+
155
+ if response is None:
156
+ logger.error("上游服务器错误")
157
+ return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json')
158
+
159
+ def generate():
160
+ last_chunk_time = time.time()
161
+ chunk_count = 0
162
+ try:
163
+ for chunk in response.iter_content(chunk_size=4096):
164
+ if chunk:
165
+ last_chunk_time = time.time()
166
+ chunk_count += 1
167
+ yield chunk
168
+
169
+ current_time = time.time()
170
+ if current_time - start_time > REQUEST_TIMEOUT:
171
+ logger.warning("达到全局超时限制")
172
+ break
173
+
174
+ if time.time() - last_chunk_time > 10:
175
+ logger.warning("数据块间隔超时")
176
+ break
177
+ finally:
178
+ response.close()
179
+ total_time = time.time() - start_time
180
+ logger.info(f"流式响应完成 - 数据块数: {chunk_count}, 总耗时: {total_time:.2f}秒")
181
+
182
+ return Response(generate(), content_type=response.headers.get('Content-Type', 'application/octet-stream'))
183
+ else:
184
+ with concurrent.futures.ThreadPoolExecutor() as executor:
185
+ future = executor.submit(process_request, url, headers, json_data, REQUEST_TIMEOUT - 5)
186
+ try:
187
+ response = future.result(timeout=REQUEST_TIMEOUT)
188
+ total_time = time.time() - start_time
189
+ logger.info(f"普通请求完成 - 耗时: {total_time:.2f}秒")
190
+ except concurrent.futures.TimeoutError:
191
+ logger.error("普通请求处理超时")
192
+ return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json')
193
+
194
+ if response is None:
195
+ logger.error("上游服务器错误")
196
+ return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json')
197
+
198
+ return Response(response.content, content_type=response.headers.get('Content-Type', 'application/json'))
199
+
200
+ except Exception as e:
201
+ logger.error(f"内部服务器错误: {e}", exc_info=True)
202
+ return Response(json.dumps({"error": "Internal server error"}), 500, content_type='application/json')
203
+
204
+ @app.route('/health', methods=['GET'])
205
+ def health_check():
206
+ """健康检查端点"""
207
+ return Response(json.dumps({
208
+ "status": "healthy",
209
+ "timeout": REQUEST_TIMEOUT
210
+ }), 200, content_type='application/json')
211
+
212
+ @app.route('/', methods=['GET'])
213
+ def index():
214
+ """首页"""
215
+ return Response(json.dumps({
216
+ "message": "Claude API Proxy",
217
+ "endpoints": {
218
+ "chat": "/v1/chat/completions",
219
+ "health": "/health"
220
+ }
221
+ }), 200, content_type='application/json')
222
+
223
+ if __name__ == '__main__':
224
+ port = int(os.environ.get('PORT', 7860))
225
+ logger.info(f"Flask应用启动在端口 {port}")
226
+ app.run(debug=False, host='0.0.0.0', port=port, threaded=True)