mistpe commited on
Commit
2e4ce5e
·
verified ·
1 Parent(s): 1a7cbdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -153
app.py CHANGED
@@ -6,10 +6,12 @@ import os
6
  import json
7
  from openai import OpenAI
8
  from dotenv import load_dotenv
 
9
  import re
10
  import threading
11
  import logging
12
  from datetime import datetime
 
13
  from concurrent.futures import ThreadPoolExecutor
14
  import queue
15
  import uuid
@@ -19,7 +21,6 @@ import struct
19
  import random
20
  import string
21
  import requests
22
- from typing import Optional, Dict, Any
23
 
24
  logging.basicConfig(
25
  level=logging.INFO,
@@ -46,13 +47,13 @@ IMAGE_MODEL_KEY = os.getenv("IMAGE_MODEL_KEY")
46
  client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
47
  executor = ThreadPoolExecutor(max_workers=10)
48
 
49
- # Define tools for the AI model
50
  TOOLS = [
51
  {
52
  "type": "function",
53
  "function": {
54
  "name": "generate_image",
55
- "description": "Generate an image based on text description and return its markdown URL",
56
  "parameters": {
57
  "type": "object",
58
  "properties": {
@@ -105,8 +106,8 @@ class AsyncResponse:
105
  self.error = None
106
  self.create_time = time.time()
107
  self.timeout = 3600
108
- self.media_id = None
109
- self.response_type = "text"
110
 
111
  def is_expired(self):
112
  return time.time() - self.create_time > self.timeout
@@ -170,71 +171,6 @@ class SessionManager:
170
  del self.sessions[user_id]
171
  logging.info(f"已清理过期会话: {user_id}")
172
 
173
- class ImageService:
174
- @staticmethod
175
- def generate_image(prompt: str) -> str:
176
- try:
177
- logging.info(f"开始生成图片,提示词: {prompt}")
178
-
179
- response = requests.post(
180
- IMAGE_MODEL_URL,
181
- headers={
182
- 'Content-Type': 'application/json',
183
- 'Authorization': f'Bearer {IMAGE_MODEL_KEY}'
184
- },
185
- json={
186
- "model": "grok-latest-image",
187
- "messages": [{
188
- "role": "user",
189
- "content": prompt
190
- }],
191
- "stream": False
192
- }
193
- )
194
-
195
- logging.info(f"图片生成服务响应状态码: {response.status_code}")
196
- response.raise_for_status()
197
-
198
- result = response.json()
199
- logging.info(f"图片生成服务响应内容: {json.dumps(result, ensure_ascii=False)}")
200
-
201
- if not result.get('choices') or not result['choices'][0].get('message', {}).get('content'):
202
- raise ValueError("Invalid response format")
203
-
204
- image_url = result['choices'][0]['message']['content']
205
- logging.info(f"成功获取图片URL: {image_url}")
206
-
207
- return image_url
208
- except Exception as e:
209
- logging.error(f"Image generation error: {str(e)}")
210
- raise
211
-
212
- @staticmethod
213
- def get_media_id(image_url: str) -> str:
214
- try:
215
- logging.info(f"开始下载图片: {image_url}")
216
- image_response = requests.get(image_url)
217
- image_response.raise_for_status()
218
- image_data = image_response.content
219
-
220
- logging.info("开始上传图片到微信服务器")
221
- upload_url = f'https://api.weixin.qq.com/cgi-bin/media/upload?access_token={TOKEN}&type=image'
222
- files = {'media': ('image.jpg', image_data, 'image/jpeg')}
223
- response = requests.post(upload_url, files=files)
224
- response.raise_for_status()
225
-
226
- result = response.json()
227
- logging.info(f"微信服务器响应: {json.dumps(result, ensure_ascii=False)}")
228
-
229
- if 'media_id' not in result:
230
- raise ValueError(f"Failed to get media_id: {result}")
231
-
232
- logging.info(f"成功获取media_id: {result['media_id']}")
233
- return result['media_id']
234
- except Exception as e:
235
- logging.error(f"WeChat media upload error: {str(e)}")
236
- raise
237
-
238
  def convert_markdown_to_wechat(md_text):
239
  if not md_text:
240
  return md_text
@@ -285,7 +221,6 @@ def generate_response_xml(to_user, from_user, content, response_type='text', med
285
  nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
286
 
287
  if response_type == 'image' and media_id:
288
- logging.info(f"生成图片消息响应,media_id: {media_id}")
289
  xml_content = f'''
290
  <xml>
291
  <ToUserName><![CDATA[{to_user}]]></ToUserName>
@@ -299,7 +234,6 @@ def generate_response_xml(to_user, from_user, content, response_type='text', med
299
  '''
300
  else:
301
  formatted_content = convert_markdown_to_wechat(content)
302
- logging.info(f"生成文本消息响应: {formatted_content}")
303
  xml_content = f'''
304
  <xml>
305
  <ToUserName><![CDATA[{to_user}]]></ToUserName>
@@ -331,81 +265,69 @@ def generate_response_xml(to_user, from_user, content, response_type='text', med
331
  response.content_type = 'application/xml'
332
  return response
333
 
334
- def generate_initial_response():
335
- return "您的请求正在处理中,请回复'查询'获取结果"
336
-
337
- def split_message(message, max_length=500):
338
- return [message[i:i+max_length] for i in range(0, len(message), max_length)]
339
-
340
- def append_status_message(content, has_pending_parts=False, is_processing=False):
341
- if "您的请求正在处理中" in content:
342
- return content + "\n\n-------------------\n发送'新对话'开始新的对话"
343
-
344
- status_message = "\n\n-------------------"
345
- if is_processing:
346
- status_message += "\n请回复'查询'获取结果"
347
- elif has_pending_parts:
348
- status_message += "\n当前消息已截断,发送'继续'查看后续内容"
349
- status_message += "\n发送'新对话'开始新的对话"
350
- return content + status_message
351
-
352
- def process_ai_response(messages):
353
  try:
354
- logging.info("开始处理AI响应")
355
- completion = client.chat.completions.create(
356
  model="o3-mini",
357
  messages=messages,
358
  tools=TOOLS,
359
- tool_choice="auto"
 
360
  )
361
- logging.info("收到AI响应")
362
-
363
- # Handle tool calls if present
364
- if completion.choices[0].message.tool_calls:
365
- logging.info("检测到工具调用")
366
- for tool_call in completion.choices[0].message.tool_calls:
367
- if tool_call.function.name == "generate_image":
368
- try:
369
- logging.info("开始执行图片生成")
370
- args = json.loads(tool_call.function.arguments)
371
- # Generate image and get markdown URL
372
- image_url = ImageService.generate_image(args['prompt'])
373
-
374
- # Get WeChat media_id
375
- media_id = ImageService.get_media_id(image_url)
376
-
377
- messages.append({
378
- "role": "assistant",
379
- "content": f"已生成图片"
380
- })
381
-
382
- return {
383
- "type": "image",
384
- "content": None,
385
- "media_id": media_id
386
- }
387
- except Exception as e:
388
- logging.error(f"图片生成过程失败: {str(e)}")
389
- return {
390
- "type": "text",
391
- "content": f"抱歉,图片生成失败:{str(e)}",
392
- "media_id": None
393
- }
394
-
395
- # Handle normal text response
396
- logging.info("处理普通文本响应")
397
- response_content = completion.choices[0].message.content
398
- messages.append({
399
- "role": "assistant",
400
- "content": response_content
401
- })
 
 
 
 
 
 
402
  return {
403
  "type": "text",
404
- "content": response_content,
405
- "media_id": None
406
  }
407
  except Exception as e:
408
- logging.error(f"AI处理错误: {str(e)}")
409
  raise
410
 
411
  def handle_async_task(session, task_id, messages):
@@ -414,20 +336,42 @@ def handle_async_task(session, task_id, messages):
414
  if task_id not in session.response_queue:
415
  return
416
 
417
- result = process_ai_response(messages)
418
- logging.info(f"异步任务处理完成: {task_id}")
419
 
420
  if task_id in session.response_queue and not session.response_queue[task_id].is_expired():
421
  session.response_queue[task_id].status = "completed"
422
- session.response_queue[task_id].result = result.get("content")
423
- session.response_queue[task_id].response_type = result.get("type")
424
- session.response_queue[task_id].media_id = result.get("media_id")
 
 
 
 
 
425
  except Exception as e:
426
  logging.error(f"异步任务处理失败: {str(e)}")
427
  if task_id in session.response_queue:
428
  session.response_queue[task_id].status = "failed"
429
  session.response_queue[task_id].error = str(e)
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  session_manager = SessionManager()
432
 
433
  @app.route('/api/wx', methods=['GET', 'POST'])
@@ -509,8 +453,10 @@ def wechatai():
509
  )
510
 
511
  if task_response.status == "completed":
512
- response_type = task_response.response_type
513
- if response_type == "image":
 
 
514
  return generate_response_xml(
515
  from_user,
516
  to_user,
@@ -564,27 +510,21 @@ def wechatai():
564
  encrypt_type=encrypt_type
565
  )
566
 
567
- # Regular message processing
568
- logging.info("准备开始处理用户消息")
569
  session.messages.append({"role": "user", "content": user_content})
570
 
571
  task_id = str(uuid.uuid4())
572
  session.current_task = task_id
573
  session.response_queue[task_id] = AsyncResponse()
574
 
575
- # Submit task to executor
576
- logging.info(f"提交任务到执行器: {task_id}")
577
  executor.submit(handle_async_task, session, task_id, session.messages.copy())
578
 
579
- # Return immediate response
580
- initial_response = generate_response_xml(
581
  from_user,
582
  to_user,
583
  append_status_message(generate_initial_response(), is_processing=True),
584
  encrypt_type=encrypt_type
585
  )
586
- logging.info("返回初始响应给用户")
587
- return initial_response
588
 
589
  except Exception as e:
590
  logging.error(f"处理请求时出错: {str(e)}")
 
6
  import json
7
  from openai import OpenAI
8
  from dotenv import load_dotenv
9
+ from markdown import markdown
10
  import re
11
  import threading
12
  import logging
13
  from datetime import datetime
14
+ import asyncio
15
  from concurrent.futures import ThreadPoolExecutor
16
  import queue
17
  import uuid
 
21
  import random
22
  import string
23
  import requests
 
24
 
25
  logging.basicConfig(
26
  level=logging.INFO,
 
47
  client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
48
  executor = ThreadPoolExecutor(max_workers=10)
49
 
50
+ # Define tools for image generation
51
  TOOLS = [
52
  {
53
  "type": "function",
54
  "function": {
55
  "name": "generate_image",
56
+ "description": "Generate an image based on text description",
57
  "parameters": {
58
  "type": "object",
59
  "properties": {
 
106
  self.error = None
107
  self.create_time = time.time()
108
  self.timeout = 3600
109
+ self.response_type = "text" # Can be "text" or "image"
110
+ self.media_id = None # For image responses
111
 
112
  def is_expired(self):
113
  return time.time() - self.create_time > self.timeout
 
171
  del self.sessions[user_id]
172
  logging.info(f"已清理过期会话: {user_id}")
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def convert_markdown_to_wechat(md_text):
175
  if not md_text:
176
  return md_text
 
221
  nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
222
 
223
  if response_type == 'image' and media_id:
 
224
  xml_content = f'''
225
  <xml>
226
  <ToUserName><![CDATA[{to_user}]]></ToUserName>
 
234
  '''
235
  else:
236
  formatted_content = convert_markdown_to_wechat(content)
 
237
  xml_content = f'''
238
  <xml>
239
  <ToUserName><![CDATA[{to_user}]]></ToUserName>
 
265
  response.content_type = 'application/xml'
266
  return response
267
 
268
+ def process_long_running_task(messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  try:
270
+ logging.info("开始调用AI服务")
271
+ response = client.chat.completions.create(
272
  model="o3-mini",
273
  messages=messages,
274
  tools=TOOLS,
275
+ tool_choice="auto",
276
+ timeout=60
277
  )
278
+ logging.info("AI服务响应成功")
279
+
280
+ if response.choices[0].message.tool_calls:
281
+ logging.info("检测到tool调用")
282
+ tool_call = response.choices[0].message.tool_calls[0]
283
+ if tool_call.function.name == "generate_image":
284
+ logging.info("开始处理图片生成请求")
285
+ args = json.loads(tool_call.function.arguments)
286
+ image_response = requests.post(
287
+ IMAGE_MODEL_URL,
288
+ headers={
289
+ 'Content-Type': 'application/json',
290
+ 'Authorization': f'Bearer {IMAGE_MODEL_KEY}'
291
+ },
292
+ json={
293
+ "model": "grok-latest-image",
294
+ "messages": [{
295
+ "role": "user",
296
+ "content": args['prompt']
297
+ }]
298
+ }
299
+ )
300
+ image_response.raise_for_status()
301
+ result = image_response.json()
302
+ logging.info("图片生成成功,准备下载图片")
303
+
304
+ image_url = result['choices'][0]['message']['content']
305
+ img_response = requests.get(image_url)
306
+ img_response.raise_for_status()
307
+
308
+ logging.info("开始上传图片到微信服务器")
309
+ upload_url = f'https://api.weixin.qq.com/cgi-bin/media/upload?access_token={TOKEN}&type=image'
310
+ files = {'media': ('image.jpg', img_response.content, 'image/jpeg')}
311
+ upload_response = requests.post(upload_url, files=files)
312
+ upload_response.raise_for_status()
313
+ media_result = upload_response.json()
314
+
315
+ if 'media_id' not in media_result:
316
+ raise ValueError(f"Failed to get media_id: {media_result}")
317
+
318
+ logging.info(f"图片上传成功,media_id: {media_result['media_id']}")
319
+ return {
320
+ "type": "image",
321
+ "media_id": media_result['media_id']
322
+ }
323
+
324
+ logging.info("返回文本响应")
325
  return {
326
  "type": "text",
327
+ "content": response.choices[0].message.content
 
328
  }
329
  except Exception as e:
330
+ logging.error(f"API调用错误: {str(e)}")
331
  raise
332
 
333
  def handle_async_task(session, task_id, messages):
 
336
  if task_id not in session.response_queue:
337
  return
338
 
339
+ result = process_long_running_task(messages)
 
340
 
341
  if task_id in session.response_queue and not session.response_queue[task_id].is_expired():
342
  session.response_queue[task_id].status = "completed"
343
+ session.response_queue[task_id].response_type = result.get("type", "text")
344
+ if result["type"] == "image":
345
+ session.response_queue[task_id].media_id = result["media_id"]
346
+ session.response_queue[task_id].result = None
347
+ messages.append({"role": "assistant", "content": "图片已生成"})
348
+ else:
349
+ session.response_queue[task_id].result = result["content"]
350
+ messages.append({"role": "assistant", "content": result["content"]})
351
  except Exception as e:
352
  logging.error(f"异步任务处理失败: {str(e)}")
353
  if task_id in session.response_queue:
354
  session.response_queue[task_id].status = "failed"
355
  session.response_queue[task_id].error = str(e)
356
 
357
+ def generate_initial_response():
358
+ return "您的请求正在处理中,请回复'查询'获取结果"
359
+
360
+ def split_message(message, max_length=500):
361
+ return [message[i:i+max_length] for i in range(0, len(message), max_length)]
362
+
363
+ def append_status_message(content, has_pending_parts=False, is_processing=False):
364
+ if "您的请求正在处理中" in content:
365
+ return content + "\n\n-------------------\n发送'新对话'开始新的对话"
366
+
367
+ status_message = "\n\n-------------------"
368
+ if is_processing:
369
+ status_message += "\n请回复'查询'获取结果"
370
+ elif has_pending_parts:
371
+ status_message += "\n当前消息已截断,发送'继续'查看后续内容"
372
+ status_message += "\n发送'新对话'开始新的对话"
373
+ return content + status_message
374
+
375
  session_manager = SessionManager()
376
 
377
  @app.route('/api/wx', methods=['GET', 'POST'])
 
453
  )
454
 
455
  if task_response.status == "completed":
456
+ if task_response.response_type == "image":
457
+ logging.info("返回图片响应")
458
+ del session.response_queue[session.current_task]
459
+ session.current_task = None
460
  return generate_response_xml(
461
  from_user,
462
  to_user,
 
510
  encrypt_type=encrypt_type
511
  )
512
 
 
 
513
  session.messages.append({"role": "user", "content": user_content})
514
 
515
  task_id = str(uuid.uuid4())
516
  session.current_task = task_id
517
  session.response_queue[task_id] = AsyncResponse()
518
 
 
 
519
  executor.submit(handle_async_task, session, task_id, session.messages.copy())
520
 
521
+ logging.info("返回初始响应")
522
+ return generate_response_xml(
523
  from_user,
524
  to_user,
525
  append_status_message(generate_initial_response(), is_processing=True),
526
  encrypt_type=encrypt_type
527
  )
 
 
528
 
529
  except Exception as e:
530
  logging.error(f"处理请求时出错: {str(e)}")