mistpe commited on
Commit
d37e498
·
verified ·
1 Parent(s): c6deae5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +465 -0
app.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, make_response
2
+ import hashlib
3
+ import time
4
+ import xml.etree.ElementTree as ET
5
+ import os
6
+ import json
7
+ from openai import OpenAI
8
+ from dotenv import load_dotenv
9
+ from markdown import markdown
10
+ import re
11
+ import threading
12
+ import logging
13
+ from datetime import datetime
14
+ import asyncio
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ import queue
17
+ import uuid
18
+ import base64
19
+ from Crypto.Cipher import AES
20
+ import struct
21
+ import random
22
+ import string
23
+
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(levelname)s - %(message)s',
27
+ handlers=[
28
+ logging.FileHandler('wechat_service.log'),
29
+ logging.StreamHandler()
30
+ ]
31
+ )
32
+
33
+ load_dotenv()
34
+
35
+ app = Flask(__name__)
36
+
37
+ TOKEN = os.getenv('TOKEN')
38
+ ENCODING_AES_KEY = os.getenv('ENCODING_AES_KEY')
39
+ APPID = os.getenv('APPID')
40
+ API_KEY = os.getenv("API_KEY")
41
+ BASE_URL = os.getenv("OPENAI_BASE_URL")
42
+
43
+ client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
44
+ executor = ThreadPoolExecutor(max_workers=10)
45
+
46
+ class WeChatCrypto:
47
+ def __init__(self, key, app_id):
48
+ self.key = base64.b64decode(key + '=')
49
+ self.app_id = app_id
50
+
51
+ def encrypt(self, text):
52
+ # 生成随机16字节字符串
53
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
54
+ text_bytes = text.encode('utf-8')
55
+
56
+ # 构造明文字符串
57
+ msg_len = struct.pack('>I', len(text_bytes))
58
+ message = random_str.encode('utf-8') + msg_len + text_bytes + self.app_id.encode('utf-8')
59
+
60
+ # 补位
61
+ pad_len = 32 - (len(message) % 32)
62
+ message += chr(pad_len).encode('utf-8') * pad_len
63
+
64
+ # 加密
65
+ cipher = AES.new(self.key, AES.MODE_CBC, self.key[:16])
66
+ encrypted = cipher.encrypt(message)
67
+ return base64.b64encode(encrypted).decode('utf-8')
68
+
69
+ def decrypt(self, encrypted_text):
70
+ # Base64解码
71
+ encrypted_data = base64.b64decode(encrypted_text)
72
+
73
+ # 解密
74
+ cipher = AES.new(self.key, AES.MODE_CBC, self.key[:16])
75
+ decrypted = cipher.decrypt(encrypted_data)
76
+
77
+ # 获取填充长度
78
+ pad_len = decrypted[-1]
79
+ if not isinstance(pad_len, int):
80
+ pad_len = ord(pad_len)
81
+ content = decrypted[16:-pad_len]
82
+
83
+ # 获取消息长度
84
+ msg_len = struct.unpack('>I', content[:4])[0]
85
+ xml_content = content[4:msg_len + 4].decode('utf-8')
86
+ app_id = content[msg_len + 4:].decode('utf-8')
87
+
88
+ if app_id != self.app_id:
89
+ raise ValueError('Invalid AppID')
90
+
91
+ return xml_content
92
+
93
+ class AsyncResponse:
94
+ def __init__(self):
95
+ self.status = "processing"
96
+ self.result = None
97
+ self.error = None
98
+ self.create_time = time.time()
99
+ self.timeout = 3600
100
+
101
+ def is_expired(self):
102
+ return time.time() - self.create_time > self.timeout
103
+
104
+ class UserSession:
105
+ def __init__(self):
106
+ self.messages = [{"role": "system", "content": "你是HXIAO公众号的智能助手,这一个用来分享与学习人工智能的公众号,我们的目标是专注AI应用的简单研究与实践。致力于分享切实可行的技术方案,希望让复杂的技术变得简单易懂。也喜欢用通俗的语言来解释专业概念,让技术真正服务于每个学习者"}]
107
+ self.pending_parts = []
108
+ self.last_active = time.time()
109
+ self.current_task = None
110
+ self.response_queue = {}
111
+ self.session_timeout = 3600
112
+
113
+ def is_expired(self):
114
+ return time.time() - self.last_active > self.session_timeout
115
+
116
+ def cleanup_expired_tasks(self):
117
+ expired_tasks = [
118
+ task_id for task_id, response in self.response_queue.items()
119
+ if response.is_expired()
120
+ ]
121
+ for task_id in expired_tasks:
122
+ del self.response_queue[task_id]
123
+ if self.current_task == task_id:
124
+ self.current_task = None
125
+
126
+ class SessionManager:
127
+ def __init__(self):
128
+ self.sessions = {}
129
+ self._lock = threading.Lock()
130
+ self.crypto = WeChatCrypto(ENCODING_AES_KEY, APPID)
131
+
132
+ def get_session(self, user_id):
133
+ with self._lock:
134
+ current_time = time.time()
135
+ if user_id in self.sessions:
136
+ session = self.sessions[user_id]
137
+ if session.is_expired():
138
+ session = UserSession()
139
+ else:
140
+ session.cleanup_expired_tasks()
141
+ else:
142
+ session = UserSession()
143
+ session.last_active = current_time
144
+ self.sessions[user_id] = session
145
+ return session
146
+
147
+ def clear_session(self, user_id):
148
+ with self._lock:
149
+ if user_id in self.sessions:
150
+ self.sessions[user_id] = UserSession()
151
+
152
+ def cleanup_expired_sessions(self):
153
+ with self._lock:
154
+ current_time = time.time()
155
+ expired_users = [
156
+ user_id for user_id, session in self.sessions.items()
157
+ if session.is_expired()
158
+ ]
159
+ for user_id in expired_users:
160
+ del self.sessions[user_id]
161
+ logging.info(f"已清理过期会话: {user_id}")
162
+
163
+ session_manager = SessionManager()
164
+
165
+ def convert_markdown_to_wechat(md_text):
166
+ if not md_text:
167
+ return md_text
168
+
169
+ md_text = re.sub(r'^# (.*?)$', r'【标题】\1', md_text, flags=re.MULTILINE)
170
+ md_text = re.sub(r'^## (.*?)$', r'【子标题】\1', md_text, flags=re.MULTILINE)
171
+ md_text = re.sub(r'^### (.*?)$', r'【小标题】\1', md_text, flags=re.MULTILINE)
172
+ md_text = re.sub(r'\*\*(.*?)\*\*', r'『\1』', md_text)
173
+ md_text = re.sub(r'\*(.*?)\*', r'「\1」', md_text)
174
+ md_text = re.sub(r'`(.*?)`', r'「\1」', md_text)
175
+ md_text = re.sub(r'^\- ', '• ', md_text, flags=re.MULTILINE)
176
+ md_text = re.sub(r'^\d\. ', '○ ', md_text, flags=re.MULTILINE)
177
+ md_text = re.sub(r'```[\w]*\n(.*?)```', r'【代码开始】\n\1\n【代码结束】', md_text, flags=re.DOTALL)
178
+ md_text = re.sub(r'^> (.*?)$', r'▎\1', md_text, flags=re.MULTILINE)
179
+ md_text = re.sub(r'^-{3,}$', r'—————————', md_text, flags=re.MULTILINE)
180
+ md_text = re.sub(r'\[(.*?)\]\((.*?)\)', r'\1(\2)', md_text)
181
+ md_text = re.sub(r'\n{3,}', '\n\n', md_text)
182
+
183
+ return md_text
184
+
185
+ def verify_signature(signature, timestamp, nonce, token):
186
+ items = [token, timestamp, nonce]
187
+ items.sort()
188
+ temp_str = ''.join(items)
189
+ hash_sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest()
190
+ return hash_sha1 == signature
191
+
192
+ def verify_msg_signature(msg_signature, timestamp, nonce, token, encrypt_msg):
193
+ items = [token, timestamp, nonce, encrypt_msg]
194
+ items.sort()
195
+ temp_str = ''.join(items)
196
+ hash_sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest()
197
+ return hash_sha1 == msg_signature
198
+
199
+ def parse_xml_message(xml_content):
200
+ root = ET.fromstring(xml_content)
201
+ return {
202
+ 'content': root.find('Content').text if root.find('Content') is not None else '',
203
+ 'from_user': root.find('FromUserName').text,
204
+ 'to_user': root.find('ToUserName').text,
205
+ 'msg_id': root.find('MsgId').text if root.find('MsgId') is not None else '',
206
+ 'create_time': root.find('CreateTime').text,
207
+ 'msg_type': root.find('MsgType').text
208
+ }
209
+
210
+ def generate_response_xml(to_user, from_user, content, encrypt_type='aes'):
211
+ formatted_content = convert_markdown_to_wechat(content)
212
+ timestamp = str(int(time.time()))
213
+ nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
214
+
215
+ if encrypt_type == 'aes':
216
+ xml_content = f'''
217
+ <xml>
218
+ <ToUserName><![CDATA[{to_user}]]></ToUserName>
219
+ <FromUserName><![CDATA[{from_user}]]></FromUserName>
220
+ <CreateTime>{timestamp}</CreateTime>
221
+ <MsgType><![CDATA[text]]></MsgType>
222
+ <Content><![CDATA[{formatted_content}]]></Content>
223
+ </xml>
224
+ '''
225
+
226
+ # 加密消息内容
227
+ encrypted = session_manager.crypto.encrypt(xml_content)
228
+
229
+ # 生成签名
230
+ signature_list = [TOKEN, timestamp, nonce, encrypted]
231
+ signature_list.sort()
232
+ msg_signature = hashlib.sha1(''.join(signature_list).encode('utf-8')).hexdigest()
233
+
234
+ response_xml = f'''
235
+ <xml>
236
+ <Encrypt><![CDATA[{encrypted}]]></Encrypt>
237
+ <MsgSignature><![CDATA[{msg_signature}]]></MsgSignature>
238
+ <TimeStamp>{timestamp}</TimeStamp>
239
+ <Nonce><![CDATA[{nonce}]]></Nonce>
240
+ </xml>
241
+ '''
242
+ else:
243
+ response_xml = f'''
244
+ <xml>
245
+ <ToUserName><![CDATA[{to_user}]]></ToUserName>
246
+ <FromUserName><![CDATA[{from_user}]]></FromUserName>
247
+ <CreateTime>{timestamp}</CreateTime>
248
+ <MsgType><![CDATA[text]]></MsgType>
249
+ <Content><![CDATA[{formatted_content}]]></Content>
250
+ </xml>
251
+ '''
252
+
253
+ response = make_response(response_xml)
254
+ response.content_type = 'application/xml'
255
+ return response
256
+
257
+ def process_long_running_task(messages):
258
+ try:
259
+ response = client.chat.completions.create(
260
+ model="o3-mini",
261
+ messages=messages,
262
+ timeout=60
263
+ )
264
+ return response.choices[0].message.content
265
+ except Exception as e:
266
+ logging.error(f"API调用错误: {str(e)}")
267
+ raise
268
+
269
+ def handle_async_task(session, task_id, messages):
270
+ try:
271
+ if task_id not in session.response_queue:
272
+ return
273
+
274
+ result = process_long_running_task(messages)
275
+
276
+ if task_id in session.response_queue and not session.response_queue[task_id].is_expired():
277
+ session.response_queue[task_id].status = "completed"
278
+ session.response_queue[task_id].result = result
279
+ except Exception as e:
280
+ if task_id in session.response_queue:
281
+ session.response_queue[task_id].status = "failed"
282
+ session.response_queue[task_id].error = str(e)
283
+
284
+ def generate_initial_response():
285
+ return "您的请求正在处理中,请回复'查询'获取结果"
286
+
287
+ def split_message(message, max_length=500):
288
+ return [message[i:i+max_length] for i in range(0, len(message), max_length)]
289
+
290
+ def append_status_message(content, has_pending_parts=False, is_processing=False):
291
+ if "您的请求正在处理中" in content:
292
+ return content + "\n\n-------------------\n发送'新对话'开始新的对话"
293
+
294
+ status_message = "\n\n-------------------"
295
+ if is_processing:
296
+ status_message += "\n请回复'查询'获取结果"
297
+ elif has_pending_parts:
298
+ status_message += "\n当前消息已截断,发送'继续'查看后续内容"
299
+ status_message += "\n发送'新对话'开始新的对话"
300
+ return content + status_message
301
+
302
+ @app.route('/api/wx', methods=['GET', 'POST'])
303
+ def wechatai():
304
+ if request.method == 'GET':
305
+ signature = request.args.get('signature')
306
+ timestamp = request.args.get('timestamp')
307
+ nonce = request.args.get('nonce')
308
+ echostr = request.args.get('echostr')
309
+
310
+ if verify_signature(signature, timestamp, nonce, TOKEN):
311
+ return echostr
312
+ return 'error', 403
313
+
314
+ try:
315
+ encrypt_type = request.args.get('encrypt_type', '')
316
+
317
+ if encrypt_type == 'aes':
318
+ msg_signature = request.args.get('msg_signature')
319
+ timestamp = request.args.get('timestamp')
320
+ nonce = request.args.get('nonce')
321
+
322
+ # 解析加密的XML
323
+ xml_tree = ET.fromstring(request.data)
324
+ encrypted_text = xml_tree.find('Encrypt').text
325
+
326
+ # 验证消息签名
327
+ if not verify_msg_signature(msg_signature, timestamp, nonce, TOKEN, encrypted_text):
328
+ return 'Invalid signature', 403
329
+
330
+ # 解密消息
331
+ decrypted_xml = session_manager.crypto.decrypt(encrypted_text)
332
+ message_data = parse_xml_message(decrypted_xml)
333
+ else:
334
+ message_data = parse_xml_message(request.data)
335
+
336
+ user_content = message_data['content'].strip()
337
+ from_user = message_data['from_user']
338
+ to_user = message_data['to_user']
339
+
340
+ logging.info(f"收到用户({from_user})消息: {user_content}")
341
+ session = session_manager.get_session(from_user)
342
+
343
+ if user_content == '新对话':
344
+ session_manager.clear_session(from_user)
345
+ return generate_response_xml(
346
+ from_user,
347
+ to_user,
348
+ append_status_message('已开始新的对话。请描述您的问题。'),
349
+ encrypt_type
350
+ )
351
+
352
+ if user_content == '继续':
353
+ if session.pending_parts:
354
+ next_part = session.pending_parts.pop(0)
355
+ has_more = bool(session.pending_parts)
356
+ return generate_response_xml(
357
+ from_user,
358
+ to_user,
359
+ append_status_message(next_part, has_more),
360
+ encrypt_type
361
+ )
362
+ return generate_response_xml(
363
+ from_user,
364
+ to_user,
365
+ append_status_message('没有更多内容了。请继续您的问题。'),
366
+ encrypt_type
367
+ )
368
+
369
+ if user_content == '查询':
370
+ if session.current_task:
371
+ task_response = session.response_queue.get(session.current_task)
372
+ if task_response:
373
+ if task_response.is_expired():
374
+ del session.response_queue[session.current_task]
375
+ session.current_task = None
376
+ return generate_response_xml(
377
+ from_user,
378
+ to_user,
379
+ append_status_message('请求已过期,请重新提问。'),
380
+ encrypt_type
381
+ )
382
+
383
+ if task_response.status == "completed":
384
+ response = task_response.result
385
+ del session.response_queue[session.current_task]
386
+ session.current_task = None
387
+ session.messages.append({"role": "assistant", "content": response})
388
+
389
+ if len(response) > 500:
390
+ parts = split_message(response)
391
+ first_part = parts.pop(0)
392
+ session.pending_parts = parts
393
+ return generate_response_xml(
394
+ from_user,
395
+ to_user,
396
+ append_status_message(first_part, True),
397
+ encrypt_type
398
+ )
399
+ return generate_response_xml(
400
+ from_user,
401
+ to_user,
402
+ append_status_message(response),
403
+ encrypt_type
404
+ )
405
+ elif task_response.status == "failed":
406
+ error_message = '处理过程中出现错误,请重新提问。'
407
+ del session.response_queue[session.current_task]
408
+ session.current_task = None
409
+ return generate_response_xml(
410
+ from_user,
411
+ to_user,
412
+ append_status_message(error_message),
413
+ encrypt_type
414
+ )
415
+ else:
416
+ return generate_response_xml(
417
+ from_user,
418
+ to_user,
419
+ append_status_message('正在处理中,请稍后再次查询。', is_processing=True),
420
+ encrypt_type
421
+ )
422
+ return generate_response_xml(
423
+ from_user,
424
+ to_user,
425
+ append_status_message('没有正在处理的请求。'),
426
+ encrypt_type
427
+ )
428
+
429
+ session.messages.append({"role": "user", "content": user_content})
430
+
431
+ task_id = str(uuid.uuid4())
432
+ session.current_task = task_id
433
+ session.response_queue[task_id] = AsyncResponse()
434
+
435
+ executor.submit(handle_async_task, session, task_id, session.messages.copy())
436
+
437
+ return generate_response_xml(
438
+ from_user,
439
+ to_user,
440
+ append_status_message(generate_initial_response(), is_processing=True),
441
+ encrypt_type
442
+ )
443
+
444
+ except Exception as e:
445
+ logging.error(f"处理请求时出错: {str(e)}")
446
+ return generate_response_xml(
447
+ message_data['from_user'],
448
+ message_data['to_user'],
449
+ append_status_message('抱歉,系统暂时出现问题,请稍后重试。'),
450
+ encrypt_type if 'encrypt_type' in locals() else ''
451
+ )
452
+
453
+ def cleanup_sessions():
454
+ while True:
455
+ time.sleep(3600) # 每小时清理一次
456
+ try:
457
+ session_manager.cleanup_expired_sessions()
458
+ except Exception as e:
459
+ logging.error(f"清理会话时出错: {str(e)}")
460
+
461
+ if __name__ == '__main__':
462
+ cleanup_thread = threading.Thread(target=cleanup_sessions, daemon=True)
463
+ cleanup_thread.start()
464
+
465
+ app.run(host='0.0.0.0', port=7860, debug=True)