AWS / app.py
LAJILAODEEAIQ's picture
Create app.py
7c3a45a verified
import os
import time
import logging
from flask import Flask, request, Response
import requests
import json
import concurrent.futures
# 配置日志 - 只使用控制台输出,避免文件权限问题
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# 全局超时设置(秒)
REQUEST_TIMEOUT = 90
def estimate_tokens(text):
"""估算文本的token数量(保守估计:1 token ≈ 2个字符)"""
return len(text) // 2 if text else 0
def truncate_messages(messages, max_tokens=15000):
"""
截断消息列表以满足token限制
"""
if not messages:
logger.info("消息列表为空,无需截断")
return messages
original_count = len(messages)
original_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in messages])
logger.info(f"原始消息数量: {original_count}, 原始token估算: {original_tokens}")
if original_tokens <= max_tokens:
logger.info(f"原始token({original_tokens}) <= 限制({max_tokens}),无需截断")
return messages
if len(messages) <= 2:
logger.info(f"消息数量 <= 2,直接返回,发送token: {original_tokens}")
return messages
# 1. 最高优先级:保留最后两条消息
last_two = messages[-2:]
last_two_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in last_two])
logger.info(f"最后两条消息token: {last_two_tokens}")
if last_two_tokens > max_tokens:
logger.warning(f"最后两条消息({last_two_tokens} tokens)超过限制({max_tokens} tokens),需要截断")
chars_per_msg = (max_tokens * 2) // len(last_two)
truncated_last = []
for i, msg in enumerate(last_two):
content = msg.get("content", "")
if len(content) > chars_per_msg:
if i == len(last_two) - 1:
truncated_content = "..." + content[-chars_per_msg+3:]
else:
truncated_content = content[:chars_per_msg-3] + "..."
msg = {**msg, "content": truncated_content}
truncated_last.append(msg)
final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in truncated_last])
logger.info(f"截断后仅保留最后两条消息,发送token: {final_tokens}")
return truncated_last
available_for_start = max_tokens - last_two_tokens
logger.info(f"可用于开始对话的token容量: {available_for_start}")
start_messages = messages[:-2]
preserved_start = []
current_tokens = 0
for i, msg in enumerate(start_messages):
content = msg.get("content", "")
msg_tokens = estimate_tokens(content)
if current_tokens + msg_tokens <= available_for_start:
preserved_start.append(msg)
current_tokens += msg_tokens
else:
remaining_tokens = available_for_start - current_tokens
remaining_chars = remaining_tokens * 2
if remaining_chars > 100:
truncated_content = content[:remaining_chars-3] + "..."
truncated_msg = {**msg, "content": truncated_content}
preserved_start.append(truncated_msg)
current_tokens += estimate_tokens(truncated_content)
break
result = preserved_start + last_two
final_count = len(result)
final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in result])
logger.info(f"截断完成: {original_count} -> {final_count} 条消息, {original_tokens} -> {final_tokens} tokens")
return result
def process_request(url, headers, json_data, timeout):
"""处理请求并返回响应对象"""
try:
return requests.post(url, headers=headers, json=json_data, timeout=timeout)
except requests.exceptions.Timeout:
logger.warning("上游请求超时")
return None
except requests.exceptions.RequestException as e:
logger.error(f"上游请求异常: {e}")
return None
@app.route('/v1/chat/completions', methods=['POST'])
def claude():
start_time = time.time()
try:
json_data = request.get_json()
if not json_data:
logger.error("缺少JSON数据")
return Response(json.dumps({"error": "Missing JSON data"}), 400, content_type='application/json')
model = json_data.get('model', 'unknown')
stream_flag = json_data.get('stream', False)
logger.info(f"收到请求 - 模型: {model}, 流式: {stream_flag}")
if 'messages' in json_data and json_data['messages']:
original_msg_count = len(json_data['messages'])
json_data['messages'] = truncate_messages(json_data['messages'], max_tokens=15000)
final_msg_count = len(json_data['messages'])
if original_msg_count != final_msg_count:
logger.info(f"消息截断: {original_msg_count} -> {final_msg_count}")
headers = {key: value for key, value in request.headers if
key not in ['Host', 'User-Agent', "Accept-Encoding", "Accept", "Connection", "Content-Length"]}
url = "https://aiho.st/proxy/aws/claude/chat/completions"
if stream_flag:
with concurrent.futures.ThreadPoolExecutor() as executor:
session = requests.Session()
req = requests.Request('POST', url, headers=headers, json=json_data).prepare()
future = executor.submit(session.send, req, stream=True, timeout=REQUEST_TIMEOUT - 5)
try:
response = future.result(timeout=REQUEST_TIMEOUT)
logger.info("流式请求成功发送")
except concurrent.futures.TimeoutError:
logger.error("流式请求处理超时")
return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json')
if response is None:
logger.error("上游服务器错误")
return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json')
def generate():
last_chunk_time = time.time()
chunk_count = 0
try:
for chunk in response.iter_content(chunk_size=4096):
if chunk:
last_chunk_time = time.time()
chunk_count += 1
yield chunk
current_time = time.time()
if current_time - start_time > REQUEST_TIMEOUT:
logger.warning("达到全局超时限制")
break
if time.time() - last_chunk_time > 10:
logger.warning("数据块间隔超时")
break
finally:
response.close()
total_time = time.time() - start_time
logger.info(f"流式响应完成 - 数据块数: {chunk_count}, 总耗时: {total_time:.2f}秒")
return Response(generate(), content_type=response.headers.get('Content-Type', 'application/octet-stream'))
else:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(process_request, url, headers, json_data, REQUEST_TIMEOUT - 5)
try:
response = future.result(timeout=REQUEST_TIMEOUT)
total_time = time.time() - start_time
logger.info(f"普通请求完成 - 耗时: {total_time:.2f}秒")
except concurrent.futures.TimeoutError:
logger.error("普通请求处理超时")
return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json')
if response is None:
logger.error("上游服务器错误")
return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json')
return Response(response.content, content_type=response.headers.get('Content-Type', 'application/json'))
except Exception as e:
logger.error(f"内部服务器错误: {e}", exc_info=True)
return Response(json.dumps({"error": "Internal server error"}), 500, content_type='application/json')
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查端点"""
return Response(json.dumps({
"status": "healthy",
"timeout": REQUEST_TIMEOUT
}), 200, content_type='application/json')
@app.route('/', methods=['GET'])
def index():
"""首页"""
return Response(json.dumps({
"message": "Claude API Proxy",
"endpoints": {
"chat": "/v1/chat/completions",
"health": "/health"
}
}), 200, content_type='application/json')
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
logger.info(f"Flask应用启动在端口 {port}")
app.run(debug=False, host='0.0.0.0', port=port, threaded=True)