import os import time import logging from flask import Flask, request, Response import requests import json import concurrent.futures # 配置日志 - 只使用控制台输出,避免文件权限问题 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) app = Flask(__name__) # 全局超时设置(秒) REQUEST_TIMEOUT = 90 def estimate_tokens(text): """估算文本的token数量(保守估计:1 token ≈ 2个字符)""" return len(text) // 2 if text else 0 def truncate_messages(messages, max_tokens=15000): """ 截断消息列表以满足token限制 """ if not messages: logger.info("消息列表为空,无需截断") return messages original_count = len(messages) original_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in messages]) logger.info(f"原始消息数量: {original_count}, 原始token估算: {original_tokens}") if original_tokens <= max_tokens: logger.info(f"原始token({original_tokens}) <= 限制({max_tokens}),无需截断") return messages if len(messages) <= 2: logger.info(f"消息数量 <= 2,直接返回,发送token: {original_tokens}") return messages # 1. 最高优先级:保留最后两条消息 last_two = messages[-2:] last_two_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in last_two]) logger.info(f"最后两条消息token: {last_two_tokens}") if last_two_tokens > max_tokens: logger.warning(f"最后两条消息({last_two_tokens} tokens)超过限制({max_tokens} tokens),需要截断") chars_per_msg = (max_tokens * 2) // len(last_two) truncated_last = [] for i, msg in enumerate(last_two): content = msg.get("content", "") if len(content) > chars_per_msg: if i == len(last_two) - 1: truncated_content = "..." + content[-chars_per_msg+3:] else: truncated_content = content[:chars_per_msg-3] + "..." msg = {**msg, "content": truncated_content} truncated_last.append(msg) final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in truncated_last]) logger.info(f"截断后仅保留最后两条消息,发送token: {final_tokens}") return truncated_last available_for_start = max_tokens - last_two_tokens logger.info(f"可用于开始对话的token容量: {available_for_start}") start_messages = messages[:-2] preserved_start = [] current_tokens = 0 for i, msg in enumerate(start_messages): content = msg.get("content", "") msg_tokens = estimate_tokens(content) if current_tokens + msg_tokens <= available_for_start: preserved_start.append(msg) current_tokens += msg_tokens else: remaining_tokens = available_for_start - current_tokens remaining_chars = remaining_tokens * 2 if remaining_chars > 100: truncated_content = content[:remaining_chars-3] + "..." truncated_msg = {**msg, "content": truncated_content} preserved_start.append(truncated_msg) current_tokens += estimate_tokens(truncated_content) break result = preserved_start + last_two final_count = len(result) final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in result]) logger.info(f"截断完成: {original_count} -> {final_count} 条消息, {original_tokens} -> {final_tokens} tokens") return result def process_request(url, headers, json_data, timeout): """处理请求并返回响应对象""" try: return requests.post(url, headers=headers, json=json_data, timeout=timeout) except requests.exceptions.Timeout: logger.warning("上游请求超时") return None except requests.exceptions.RequestException as e: logger.error(f"上游请求异常: {e}") return None @app.route('/v1/chat/completions', methods=['POST']) def claude(): start_time = time.time() try: json_data = request.get_json() if not json_data: logger.error("缺少JSON数据") return Response(json.dumps({"error": "Missing JSON data"}), 400, content_type='application/json') model = json_data.get('model', 'unknown') stream_flag = json_data.get('stream', False) logger.info(f"收到请求 - 模型: {model}, 流式: {stream_flag}") if 'messages' in json_data and json_data['messages']: original_msg_count = len(json_data['messages']) json_data['messages'] = truncate_messages(json_data['messages'], max_tokens=15000) final_msg_count = len(json_data['messages']) if original_msg_count != final_msg_count: logger.info(f"消息截断: {original_msg_count} -> {final_msg_count}") headers = {key: value for key, value in request.headers if key not in ['Host', 'User-Agent', "Accept-Encoding", "Accept", "Connection", "Content-Length"]} url = "https://aiho.st/proxy/aws/claude/chat/completions" if stream_flag: with concurrent.futures.ThreadPoolExecutor() as executor: session = requests.Session() req = requests.Request('POST', url, headers=headers, json=json_data).prepare() future = executor.submit(session.send, req, stream=True, timeout=REQUEST_TIMEOUT - 5) try: response = future.result(timeout=REQUEST_TIMEOUT) logger.info("流式请求成功发送") except concurrent.futures.TimeoutError: logger.error("流式请求处理超时") return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json') if response is None: logger.error("上游服务器错误") return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json') def generate(): last_chunk_time = time.time() chunk_count = 0 try: for chunk in response.iter_content(chunk_size=4096): if chunk: last_chunk_time = time.time() chunk_count += 1 yield chunk current_time = time.time() if current_time - start_time > REQUEST_TIMEOUT: logger.warning("达到全局超时限制") break if time.time() - last_chunk_time > 10: logger.warning("数据块间隔超时") break finally: response.close() total_time = time.time() - start_time logger.info(f"流式响应完成 - 数据块数: {chunk_count}, 总耗时: {total_time:.2f}秒") return Response(generate(), content_type=response.headers.get('Content-Type', 'application/octet-stream')) else: with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(process_request, url, headers, json_data, REQUEST_TIMEOUT - 5) try: response = future.result(timeout=REQUEST_TIMEOUT) total_time = time.time() - start_time logger.info(f"普通请求完成 - 耗时: {total_time:.2f}秒") except concurrent.futures.TimeoutError: logger.error("普通请求处理超时") return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json') if response is None: logger.error("上游服务器错误") return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json') return Response(response.content, content_type=response.headers.get('Content-Type', 'application/json')) except Exception as e: logger.error(f"内部服务器错误: {e}", exc_info=True) return Response(json.dumps({"error": "Internal server error"}), 500, content_type='application/json') @app.route('/health', methods=['GET']) def health_check(): """健康检查端点""" return Response(json.dumps({ "status": "healthy", "timeout": REQUEST_TIMEOUT }), 200, content_type='application/json') @app.route('/', methods=['GET']) def index(): """首页""" return Response(json.dumps({ "message": "Claude API Proxy", "endpoints": { "chat": "/v1/chat/completions", "health": "/health" } }), 200, content_type='application/json') if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) logger.info(f"Flask应用启动在端口 {port}") app.run(debug=False, host='0.0.0.0', port=port, threaded=True)