Spaces:
Sleeping
Sleeping
Update ai_service.py
Browse files- ai_service.py +169 -65
ai_service.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
AI 服务模块 - 处理与 AI 模型的交互
|
| 3 |
"""
|
| 4 |
import dashscope
|
| 5 |
from dashscope import Generation
|
|
@@ -9,34 +9,13 @@ from cache_manager import request_cache
|
|
| 9 |
from security import input_validator
|
| 10 |
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
message: 用户输入的消息
|
| 18 |
-
history: 对话历史
|
| 19 |
-
uploaded_files: 上传的文件
|
| 20 |
-
custom_prompt_text: 自定义提示词
|
| 21 |
-
prompt_mode: 提示词模式
|
| 22 |
-
|
| 23 |
-
Returns:
|
| 24 |
-
str: AI 的回复内容
|
| 25 |
"""
|
| 26 |
-
# 输入验证
|
| 27 |
-
is_valid, error_msg = input_validator.validate_message(message)
|
| 28 |
-
if not is_valid:
|
| 29 |
-
return f"❌ 输入验证失败:{error_msg}"
|
| 30 |
-
|
| 31 |
-
is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
|
| 32 |
-
if not is_valid:
|
| 33 |
-
return f"❌ 自定义提示词验证失败:{error_msg}"
|
| 34 |
-
|
| 35 |
-
is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
|
| 36 |
-
if not is_valid:
|
| 37 |
-
return f"❌ 文件验证失败:{error_msg}"
|
| 38 |
# 1) 选择 System Prompt
|
| 39 |
-
base_sys = SYSTEM_PROMPT.strip()
|
| 40 |
user_sys = (custom_prompt_text or "").strip()
|
| 41 |
mode = (prompt_mode or "覆盖默认SYSTEM_PROMPT").strip()
|
| 42 |
|
|
@@ -46,10 +25,9 @@ def design_poker_game(message, history, uploaded_files, custom_prompt_text, prom
|
|
| 46 |
else:
|
| 47 |
system_to_use = base_sys
|
| 48 |
|
| 49 |
-
# 2) 基础 system
|
| 50 |
messages = [{"role": "system", "content": system_to_use}]
|
| 51 |
|
| 52 |
-
#
|
| 53 |
gdl_spec = load_gdl_text(uploaded_files)
|
| 54 |
if gdl_spec:
|
| 55 |
messages.append({
|
|
@@ -58,39 +36,94 @@ def design_poker_game(message, history, uploaded_files, custom_prompt_text, prom
|
|
| 58 |
+ gdl_spec + "\n</GDL_SPEC>"
|
| 59 |
})
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
for human, assistant in history:
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
#
|
| 67 |
messages.append({"role": "user", "content": message})
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
cached_response = request_cache.get(messages)
|
| 73 |
if cached_response:
|
| 74 |
return cached_response
|
| 75 |
-
|
| 76 |
response = _call_ai_model(messages)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
if len(history) == 0 and response and not response.startswith("❌") and not response.startswith("💥"):
|
| 80 |
request_cache.set(messages, response)
|
| 81 |
-
|
| 82 |
return response
|
| 83 |
|
| 84 |
|
| 85 |
def _call_ai_model(messages):
|
| 86 |
"""
|
| 87 |
-
调用 AI 模型
|
| 88 |
-
|
| 89 |
-
Args:
|
| 90 |
-
messages: 消息列表
|
| 91 |
-
|
| 92 |
-
Returns:
|
| 93 |
-
str: AI 回复内容
|
| 94 |
"""
|
| 95 |
try:
|
| 96 |
response = Generation.call(
|
|
@@ -99,10 +132,10 @@ def _call_ai_model(messages):
|
|
| 99 |
temperature=TEMPERATURE,
|
| 100 |
top_p=TOP_P,
|
| 101 |
max_tokens=MAX_TOKENS,
|
| 102 |
-
result_format=
|
| 103 |
-
enable_thinking=False
|
| 104 |
)
|
| 105 |
-
|
| 106 |
if response.status_code == 200:
|
| 107 |
content = response.output.choices[0].message.content
|
| 108 |
if not content or content.strip() == "":
|
|
@@ -110,7 +143,7 @@ def _call_ai_model(messages):
|
|
| 110 |
return content
|
| 111 |
else:
|
| 112 |
return _handle_api_error(response)
|
| 113 |
-
|
| 114 |
except ConnectionError as e:
|
| 115 |
return f"🌐 网络连接错误:{str(e)}\n\n请检查网络连接是否正常。"
|
| 116 |
except TimeoutError as e:
|
|
@@ -123,17 +156,10 @@ def _call_ai_model(messages):
|
|
| 123 |
def _handle_api_error(response):
|
| 124 |
"""
|
| 125 |
处理 API 错误
|
| 126 |
-
|
| 127 |
-
Args:
|
| 128 |
-
response: API 响应对象
|
| 129 |
-
|
| 130 |
-
Returns:
|
| 131 |
-
str: 错误信息
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
message = getattr(response, 'message', '')
|
| 137 |
|
| 138 |
error_msg = f"❌ API 错误:{code_raw} - {message}"
|
| 139 |
|
|
@@ -145,13 +171,91 @@ def _handle_api_error(response):
|
|
| 145 |
|
| 146 |
code_int = _as_int(code_raw)
|
| 147 |
|
| 148 |
-
if (status_code == 401) or (code_int == 401) or (str(code_raw) ==
|
| 149 |
error_msg += "\n\n💡 提示:请检查 API Key 是否正确设置。"
|
| 150 |
-
elif (status_code == 429) or (code_int == 429) or (str(code_raw) ==
|
| 151 |
error_msg += "\n\n💡 提示:请求过于频繁,请稍后再试。"
|
| 152 |
else:
|
| 153 |
-
# 服务器错误(5xx)判定:优先使用 status_code,其次尝试解析 code
|
| 154 |
if (isinstance(status_code, int) and status_code >= 500) or (code_int is not None and code_int >= 500):
|
| 155 |
error_msg += "\n\n💡 提示:服务器错误,请稍后重试。"
|
| 156 |
|
| 157 |
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
AI 服务模块 - 处理与 AI 模型的交互(支持原生流式输出)
|
| 3 |
"""
|
| 4 |
import dashscope
|
| 5 |
from dashscope import Generation
|
|
|
|
| 9 |
from security import input_validator
|
| 10 |
|
| 11 |
|
| 12 |
+
# ========== 公共小工具 ==========
|
| 13 |
+
def _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode):
|
| 14 |
"""
|
| 15 |
+
组装 messages,保证与非流式/流式两条路径的提示词一致
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# 1) 选择 System Prompt
|
| 18 |
+
base_sys = (SYSTEM_PROMPT or "").strip()
|
| 19 |
user_sys = (custom_prompt_text or "").strip()
|
| 20 |
mode = (prompt_mode or "覆盖默认SYSTEM_PROMPT").strip()
|
| 21 |
|
|
|
|
| 25 |
else:
|
| 26 |
system_to_use = base_sys
|
| 27 |
|
|
|
|
| 28 |
messages = [{"role": "system", "content": system_to_use}]
|
| 29 |
|
| 30 |
+
# 2) 注入上传的 GDL(作为第二条 system)
|
| 31 |
gdl_spec = load_gdl_text(uploaded_files)
|
| 32 |
if gdl_spec:
|
| 33 |
messages.append({
|
|
|
|
| 36 |
+ gdl_spec + "\n</GDL_SPEC>"
|
| 37 |
})
|
| 38 |
|
| 39 |
+
# 3) 追加历史对话
|
| 40 |
+
for human, assistant in (history or []):
|
| 41 |
+
if human:
|
| 42 |
+
messages.append({"role": "user", "content": human})
|
| 43 |
+
if assistant:
|
| 44 |
+
messages.append({"role": "assistant", "content": assistant})
|
| 45 |
|
| 46 |
+
# 4) 当前输入
|
| 47 |
messages.append({"role": "user", "content": message})
|
| 48 |
|
| 49 |
+
return messages
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _yield_chunks(text, step=40):
|
| 53 |
+
"""把整段文本切成小块,伪流式输出。"""
|
| 54 |
+
s = str(text or "")
|
| 55 |
+
for i in range(0, len(s), step):
|
| 56 |
+
yield s[i:i + step]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _extract_stream_delta(resp):
|
| 60 |
+
"""
|
| 61 |
+
尽量兼容不同 dashscope 小版本的流式返回结构,提取“增量文本”
|
| 62 |
+
常见字段:resp.output_text 或 resp.output.choices[0].delta/message/content
|
| 63 |
+
"""
|
| 64 |
+
delta = None
|
| 65 |
+
# 优先简单字段
|
| 66 |
+
if hasattr(resp, "output_text") and resp.output_text:
|
| 67 |
+
return resp.output_text
|
| 68 |
+
|
| 69 |
+
out = getattr(resp, "output", None)
|
| 70 |
+
if isinstance(out, dict):
|
| 71 |
+
choices = out.get("choices") or []
|
| 72 |
+
if choices:
|
| 73 |
+
c0 = choices[0] or {}
|
| 74 |
+
# 1) delta 路径
|
| 75 |
+
d = c0.get("delta")
|
| 76 |
+
if isinstance(d, dict):
|
| 77 |
+
delta = d.get("content") or d.get("text") or None
|
| 78 |
+
elif d:
|
| 79 |
+
delta = str(d)
|
| 80 |
+
# 2) message 路径(有些版本直接不断给 message.content)
|
| 81 |
+
if not delta and isinstance(c0.get("message"), dict):
|
| 82 |
+
delta = c0["message"].get("content")
|
| 83 |
+
# 3) content 直给
|
| 84 |
+
if not delta:
|
| 85 |
+
delta = c0.get("content")
|
| 86 |
+
|
| 87 |
+
return delta
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ========== 非流式(保留你原实现,便于兼容) ==========
|
| 91 |
+
def design_poker_game(message, history, uploaded_files, custom_prompt_text, prompt_mode):
|
| 92 |
+
"""
|
| 93 |
+
设计扑克游戏的主要函数(非流式)
|
| 94 |
+
"""
|
| 95 |
+
# 输入验证
|
| 96 |
+
is_valid, error_msg = input_validator.validate_message(message)
|
| 97 |
+
if not is_valid:
|
| 98 |
+
return f"❌ 输入验证失败:{error_msg}"
|
| 99 |
+
|
| 100 |
+
is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
|
| 101 |
+
if not is_valid:
|
| 102 |
+
return f"❌ 自定义提示词验证失败:{error_msg}"
|
| 103 |
+
|
| 104 |
+
is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
|
| 105 |
+
if not is_valid:
|
| 106 |
+
return f"❌ 文件验证失败:{error_msg}"
|
| 107 |
+
|
| 108 |
+
messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode)
|
| 109 |
+
|
| 110 |
+
# 仅在“无历史”时启用缓存(沿用你的策略)
|
| 111 |
+
if len(history or []) == 0:
|
| 112 |
cached_response = request_cache.get(messages)
|
| 113 |
if cached_response:
|
| 114 |
return cached_response
|
| 115 |
+
|
| 116 |
response = _call_ai_model(messages)
|
| 117 |
+
|
| 118 |
+
if len(history or []) == 0 and response and not response.startswith(("❌", "💥")):
|
|
|
|
| 119 |
request_cache.set(messages, response)
|
| 120 |
+
|
| 121 |
return response
|
| 122 |
|
| 123 |
|
| 124 |
def _call_ai_model(messages):
|
| 125 |
"""
|
| 126 |
+
调用 AI 模型(非流式)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
"""
|
| 128 |
try:
|
| 129 |
response = Generation.call(
|
|
|
|
| 132 |
temperature=TEMPERATURE,
|
| 133 |
top_p=TOP_P,
|
| 134 |
max_tokens=MAX_TOKENS,
|
| 135 |
+
result_format="message",
|
| 136 |
+
enable_thinking=False,
|
| 137 |
)
|
| 138 |
+
|
| 139 |
if response.status_code == 200:
|
| 140 |
content = response.output.choices[0].message.content
|
| 141 |
if not content or content.strip() == "":
|
|
|
|
| 143 |
return content
|
| 144 |
else:
|
| 145 |
return _handle_api_error(response)
|
| 146 |
+
|
| 147 |
except ConnectionError as e:
|
| 148 |
return f"🌐 网络连接错误:{str(e)}\n\n请检查网络连接是否正常。"
|
| 149 |
except TimeoutError as e:
|
|
|
|
| 156 |
def _handle_api_error(response):
|
| 157 |
"""
|
| 158 |
处理 API 错误
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
"""
|
| 160 |
+
status_code = getattr(response, "status_code", None)
|
| 161 |
+
code_raw = getattr(response, "code", None)
|
| 162 |
+
message = getattr(response, "message", "")
|
|
|
|
| 163 |
|
| 164 |
error_msg = f"❌ API 错误:{code_raw} - {message}"
|
| 165 |
|
|
|
|
| 171 |
|
| 172 |
code_int = _as_int(code_raw)
|
| 173 |
|
| 174 |
+
if (status_code == 401) or (code_int == 401) or (str(code_raw) == "401"):
|
| 175 |
error_msg += "\n\n💡 提示:请检查 API Key 是否正确设置。"
|
| 176 |
+
elif (status_code == 429) or (code_int == 429) or (str(code_raw) == "429"):
|
| 177 |
error_msg += "\n\n💡 提示:请求过于频繁,请稍后再试。"
|
| 178 |
else:
|
|
|
|
| 179 |
if (isinstance(status_code, int) and status_code >= 500) or (code_int is not None and code_int >= 500):
|
| 180 |
error_msg += "\n\n💡 提示:服务器错误,请稍后重试。"
|
| 181 |
|
| 182 |
return error_msg
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ========== 新增:原生流式 ==========
|
| 186 |
+
def design_poker_game_stream(message, history, uploaded_files, custom_prompt_text, prompt_mode):
|
| 187 |
+
"""
|
| 188 |
+
原生流式:逐段 yield 文本片段(字符串)
|
| 189 |
+
- 与 design_poker_game 的提示与参数保持一致
|
| 190 |
+
- app.py 会优先调用本函数实现“边生成边显示”
|
| 191 |
+
"""
|
| 192 |
+
# 1) 输入验证(与非流式一致)
|
| 193 |
+
is_valid, error_msg = input_validator.validate_message(message)
|
| 194 |
+
if not is_valid:
|
| 195 |
+
yield f"❌ 输入验证失败:{error_msg}"
|
| 196 |
+
return
|
| 197 |
+
|
| 198 |
+
is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
|
| 199 |
+
if not is_valid:
|
| 200 |
+
yield f"❌ 自定义提示词验证失败:{error_msg}"
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
|
| 204 |
+
if not is_valid:
|
| 205 |
+
yield f"❌ 文件验证失败:{error_msg}"
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
# 2) 组装 messages
|
| 209 |
+
messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode)
|
| 210 |
+
|
| 211 |
+
# 3) 缓存命中(仅无历史时)
|
| 212 |
+
no_hist = len(history or []) == 0
|
| 213 |
+
if no_hist:
|
| 214 |
+
cached = request_cache.get(messages)
|
| 215 |
+
if cached:
|
| 216 |
+
for piece in _yield_chunks(cached, step=48): # 比非流式略大些片段,体感更顺
|
| 217 |
+
yield piece
|
| 218 |
+
return
|
| 219 |
+
|
| 220 |
+
# 4) 原生流式调用
|
| 221 |
+
buf = []
|
| 222 |
+
try:
|
| 223 |
+
resp_iter = Generation.call(
|
| 224 |
+
model=MODEL_NAME,
|
| 225 |
+
messages=messages,
|
| 226 |
+
temperature=TEMPERATURE,
|
| 227 |
+
top_p=TOP_P,
|
| 228 |
+
max_tokens=MAX_TOKENS,
|
| 229 |
+
result_format="message", # 与非流式保持一致
|
| 230 |
+
enable_thinking=False,
|
| 231 |
+
stream=True,
|
| 232 |
+
incremental_output=True, # 关键:��量输出
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
for resp in resp_iter:
|
| 236 |
+
# 有些帧可能是控制帧/心跳,直接跳过
|
| 237 |
+
delta = _extract_stream_delta(resp)
|
| 238 |
+
if not delta:
|
| 239 |
+
# 也可能是错误帧
|
| 240 |
+
status_code = getattr(resp, "status_code", 200)
|
| 241 |
+
if status_code and status_code != 200:
|
| 242 |
+
# 尽量提取错误信息并终止
|
| 243 |
+
err = _handle_api_error(resp)
|
| 244 |
+
yield f"\n{err}"
|
| 245 |
+
return
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
buf.append(delta)
|
| 249 |
+
yield delta # 每拿到一段就吐出去
|
| 250 |
+
|
| 251 |
+
# 5) 结束:写入缓存(仅无历史时 & 有内容 & 无错误提示)
|
| 252 |
+
full = "".join(buf).strip()
|
| 253 |
+
if no_hist and full and not full.startswith(("❌", "💥")):
|
| 254 |
+
request_cache.set(messages, full)
|
| 255 |
+
|
| 256 |
+
except ConnectionError as e:
|
| 257 |
+
yield f"\n🌐 网络连接错误:{str(e)}"
|
| 258 |
+
except TimeoutError as e:
|
| 259 |
+
yield f"\n⏰ 请求超时:{str(e)}"
|
| 260 |
+
except Exception as e:
|
| 261 |
+
yield f"\n💥 流式调用失败:{type(e).__name__}: {e}"
|