Spaces:
Sleeping
Sleeping
| """ | |
| AI 服务模块 - 处理与 AI 模型的交互(支持原生流式输出) | |
| """ | |
| from openai import OpenAI | |
| from config import ( | |
| API_KEY, | |
| API_BASE_URL, | |
| MODEL_NAME, | |
| MAX_TOKENS, | |
| TEMPERATURE, | |
| TOP_P, | |
| SYSTEM_PROMPT, | |
| ENABLE_REFERENCE_RETRIEVAL, | |
| REFERENCE_MAX_VARIANTS, | |
| INJECT_REFERENCE_MGDL, | |
| INJECT_ALL_EXAMPLE_GDL, | |
| ENABLE_OUTPUT_VALIDATION, | |
| ENABLE_AUTO_REPAIR, | |
| ) | |
| from file_handler import load_gdl_text | |
| from cache_manager import request_cache | |
| from security import input_validator | |
| from default_content import get_default_gdl, get_default_prompt, get_default_example_gdl | |
| from reference_retriever import build_reference_pack | |
| from output_validator import validate_mahjong_response, format_issues_for_llm | |
| # 初始化OpenAI客户端(DeepSeek V3兼容) | |
| client = OpenAI( | |
| api_key=API_KEY, | |
| base_url=API_BASE_URL | |
| ) | |
| def _is_analyse_mode(messages): | |
| """ | |
| Analyse 模式下按约定不输出 mGDL,因此应跳过“缺 mGDL”的静态校验提示, | |
| 否则会在对话末尾产生误导性的 NO_MGDL 警告。 | |
| """ | |
| try: | |
| for msg in reversed(messages or []): | |
| if msg.get("role") == "user": | |
| content = msg.get("content", "") or "" | |
| return "<ANALYSE_MODE>true</ANALYSE_MODE>" in content | |
| except Exception: | |
| return False | |
| return False | |
| # ========== 公共小工具 ========== | |
| def _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode): | |
| """ | |
| 组装 messages,保证与非流式/流式两条路径的提示词一致 | |
| """ | |
| # 1) 选择 System Prompt | |
| base_sys = (SYSTEM_PROMPT or "").strip() | |
| user_sys = (custom_prompt_text or "").strip() | |
| mode = (prompt_mode or "覆盖默认SYSTEM_PROMPT").strip() | |
| # 🟢 如果用户没有提供自定义 prompt,则使用默认的 prompt 内容 | |
| if not user_sys: | |
| default_prompt = get_default_prompt() | |
| if default_prompt: | |
| user_sys = default_prompt | |
| if user_sys: | |
| wrapped_user_sys = f"<TEAM_PROMPT>\n{user_sys}\n</TEAM_PROMPT>" | |
| system_to_use = (base_sys + "\n\n" + wrapped_user_sys) if mode.startswith("合并") else wrapped_user_sys | |
| else: | |
| system_to_use = base_sys | |
| messages = [{"role": "system", "content": system_to_use}] | |
| # 2) 注入上传的 GDL(作为第二条 system) | |
| # 🟢 如果用户没有上传文件,则使用默认的 GDL 内容 | |
| gdl_spec = load_gdl_text(uploaded_files) | |
| if not gdl_spec: | |
| gdl_spec = get_default_gdl() | |
| if gdl_spec: | |
| messages.append({ | |
| "role": "system", | |
| "content": "以下为用户上传的麻将游戏通用语言(mGDL)规范或示例,请在设计与输出中严格遵循:\n<GDL_SPEC>\n" | |
| + gdl_spec + "\n</GDL_SPEC>" | |
| }) | |
| # 2.5) 注入“参考玩法包”(RAG-lite:少量最相关示例,而不是全量堆叠) | |
| if ENABLE_REFERENCE_RETRIEVAL: | |
| pack = build_reference_pack( | |
| message, | |
| max_variants=REFERENCE_MAX_VARIANTS, | |
| include_mechanism_library=True, | |
| include_mgdl=INJECT_REFERENCE_MGDL, | |
| ) | |
| mech = (pack.get("mechanism_library") or "").strip() | |
| if mech: | |
| messages.append({ | |
| "role": "system", | |
| "content": "以下为【麻将机制说明/机制词典】,创新时优先从中挑选可落地机制再做组合:\n" | |
| "<MECHANISM_LIBRARY>\n" + mech + "\n</MECHANISM_LIBRARY>" | |
| }) | |
| picked_names = (pack.get("picked_names") or "").strip() | |
| ref_md = (pack.get("reference_md") or "").strip() | |
| if ref_md: | |
| messages.append({ | |
| "role": "system", | |
| "content": "以下为【参考玩法自然语言规则(主真理)】。当参考玩法的 .md 与 .txt 有冲突时,以 .md 为准:\n" | |
| "本轮命中参考玩法:" + (picked_names or "(未命中,使用兜底样例)") + "\n" | |
| "<REFERENCE_VARIANTS_MD>\n" + ref_md + "\n</REFERENCE_VARIANTS_MD>" | |
| }) | |
| # 默认不注入参考玩法 mGDL:mGDL 更适合作为语法约束与输出格式规范,语义参考以 .md 为主 | |
| if INJECT_REFERENCE_MGDL: | |
| ref_mgdl = (pack.get("reference_mgdl") or "").strip() | |
| if ref_mgdl: | |
| messages.append({ | |
| "role": "system", | |
| "content": "以下为【参考玩法 mGDL(辅语法翻译)】。仅用于学习如何用 v1.3 语法表达规则:\n" | |
| "<REFERENCE_VARIANTS_MGDL>\n" + ref_mgdl + "\n</REFERENCE_VARIANTS_MGDL>" | |
| }) | |
| # 可选:仍注入全量示例(不推荐:容易稀释注意力) | |
| if INJECT_ALL_EXAMPLE_GDL: | |
| example_gdl = get_default_example_gdl() | |
| if example_gdl: | |
| messages.append({ | |
| "role": "system", | |
| "content": "以下为【全量示例 mGDL】(注意:过多示例可能稀释注意力;优先使用上面的“参考玩法包”):\n<EXAMPLE_GDL_ALL>\n" | |
| + example_gdl + "\n</EXAMPLE_GDL_ALL>" | |
| }) | |
| # 3) 追加历史对话 | |
| for human, assistant in (history or []): | |
| if human: | |
| messages.append({"role": "user", "content": human}) | |
| if assistant: | |
| messages.append({"role": "assistant", "content": assistant}) | |
| # 4) 当前输入 | |
| messages.append({"role": "user", "content": message}) | |
| return messages | |
| def _yield_chunks(text, step=40): | |
| """把整段文本切成小块,伪流式输出。""" | |
| s = str(text or "") | |
| for i in range(0, len(s), step): | |
| yield s[i:i + step] | |
| # ========== 非流式(保留你原实现,便于兼容) ========== | |
| def design_mahjong_game(message, history, uploaded_files, custom_prompt_text, prompt_mode): | |
| """ | |
| 设计麻将玩法的主要函数(非流式) | |
| """ | |
| # 输入验证 | |
| is_valid, error_msg = input_validator.validate_message(message) | |
| if not is_valid: | |
| return f"❌ 输入验证失败:{error_msg}" | |
| is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text) | |
| if not is_valid: | |
| return f"❌ 自定义提示词验证失败:{error_msg}" | |
| is_valid, error_msg = input_validator.validate_file_list(uploaded_files) | |
| if not is_valid: | |
| return f"❌ 文件验证失败:{error_msg}" | |
| messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode) | |
| # 仅在“无历史”时启用缓存(沿用你的策略) | |
| if len(history or []) == 0: | |
| cached_response = request_cache.get(messages) | |
| if cached_response: | |
| return cached_response | |
| response = _call_ai_model(messages) | |
| # 输出校验 +(可选)最小修复回路(仅非流式,避免影响实时体验) | |
| if (not _is_analyse_mode(messages)) and ENABLE_OUTPUT_VALIDATION and response and not response.startswith(("❌", "💥")): | |
| issues = validate_mahjong_response(response) | |
| if issues and ENABLE_AUTO_REPAIR: | |
| fix_instructions = format_issues_for_llm(issues) | |
| repair_messages = list(messages) | |
| repair_messages.append({ | |
| "role": "user", | |
| "content": "下面是你刚刚的输出。请只做【最小修改】来修复这些问题:\n" | |
| + fix_instructions | |
| + "\n\n【原输出】\n" | |
| + response | |
| + "\n\n修复要求:\n" | |
| "1) 不要引入新机制,除非为满足守恒/模块完整性而必须。\n" | |
| "2) 保持规则风味不变,只补齐缺失模块/替换占位符/补全必要字段。\n" | |
| "3) 修复后重新输出完整结果(自然语言规则 + mGDL + 自检报告)。" | |
| }) | |
| repaired = _call_ai_model(repair_messages) | |
| if repaired and not repaired.startswith(("❌", "💥")): | |
| response = repaired | |
| if len(history or []) == 0 and response and not response.startswith(("❌", "💥")): | |
| request_cache.set(messages, response) | |
| return response | |
| def _call_ai_model(messages): | |
| """ | |
| 调用 AI 模型(非流式) | |
| """ | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| max_tokens=MAX_TOKENS, | |
| ) | |
| content = response.choices[0].message.content | |
| if not content or content.strip() == "": | |
| return "🤔 AI 返回了空内容,请尝试重新发送或调整输入。" | |
| return content | |
| except ConnectionError as e: | |
| return f"🌐 网络连接错误:{str(e)}\n\n请检查网络连接是否正常。" | |
| except TimeoutError as e: | |
| return f"⏰ 请求超时:{str(e)}\n\n请稍后重试,或尝试减少输入内容长度。" | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| error_msg = str(e) | |
| return f"💥 调用失败:{error_type}: {error_msg}\n\n请检查 API Key 是否正确,或网络是否通畅。" | |
| # ========== 新增:原生流式 ========== | |
| def design_mahjong_game_stream(message, history, uploaded_files, custom_prompt_text, prompt_mode): | |
| """ | |
| 原生流式:逐段 yield 文本片段(字符串) | |
| """ | |
| # 1) 输入验证(与非流式一致) | |
| is_valid, error_msg = input_validator.validate_message(message) | |
| if not is_valid: | |
| yield f"❌ 输入验证失败:{error_msg}" | |
| return | |
| is_valid, error_msg = input_validator.validate_custom_prompt(custom_prompt_text) | |
| if not is_valid: | |
| yield f"❌ 自定义提示词验证失败:{error_msg}" | |
| return | |
| is_valid, error_msg = input_validator.validate_file_list(uploaded_files) | |
| if not is_valid: | |
| yield f"❌ 文件验证失败:{error_msg}" | |
| return | |
| # 2) 组装 messages(与非流式完全一致) | |
| messages = _prepare_messages(message, history, uploaded_files, custom_prompt_text, prompt_mode) | |
| # 3) 无历史时的缓存命中 | |
| no_hist = len(history or []) == 0 | |
| if no_hist: | |
| cached = request_cache.get(messages) | |
| if cached: | |
| for piece in _yield_chunks(cached, step=48): | |
| yield piece | |
| return | |
| # 4) 原生流式调用(OpenAI兼容API) | |
| buf = [] | |
| try: | |
| stream = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| max_tokens=MAX_TOKENS, | |
| stream=True, | |
| ) | |
| # 简单的字符级节流(攒到一定长度再刷新,提高前端性能) | |
| cache_piece = [] | |
| cache_len = 0 | |
| FLUSH_EVERY = 24 # 每凑够 N 字符刷新一次;可按需调整 | |
| for chunk in stream: | |
| # 安全地提取增量文本 | |
| delta = None | |
| if chunk.choices and len(chunk.choices) > 0: | |
| delta_obj = chunk.choices[0].delta | |
| if delta_obj and hasattr(delta_obj, 'content'): | |
| delta = delta_obj.content | |
| # 有的帧是控制帧,不含文本 | |
| if not delta: | |
| continue | |
| buf.append(delta) | |
| cache_piece.append(delta) | |
| cache_len += len(delta) | |
| # 小节流:积累到一定字符再 yield | |
| if cache_len >= FLUSH_EVERY: | |
| text_chunk = "".join(cache_piece) | |
| cache_piece.clear() | |
| cache_len = 0 | |
| yield text_chunk | |
| # 循环结束,把最后没刷出去的片段刷掉 | |
| if cache_piece: | |
| yield "".join(cache_piece) | |
| # 5) 写入缓存(仅无历史 & 正常内容) | |
| full = "".join(buf).strip() | |
| if (not _is_analyse_mode(messages)) and ENABLE_OUTPUT_VALIDATION and full and not full.startswith(("❌", "💥")): | |
| issues = validate_mahjong_response(full) | |
| if issues: | |
| hint = format_issues_for_llm(issues) | |
| yield "\n\n---\n⚠️ 输出静态校验发现潜在问题(建议让模型按最小修改修复后再导出):\n" + hint + "\n" | |
| if no_hist and full and not full.startswith(("❌", "💥")): | |
| request_cache.set(messages, full) | |
| except ConnectionError as e: | |
| yield f"\n🌐 网络连接错误:{str(e)}" | |
| except TimeoutError as e: | |
| yield f"\n⏰ 请求超时:{str(e)}" | |
| except Exception as e: | |
| # 这里不再抛具体 KeyError,而是把异常消息直接展示出来,避免中断生成器 | |
| yield f"\n💥 流式调用失败:{type(e).__name__}: {e}" | |