"""
AI 服务模块 - 处理与 AI 模型的交互(支持原生流式输出)
"""
from openai import OpenAI
from config import API_KEY, API_BASE_URL, MODEL_NAME, MAX_TOKENS, TEMPERATURE, TOP_P, SYSTEM_PROMPT
from file_handler import load_gdl_text
from cache_manager import request_cache
from security import input_validator
from default_content import get_default_gdl, get_default_prompt, get_default_example_gdl
# 初始化OpenAI客户端(DeepSeek V3兼容)
client = OpenAI(
api_key=API_KEY,
base_url=API_BASE_URL
)
# ========== 公共小工具 ==========
def _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode):
"""
组装 messages,保证与非流式/流式两条路径的提示词一致
"""
# 1) 选择 System Prompt
base_sys = (SYSTEM_PROMPT or "").strip()
user_sys = (custom_prompt_text or "").strip()
mode = (prompt_mode or "覆盖默认SYSTEM_PROMPT").strip()
# 🟢 如果用户没有提供自定义 prompt,则使用默认的 prompt 内容
if not user_sys:
default_prompt = get_default_prompt()
if default_prompt:
user_sys = default_prompt
if user_sys:
wrapped_user_sys = f"\n{user_sys}\n"
system_to_use = (base_sys + "\n\n" + wrapped_user_sys) if mode.startswith("合并") else wrapped_user_sys
else:
system_to_use = base_sys
messages = [{"role": "system", "content": system_to_use}]
# 2) 注入上传的 GDL(作为第二条 system)
# 🟢 如果用户没有上传文件,则使用默认的 GDL 内容
gdl_spec = load_gdl_text(uploaded_files)
if not gdl_spec:
gdl_spec = get_default_gdl()
if gdl_spec:
messages.append({
"role": "system",
"content": "以下为用户上传的麻将游戏通用语言(mGDL)规范或示例,请在设计与输出中严格遵循:\n\n"
+ gdl_spec + "\n"
})
# 2.5) 注入示例 GDL 文档(作为参考示例)
# 🟢 自动加载示例 GDL 文档,供 AI 参考
example_gdl = get_default_example_gdl()
if example_gdl:
messages.append({
"role": "system",
"content": "以下为示例 GDL 文档,供您参考设计时使用:\n\n"
+ example_gdl + "\n"
})
# 3) 追加历史对话
for human, assistant in (history or []):
if human:
messages.append({"role": "user", "content": human})
if assistant:
messages.append({"role": "assistant", "content": assistant})
# 4) 当前输入
messages.append({"role": "user", "content": message})
return messages
def _yield_chunks(text, step=40):
"""把整段文本切成小块,伪流式输出。"""
s = str(text or "")
for i in range(0, len(s), step):
yield s[i:i + step]
# ========== 非流式(保留你原实现,便于兼容) ==========
def design_mahjong_game(message, history, uploaded_files, custom_prompt_text, prompt_mode):
"""
设计麻将玩法的主要函数(非流式)
"""
# 输入验证
is_valid, error_msg = input_validator.validate_message(message)
if not is_valid:
return f"❌ 输入验证失败:{error_msg}"
is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
if not is_valid:
return f"❌ 自定义提示词验证失败:{error_msg}"
is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
if not is_valid:
return f"❌ 文件验证失败:{error_msg}"
messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode)
# 仅在“无历史”时启用缓存(沿用你的策略)
if len(history or []) == 0:
cached_response = request_cache.get(messages)
if cached_response:
return cached_response
response = _call_ai_model(messages)
if len(history or []) == 0 and response and not response.startswith(("❌", "💥")):
request_cache.set(messages, response)
return response
def _call_ai_model(messages):
"""
调用 AI 模型(非流式)
"""
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=MAX_TOKENS,
)
content = response.choices[0].message.content
if not content or content.strip() == "":
return "🤔 AI 返回了空内容,请尝试重新发送或调整输入。"
return content
except ConnectionError as e:
return f"🌐 网络连接错误:{str(e)}\n\n请检查网络连接是否正常。"
except TimeoutError as e:
return f"⏰ 请求超时:{str(e)}\n\n请稍后重试,或尝试减少输入内容长度。"
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
return f"💥 调用失败:{error_type}: {error_msg}\n\n请检查 API Key 是否正确,或网络是否通畅。"
# ========== 新增:原生流式 ==========
def design_mahjong_game_stream(message, history, uploaded_files, custom_prompt_text, prompt_mode):
"""
原生流式:逐段 yield 文本片段(字符串)
"""
# 1) 输入验证(与非流式一致)
is_valid, error_msg = input_validator.validate_message(message)
if not is_valid:
yield f"❌ 输入验证失败:{error_msg}"
return
is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text)
if not is_valid:
yield f"❌ 自定义提示词验证失败:{error_msg}"
return
is_valid, error_msg = input_validator.validate_file_list(uploaded_files)
if not is_valid:
yield f"❌ 文件验证失败:{error_msg}"
return
# 2) 组装 messages(与非流式完全一致)
messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode)
# 3) 无历史时的缓存命中
no_hist = len(history or []) == 0
if no_hist:
cached = request_cache.get(messages)
if cached:
for piece in _yield_chunks(cached, step=48):
yield piece
return
# 4) 原生流式调用(OpenAI兼容API)
buf = []
try:
stream = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=MAX_TOKENS,
stream=True,
)
# 简单的字符级节流(攒到一定长度再刷新,提高前端性能)
cache_piece = []
cache_len = 0
FLUSH_EVERY = 24 # 每凑够 N 字符刷新一次;可按需调整
for chunk in stream:
# 安全地提取增量文本
delta = None
if chunk.choices and len(chunk.choices) > 0:
delta_obj = chunk.choices[0].delta
if delta_obj and hasattr(delta_obj, 'content'):
delta = delta_obj.content
# 有的帧是控制帧,不含文本
if not delta:
continue
buf.append(delta)
cache_piece.append(delta)
cache_len += len(delta)
# 小节流:积累到一定字符再 yield
if cache_len >= FLUSH_EVERY:
text_chunk = "".join(cache_piece)
cache_piece.clear()
cache_len = 0
yield text_chunk
# 循环结束,把最后没刷出去的片段刷掉
if cache_piece:
yield "".join(cache_piece)
# 5) 写入缓存(仅无历史 & 正常内容)
full = "".join(buf).strip()
if no_hist and full and not full.startswith(("❌", "💥")):
request_cache.set(messages, full)
except ConnectionError as e:
yield f"\n🌐 网络连接错误:{str(e)}"
except TimeoutError as e:
yield f"\n⏰ 请求超时:{str(e)}"
except Exception as e:
# 这里不再抛具体 KeyError,而是把异常消息直接展示出来,避免中断生成器
yield f"\n💥 流式调用失败:{type(e).__name__}: {e}"