Estazz commited on
Commit
52f33d3
·
verified ·
1 Parent(s): f16af1f

Update ai_service.py

Browse files
Files changed (1) hide show
  1. ai_service.py +169 -65
ai_service.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- AI 服务模块 - 处理与 AI 模型的交互
3
  """
4
  import dashscope
5
  from dashscope import Generation
@@ -9,34 +9,13 @@ from cache_manager import request_cache
9
  from security import input_validator
10
 
11
 
12
- def design_poker_game(message, history, uploaded_files, custom_prompt_text, prompt_mode):
 
13
  """
14
- 设计扑克游戏主要函数
15
-
16
- Args:
17
- message: 用户输入的消息
18
- history: 对话历史
19
- uploaded_files: 上传的文件
20
- custom_prompt_text: 自定义提示词
21
- prompt_mode: 提示词模式
22
-
23
- Returns:
24
- str: AI 的回复内容
25
  """
26
- # 输入验证
27
- is_valid, error_msg = input_validator.validate_message(message)
28
- if not is_valid:
29
- return f"❌ 输入验证失败:{error_msg}"
30
-
31
- is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
32
- if not is_valid:
33
- return f"❌ 自定义提示词验证失败:{error_msg}"
34
-
35
- is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
36
- if not is_valid:
37
- return f"❌ 文件验证失败:{error_msg}"
38
  # 1) 选择 System Prompt
39
- base_sys = SYSTEM_PROMPT.strip()
40
  user_sys = (custom_prompt_text or "").strip()
41
  mode = (prompt_mode or "覆盖默认SYSTEM_PROMPT").strip()
42
 
@@ -46,10 +25,9 @@ def design_poker_game(message, history, uploaded_files, custom_prompt_text, prom
46
  else:
47
  system_to_use = base_sys
48
 
49
- # 2) 基础 system
50
  messages = [{"role": "system", "content": system_to_use}]
51
 
52
- # 3) 注入上传的 GDL(作为第二条 system)
53
  gdl_spec = load_gdl_text(uploaded_files)
54
  if gdl_spec:
55
  messages.append({
@@ -58,39 +36,94 @@ def design_poker_game(message, history, uploaded_files, custom_prompt_text, prom
58
  + gdl_spec + "\n</GDL_SPEC>"
59
  })
60
 
61
- # 4) 追加历史对话
62
- for human, assistant in history:
63
- messages.append({"role": "user", "content": human})
64
- messages.append({"role": "assistant", "content": assistant})
 
 
65
 
66
- # 5) 当前输入
67
  messages.append({"role": "user", "content": message})
68
 
69
- # 6) 检查缓存并调用模型
70
- # 注意:对于包含历史对话的请求,我们只缓存没有历史对话的请求
71
- if len(history) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  cached_response = request_cache.get(messages)
73
  if cached_response:
74
  return cached_response
75
-
76
  response = _call_ai_model(messages)
77
-
78
- # 缓存没有历史对话的响应
79
- if len(history) == 0 and response and not response.startswith("❌") and not response.startswith("💥"):
80
  request_cache.set(messages, response)
81
-
82
  return response
83
 
84
 
85
  def _call_ai_model(messages):
86
  """
87
- 调用 AI 模型
88
-
89
- Args:
90
- messages: 消息列表
91
-
92
- Returns:
93
- str: AI 回复内容
94
  """
95
  try:
96
  response = Generation.call(
@@ -99,10 +132,10 @@ def _call_ai_model(messages):
99
  temperature=TEMPERATURE,
100
  top_p=TOP_P,
101
  max_tokens=MAX_TOKENS,
102
- result_format='message',
103
- enable_thinking=False
104
  )
105
-
106
  if response.status_code == 200:
107
  content = response.output.choices[0].message.content
108
  if not content or content.strip() == "":
@@ -110,7 +143,7 @@ def _call_ai_model(messages):
110
  return content
111
  else:
112
  return _handle_api_error(response)
113
-
114
  except ConnectionError as e:
115
  return f"🌐 网络连接错误:{str(e)}\n\n请检查网络连接是否正常。"
116
  except TimeoutError as e:
@@ -123,17 +156,10 @@ def _call_ai_model(messages):
123
  def _handle_api_error(response):
124
  """
125
  处理 API 错误
