File size: 9,775 Bytes
7c3a45a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import os
import time
import logging
from flask import Flask, request, Response
import requests
import json
import concurrent.futures

# 配置日志 - 只使用控制台输出,避免文件权限问题
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

app = Flask(__name__)

# 全局超时设置(秒)
REQUEST_TIMEOUT = 90

def estimate_tokens(text):
    """估算文本的token数量(保守估计:1 token ≈ 2个字符)"""
    return len(text) // 2 if text else 0

def truncate_messages(messages, max_tokens=15000):
    """
    截断消息列表以满足token限制
    """
    if not messages:
        logger.info("消息列表为空,无需截断")
        return messages
    
    original_count = len(messages)
    original_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in messages])
    logger.info(f"原始消息数量: {original_count}, 原始token估算: {original_tokens}")
    
    if original_tokens <= max_tokens:
        logger.info(f"原始token({original_tokens}) <= 限制({max_tokens}),无需截断")
        return messages
    
    if len(messages) <= 2:
        logger.info(f"消息数量 <= 2,直接返回,发送token: {original_tokens}")
        return messages
    
    # 1. 最高优先级:保留最后两条消息
    last_two = messages[-2:]
    last_two_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in last_two])
    logger.info(f"最后两条消息token: {last_two_tokens}")
    
    if last_two_tokens > max_tokens:
        logger.warning(f"最后两条消息({last_two_tokens} tokens)超过限制({max_tokens} tokens),需要截断")
        
        chars_per_msg = (max_tokens * 2) // len(last_two)
        
        truncated_last = []
        for i, msg in enumerate(last_two):
            content = msg.get("content", "")
            if len(content) > chars_per_msg:
                if i == len(last_two) - 1:
                    truncated_content = "..." + content[-chars_per_msg+3:]
                else:
                    truncated_content = content[:chars_per_msg-3] + "..."
                msg = {**msg, "content": truncated_content}
            truncated_last.append(msg)
        
        final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in truncated_last])
        logger.info(f"截断后仅保留最后两条消息,发送token: {final_tokens}")
        return truncated_last
    
    available_for_start = max_tokens - last_two_tokens
    logger.info(f"可用于开始对话的token容量: {available_for_start}")
    
    start_messages = messages[:-2]
    preserved_start = []
    current_tokens = 0
    
    for i, msg in enumerate(start_messages):
        content = msg.get("content", "")
        msg_tokens = estimate_tokens(content)
        
        if current_tokens + msg_tokens <= available_for_start:
            preserved_start.append(msg)
            current_tokens += msg_tokens
        else:
            remaining_tokens = available_for_start - current_tokens
            remaining_chars = remaining_tokens * 2
            
            if remaining_chars > 100:
                truncated_content = content[:remaining_chars-3] + "..."
                truncated_msg = {**msg, "content": truncated_content}
                preserved_start.append(truncated_msg)
                current_tokens += estimate_tokens(truncated_content)
            break
    
    result = preserved_start + last_two
    final_count = len(result)
    final_tokens = sum([estimate_tokens(msg.get("content", "")) for msg in result])
    
    logger.info(f"截断完成: {original_count} -> {final_count} 条消息, {original_tokens} -> {final_tokens} tokens")
    
    return result

def process_request(url, headers, json_data, timeout):
    """处理请求并返回响应对象"""
    try:
        return requests.post(url, headers=headers, json=json_data, timeout=timeout)
    except requests.exceptions.Timeout:
        logger.warning("上游请求超时")
        return None
    except requests.exceptions.RequestException as e:
        logger.error(f"上游请求异常: {e}")
        return None

