Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import logging | |
| from flask import Flask, request, Response | |
| import requests | |
| import json | |
| import concurrent.futures | |
| # 配置日志 - 只使用控制台输出,避免文件权限问题 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| # 全局超时设置(秒) | |
| REQUEST_TIMEOUT = 90 | |
| def estimate_tokens(text): | |
| """估算文本的token数量(保守估计:1 token ≈ 2个字符)""" | |
| return len(text) // 2 if text else 0 | |
| def truncate_messages(messages, max_tokens=15000): | |
| """ | |
| 截断消息列表以满足token限制 | |
| """ | |
| if not messages: | |
| logger.info("消息列表为空,无需截断") | |
| return messages | |
| original_count = len(messages) | |
| original_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in messages]) | |
| logger.info(f"原始消息数量: {original_count}, 原始token估算: {original_tokens}") | |
| if original_tokens <= max_tokens: | |
| logger.info(f"原始token({original_tokens}) <= 限制({max_tokens}),无需截断") | |
| return messages | |
| if len(messages) <= 2: | |
| logger.info(f"消息数量 <= 2,直接返回,发送token: {original_tokens}") | |
| return messages | |
| # 1. 最高优先级:保留最后两条消息 | |
| last_two = messages[-2:] | |
| last_two_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in last_two]) | |
| logger.info(f"最后两条消息token: {last_two_tokens}") | |
| if last_two_tokens > max_tokens: | |
| logger.warning(f"最后两条消息({last_two_tokens} tokens)超过限制({max_tokens} tokens),需要截断") | |
| chars_per_msg = (max_tokens * 2) // len(last_two) | |
| truncated_last = [] | |
| for i, msg in enumerate(last_two): | |
| content = msg.get("content", "") | |
| if len(content) > chars_per_msg: | |
| if i == len(last_two) - 1: | |
| truncated_content = "..." + content[-chars_per_msg+3:] | |
| else: | |
| truncated_content = content[:chars_per_msg-3] + "..." | |
| msg = {**msg, "content": truncated_content} | |
| truncated_last.append(msg) | |
| final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in truncated_last]) | |
| logger.info(f"截断后仅保留最后两条消息,发送token: {final_tokens}") | |
| return truncated_last | |
| available_for_start = max_tokens - last_two_tokens | |
| logger.info(f"可用于开始对话的token容量: {available_for_start}") | |
| start_messages = messages[:-2] | |
| preserved_start = [] | |
| current_tokens = 0 | |
| for i, msg in enumerate(start_messages): | |
| content = msg.get("content", "") | |
| msg_tokens = estimate_tokens(content) | |
| if current_tokens + msg_tokens <= available_for_start: | |
| preserved_start.append(msg) | |
| current_tokens += msg_tokens | |
| else: | |
| remaining_tokens = available_for_start - current_tokens | |
| remaining_chars = remaining_tokens * 2 | |
| if remaining_chars > 100: | |
| truncated_content = content[:remaining_chars-3] + "..." | |
| truncated_msg = {**msg, "content": truncated_content} | |
| preserved_start.append(truncated_msg) | |
| current_tokens += estimate_tokens(truncated_content) | |
| break | |
| result = preserved_start + last_two | |
| final_count = len(result) | |
| final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in result]) | |
| logger.info(f"截断完成: {original_count} -> {final_count} 条消息, {original_tokens} -> {final_tokens} tokens") | |
| return result | |
| def process_request(url, headers, json_data, timeout): | |
| """处理请求并返回响应对象""" | |
| try: | |
| return requests.post(url, headers=headers, json=json_data, timeout=timeout) | |
| except requests.exceptions.Timeout: | |
| logger.warning("上游请求超时") | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"上游请求异常: {e}") | |
| return None | |
| def claude(): | |
| start_time = time.time() | |
| try: | |
| json_data = request.get_json() | |
| if not json_data: | |
| logger.error("缺少JSON数据") | |
| return Response(json.dumps({"error": "Missing JSON data"}), 400, content_type='application/json') | |
| model = json_data.get('model', 'unknown') | |
| stream_flag = json_data.get('stream', False) | |
| logger.info(f"收到请求 - 模型: {model}, 流式: {stream_flag}") | |
| if 'messages' in json_data and json_data['messages']: | |
| original_msg_count = len(json_data['messages']) | |
| json_data['messages'] = truncate_messages(json_data['messages'], max_tokens=15000) | |
| final_msg_count = len(json_data['messages']) | |
| if original_msg_count != final_msg_count: | |
| logger.info(f"消息截断: {original_msg_count} -> {final_msg_count}") | |
| headers = {key: value for key, value in request.headers if | |
| key not in ['Host', 'User-Agent', "Accept-Encoding", "Accept", "Connection", "Content-Length"]} | |
| url = "https://aiho.st/proxy/aws/claude/chat/completions" | |
| if stream_flag: | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| session = requests.Session() | |
| req = requests.Request('POST', url, headers=headers, json=json_data).prepare() | |
| future = executor.submit(session.send, req, stream=True, timeout=REQUEST_TIMEOUT - 5) | |
| try: | |
| response = future.result(timeout=REQUEST_TIMEOUT) | |
| logger.info("流式请求成功发送") | |
| except concurrent.futures.TimeoutError: | |
| logger.error("流式请求处理超时") | |
| return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json') | |
| if response is None: | |
| logger.error("上游服务器错误") | |
| return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json') | |
| def generate(): | |
| last_chunk_time = time.time() | |
| chunk_count = 0 | |
| try: | |
| for chunk in response.iter_content(chunk_size=4096): | |
| if chunk: | |
| last_chunk_time = time.time() | |
| chunk_count += 1 | |
| yield chunk | |
| current_time = time.time() | |
| if current_time - start_time > REQUEST_TIMEOUT: | |
| logger.warning("达到全局超时限制") | |
| break | |
| if time.time() - last_chunk_time > 10: | |
| logger.warning("数据块间隔超时") | |
| break | |
| finally: | |
| response.close() | |
| total_time = time.time() - start_time | |
| logger.info(f"流式响应完成 - 数据块数: {chunk_count}, 总耗时: {total_time:.2f}秒") | |
| return Response(generate(), content_type=response.headers.get('Content-Type', 'application/octet-stream')) | |
| else: | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(process_request, url, headers, json_data, REQUEST_TIMEOUT - 5) | |
| try: | |
| response = future.result(timeout=REQUEST_TIMEOUT) | |
| total_time = time.time() - start_time | |
| logger.info(f"普通请求完成 - 耗时: {total_time:.2f}秒") | |
| except concurrent.futures.TimeoutError: | |
| logger.error("普通请求处理超时") | |
| return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json') | |
| if response is None: | |
| logger.error("上游服务器错误") | |
| return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json') | |
| return Response(response.content, content_type=response.headers.get('Content-Type', 'application/json')) | |
| except Exception as e: | |
| logger.error(f"内部服务器错误: {e}", exc_info=True) | |
| return Response(json.dumps({"error": "Internal server error"}), 500, content_type='application/json') | |
| def health_check(): | |
| """健康检查端点""" | |
| return Response(json.dumps({ | |
| "status": "healthy", | |
| "timeout": REQUEST_TIMEOUT | |
| }), 200, content_type='application/json') | |
| def index(): | |
| """首页""" | |
| return Response(json.dumps({ | |
| "message": "Claude API Proxy", | |
| "endpoints": { | |
| "chat": "/v1/chat/completions", | |
| "health": "/health" | |
| } | |
| }), 200, content_type='application/json') | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| logger.info(f"Flask应用启动在端口 {port}") | |
| app.run(debug=False, host='0.0.0.0', port=port, threaded=True) |