126
-
127
- Args:
128
- response: API 响应对象
129
-
130
- Returns:
131
- str: 错误信息
132
  """
133
- # 兼容不同类型的 code/status_code(有些为字符串,有些为整型)
134
- status_code = getattr(response, 'status_code', None)
135
- code_raw = getattr(response, 'code', None)
136
- message = getattr(response, 'message', '')
137
 
138
  error_msg = f"❌ API 错误:{code_raw} - {message}"
139
 
@@ -145,13 +171,91 @@ def _handle_api_error(response):
145
 
146
  code_int = _as_int(code_raw)
147
 
148
- if (status_code == 401) or (code_int == 401) or (str(code_raw) == '401'):
149
  error_msg += "\n\n💡 提示:请检查 API Key 是否正确设置。"
150
- elif (status_code == 429) or (code_int == 429) or (str(code_raw) == '429'):
151
  error_msg += "\n\n💡 提示:请求过于频繁,请稍后再试。"
152
  else:
153
- # 服务器错误(5xx)判定:优先使用 status_code,其次尝试解析 code
154
  if (isinstance(status_code, int) and status_code >= 500) or (code_int is not None and code_int >= 500):
155
  error_msg += "\n\n💡 提示:服务器错误,请稍后重试。"
156
 
157
  return error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ AI 服务模块 - 处理与 AI 模型的交互(支持原生流式输出)
3
  """
4
  import dashscope
5
  from dashscope import Generation
 
9
  from security import input_validator
10
 
11
 
12
+ # ========== 公共小工具 ==========
13
+ def _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode):
14
  """
15
+ 组装 messages,保证与非流式/流式两条路径提示词一致
 
 
 
 
 
 
 
 
 
 
16
  """
 
 
 
 
 
 
 
 
 
 
 
 
17
  # 1) 选择 System Prompt
18
+ base_sys = (SYSTEM_PROMPT or "").strip()
19
  user_sys = (custom_prompt_text or "").strip()
20
  mode = (prompt_mode or "覆盖默认SYSTEM_PROMPT").strip()
21
 
 
25
  else:
26
  system_to_use = base_sys
27
 
 
28
  messages = [{"role": "system", "content": system_to_use}]
29
 
30
+ # 2) 注入上传的 GDL(作为第二条 system)
31
  gdl_spec = load_gdl_text(uploaded_files)
32
  if gdl_spec:
33
  messages.append({
 
36
  + gdl_spec + "\n</GDL_SPEC>"
37
  })
38
 
39
+ # 3) 追加历史对话
40
+ for human, assistant in (history or []):
41
+ if human:
42
+ messages.append({"role": "user", "content": human})
43
+ if assistant:
44
+ messages.append({"role": "assistant", "content": assistant})
45
 
46
+ # 4) 当前输入
47
  messages.append({"role": "user", "content": message})
48
 
49
+ return messages
50
+
51
+
52
+ def _yield_chunks(text, step=40):
53
+ """把整段文本切成小块,伪流式输出。"""
54
+ s = str(text or "")
55
+ for i in range(0, len(s), step):
56
+ yield s[i:i + step]
57
+
58
+
59
+ def _extract_stream_delta(resp):
60
+ """
61
+ 尽量兼容不同 dashscope 小版本的流式返回结构,提取“增量文本”
62
+ 常见字段:resp.output_text 或 resp.output.choices[0].delta/message/content
63
+ """
64
+ delta = None
65
+ # 优先简单字段
66
+ if hasattr(resp, "output_text") and resp.output_text:
67
+ return resp.output_text
68
+
69
+ out = getattr(resp, "output", None)
70
+ if isinstance(out, dict):
71
+ choices = out.get("choices") or []
72
+ if choices:
73
+ c0 = choices[0] or {}
74
+ # 1) delta 路径
75
+ d = c0.get("delta")
76
+ if isinstance(d, dict):
77
+ delta = d.get("content") or d.get("text") or None
78
+ elif d:
79
+ delta = str(d)
80
+ # 2) message 路径(有些版本直接不断给 message.content)
81
+ if not delta and isinstance(c0.get("message"), dict):
82
+ delta = c0["message"].get("content")
83
+ # 3) content 直给
84
+ if not delta:
85
+ delta = c0.get("content")
86
+
87
+ return delta
88
+
89
+
90
+ # ========== 非流式(保留你原实现,便于兼容) ==========
91
+ def design_poker_game(message, history, uploaded_files, custom_prompt_text, prompt_mode):
92
+ """
93
+ 设计扑克游戏的主要函数(非流式)
94
+ """
95
+ # 输入验证
96
+ is_valid, error_msg = input_validator.validate_message(message)
97
+ if not is_valid:
98
+ return f"❌ 输入验证失败:{error_msg}"
99
+
100
+ is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
101
+ if not is_valid:
102
+ return f"❌ 自定义提示词验证失败:{error_msg}"
103
+
104
+ is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
105
+ if not is_valid:
106
+ return f"❌ 文件验证失败:{error_msg}"
107
+
108
+ messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode)
109
+
110
+ # 仅在“无历史”时启用缓存(沿用你的策略)
111
+ if len(history or []) == 0:
112
  cached_response = request_cache.get(messages)
113
  if cached_response:
114
  return cached_response
115
+
116
  response = _call_ai_model(messages)
117
+
118
+ if len(history or []) == 0 and response and not response.startswith(("❌", "💥")):
 
119
  request_cache.set(messages, response)
120
+
121
  return response
122
 
123
 
124
  def _call_ai_model(messages):