@app.route('/v1/chat/completions', methods=['POST'])
def claude():
    start_time = time.time()
    
    try:
        json_data = request.get_json()
        if not json_data:
            logger.error("缺少JSON数据")
            return Response(json.dumps({"error": "Missing JSON data"}), 400, content_type='application/json')
        
        model = json_data.get('model', 'unknown')
        stream_flag = json_data.get('stream', False)
        logger.info(f"收到请求 - 模型: {model}, 流式: {stream_flag}")
        
        if 'messages' in json_data and json_data['messages']:
            original_msg_count = len(json_data['messages'])
            json_data['messages'] = truncate_messages(json_data['messages'], max_tokens=15000)
            final_msg_count = len(json_data['messages'])
            
            if original_msg_count != final_msg_count:
                logger.info(f"消息截断: {original_msg_count} -> {final_msg_count}")
        
        headers = {key: value for key, value in request.headers if
                   key not in ['Host', 'User-Agent', "Accept-Encoding", "Accept", "Connection", "Content-Length"]}
        
        url = "https://aiho.st/proxy/aws/claude/chat/completions"
        
        if stream_flag:
            with concurrent.futures.ThreadPoolExecutor() as executor:
                session = requests.Session()
                req = requests.Request('POST', url, headers=headers, json=json_data).prepare()
                
                future = executor.submit(session.send, req, stream=True, timeout=REQUEST_TIMEOUT - 5)
                try:
                    response = future.result(timeout=REQUEST_TIMEOUT)
                    logger.info("流式请求成功发送")
                except concurrent.futures.TimeoutError:
                    logger.error("流式请求处理超时")
                    return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json')
                
                if response is None:
                    logger.error("上游服务器错误")
                    return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json')
                
                def generate():
                    last_chunk_time = time.time()
                    chunk_count = 0
                    try:
                        for chunk in response.iter_content(chunk_size=4096):
                            if chunk:
                                last_chunk_time = time.time()
                                chunk_count += 1
                                yield chunk
                                
                                current_time = time.time()
                                if current_time - start_time > REQUEST_TIMEOUT:
                                    logger.warning("达到全局超时限制")
                                    break
                                    
                                if time.time() - last_chunk_time > 10:
                                    logger.warning("数据块间隔超时")
                                    break
                    finally:
                        response.close()
                        total_time = time.time() - start_time
                        logger.info(f"流式响应完成 - 数据块数: {chunk_count}, 总耗时: {total_time:.2f}秒")
                
                return Response(generate(), content_type=response.headers.get('Content-Type', 'application/octet-stream'))
        else:
            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(process_request, url, headers, json_data, REQUEST_TIMEOUT - 5)
                try:
                    response = future.result(timeout=REQUEST_TIMEOUT)
                    total_time = time.time() - start_time
                    logger.info(f"普通请求完成 - 耗时: {total_time:.2f}秒")
                except concurrent.futures.TimeoutError:
                    logger.error("普通请求处理超时")
                    return Response(json.dumps({"error": "Request processing timeout"}), 504, content_type='application/json')
                
                if response is None:
                    logger.error("上游服务器错误")
                    return Response(json.dumps({"error": "Upstream server error"}), 502, content_type='application/json')
                
                return Response(response.content, content_type=response.headers.get('Content-Type', 'application/json'))
    
    except Exception as e:
        logger.error(f"内部服务器错误: {e}", exc_info=True)
        return Response(json.dumps({"error": "Internal server error"}), 500, content_type='application/json')

@app.route('/health', methods=['GET'])
def health_check():
    """健康检查端点"""
    return Response(json.dumps({
        "status": "healthy", 
        "timeout": REQUEST_TIMEOUT
    }), 200, content_type='application/json')

@app.route('/', methods=['GET'])
def index():
    """首页"""
    return Response(json.dumps({
        "message": "Claude API Proxy",
        "endpoints": {
            "chat": "/v1/chat/completions",
            "health": "/health"
        }
    }), 200, content_type='application/json')

if __name__ == '__main__':
    port = int(os.environ.get('PORT', 7860))
    logger.info(f"Flask应用启动在端口 {port}")
    app.run(debug=False, host='0.0.0.0', port=port, threaded=True)