125
  """
126
+ 调用 AI 模型(非流式)
 
 
 
 
 
 
127
  """
128
  try:
129
  response = Generation.call(
 
132
  temperature=TEMPERATURE,
133
  top_p=TOP_P,
134
  max_tokens=MAX_TOKENS,
135
+ result_format="message",
136
+ enable_thinking=False,
137
  )
138
+
139
  if response.status_code == 200:
140
  content = response.output.choices[0].message.content
141
  if not content or content.strip() == "":
 
143
  return content
144
  else:
145
  return _handle_api_error(response)
146
+
147
  except ConnectionError as e:
148
  return f"🌐 网络连接错误:{str(e)}\n\n请检查网络连接是否正常。"
149
  except TimeoutError as e:
 
156
  def _handle_api_error(response):
157
  """
158
  处理 API 错误
 
 
 
 
 
 
159
  """
160
+ status_code = getattr(response, "status_code", None)
161
+ code_raw = getattr(response, "code", None)
162
+ message = getattr(response, "message", "")
 
163
 
164
  error_msg = f"❌ API 错误:{code_raw} - {message}"
165
 
 
171
 
172
  code_int = _as_int(code_raw)
173
 
174
+ if (status_code == 401) or (code_int == 401) or (str(code_raw) == "401"):
175
  error_msg += "\n\n💡 提示:请检查 API Key 是否正确设置。"
176
+ elif (status_code == 429) or (code_int == 429) or (str(code_raw) == "429"):
177
  error_msg += "\n\n💡 提示:请求过于频繁,请稍后再试。"
178
  else:
 
179
  if (isinstance(status_code, int) and status_code >= 500) or (code_int is not None and code_int >= 500):
180
  error_msg += "\n\n💡 提示:服务器错误,请稍后重试。"
181
 
182
  return error_msg
183
+
184
+
185
+ # ========== 新增:原生流式 ==========
186
+ def design_poker_game_stream(message, history, uploaded_files, custom_prompt_text, prompt_mode):
187
+ """
188
+ 原生流式:逐段 yield 文本片段(字符串)
189
+ - 与 design_poker_game 的提示与参数保持一致
190
+ - app.py 会优先调用本函数实现“边生成边显示”
191
+ """
192
+ # 1) 输入验证(与非流式一致)
193
+ is_valid, error_msg = input_validator.validate_message(message)
194
+ if not is_valid:
195
+ yield f"❌ 输入验证失败:{error_msg}"
196
+ return
197
+
198
+ is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
199
+ if not is_valid:
200
+ yield f"❌ 自定义提示词验证失败:{error_msg}"
201
+ return
202
+
203
+ is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
204
+ if not is_valid:
205
+ yield f"❌ 文件验证失败:{error_msg}"
206
+ return
207
+
208
+ # 2) 组装 messages
209
+ messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode)
210
+
211
+ # 3) 缓存命中(仅无历史时)
212
+ no_hist = len(history or []) == 0
213
+ if no_hist:
214
+ cached = request_cache.get(messages)
215
+ if cached:
216
+ for piece in _yield_chunks(cached, step=48): # 比非流式略大些片段,体感更顺
217
+ yield piece
218
+ return
219
+
220
+ # 4) 原生流式调用
221
+ buf = []
222
+ try:
223
+ resp_iter = Generation.call(
224
+ model=MODEL_NAME,
225
+ messages=messages,
226
+ temperature=TEMPERATURE,
227
+ top_p=TOP_P,
228
+ max_tokens=MAX_TOKENS,
229
+ result_format="message", # 与非流式保持一致
230
+ enable_thinking=False,
231
+ stream=True,
232
+ incremental_output=True, # 关键:��量输出
233
+ )
234
+
235
+ for resp in resp_iter:
236
+ # 有些帧可能是控制帧/心跳,直接跳过
237
+ delta = _extract_stream_delta(resp)
238
+ if not delta:
239
+ # 也可能是错误帧
240
+ status_code = getattr(resp, "status_code", 200)
241
+ if status_code and status_code != 200:
242
+ # 尽量提取错误信息并终止
243
+ err = _handle_api_error(resp)
244
+ yield f"\n{err}"
245
+ return
246
+ continue
247
+
248
+ buf.append(delta)
249
+ yield delta # 每拿到一段就吐出去
250
+
251
+ # 5) 结束:写入缓存(仅无历史时 & 有内容 & 无错误提示)
252
+ full = "".join(buf).strip()
253
+ if no_hist and full and not full.startswith(("❌", "💥")):
254
+ request_cache.set(messages, full)
255
+
256
+ except ConnectionError as e:
257
+ yield f"\n🌐 网络连接错误:{str(e)}"
258
+ except TimeoutError as e:
259
+ yield f"\n⏰ 请求超时:{str(e)}"
260
+ except Exception as e:
261
+ yield f"\n💥 流式调用失败:{type(e).__name__}: {e}"