Spaces:
Runtime error
Runtime error
| """ | |
| utils.py - StoryWeaver 工具函数模块 | |
| 职责: | |
| 1. 加载环境变量,初始化 OpenAI 兼容客户端 (Qwen API) | |
| 2. 提供通用的 API 调用封装函数(带重试机制) | |
| 3. 提供 JSON 安全解析工具(从 LLM 输出中提取结构化数据) | |
| """ | |
| import os | |
| import re | |
| import json | |
| import time | |
| import logging | |
| from typing import Any, Optional | |
| from dotenv import load_dotenv | |
| try: | |
| from openai import OpenAI | |
| _OPENAI_IMPORT_ERROR: Optional[Exception] = None | |
| except ImportError as exc: # pragma: no cover - depends on local env | |
| OpenAI = None # type: ignore[assignment] | |
| _OPENAI_IMPORT_ERROR = exc | |
| # ============================================================ | |
| # 日志配置 | |
| # ============================================================ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger("StoryWeaver") | |
| # ============================================================ | |
| # 环境变量加载 & API 客户端初始化 | |
| # ============================================================ | |
| # 从项目根目录的 .env 文件加载环境变量 | |
| load_dotenv() | |
| # 严禁硬编码 API Key —— 仅通过环境变量读取 | |
| QWEN_API_KEY: str = os.getenv("QWEN_API_KEY", "") | |
| if not QWEN_API_KEY or QWEN_API_KEY == "sk-xxxxxx": | |
| logger.warning( | |
| "⚠️ QWEN_API_KEY 未设置或仍为模板值!" | |
| "请在 .env 文件中填写有效的 API Key。" | |
| ) | |
| # 使用 OpenAI 兼容格式连接 Qwen API | |
| # base_url 指向通义千问的 OpenAI 兼容端点 | |
| _client: Optional[Any] = None | |
| def get_client() -> Any: | |
| """ | |
| 获取全局 OpenAI 客户端(懒加载单例)。 | |
| 使用兼容格式调用 Qwen API。 | |
| """ | |
| global _client | |
| if OpenAI is None: | |
| raise RuntimeError( | |
| "未安装 openai 依赖,无法初始化 Qwen 客户端。" | |
| "请先执行 `pip install -r requirements.txt`。" | |
| ) from _OPENAI_IMPORT_ERROR | |
| if _client is None: | |
| _client = OpenAI( | |
| api_key=QWEN_API_KEY, | |
| base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", | |
| ) | |
| return _client | |
| # ============================================================ | |
| # 默认模型配置 | |
| # ============================================================ | |
| # 使用 qwen2.5-14b-instruct 以获得最快的响应速度 | |
| DEFAULT_MODEL: str = "qwen2.5-14b-instruct" | |
| # ============================================================ | |
| # 通用 API 调用封装(带重试 & 错误处理) | |
| # ============================================================ | |
| def call_qwen( | |
| messages: list[dict[str, str]], | |
| model: str = DEFAULT_MODEL, | |
| temperature: float = 0.8, | |
| max_tokens: int = 2000, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0, | |
| ) -> str: | |
| """ | |
| 调用 Qwen API 的通用封装函数。 | |
| 设计思路: | |
| - 使用 OpenAI 兼容格式,方便后续切换模型 | |
| - 内置指数退避重试机制,应对网络波动和限流 | |
| - 返回纯文本内容,JSON 解析交给调用方处理 | |
| Args: | |
| messages: OpenAI 格式的消息列表 [{"role": "system", "content": "..."}, ...] | |
| model: 模型名称,默认 qwen-plus | |
| temperature: 生成温度,越高越有创意(0.0-2.0) | |
| max_tokens: 最大生成 token 数 | |
| max_retries: 最大重试次数 | |
| retry_delay: 初始重试间隔(秒),每次翻倍 | |
| Returns: | |
| 模型生成的文本内容 | |
| Raises: | |
| Exception: 重试耗尽后抛出最后一次异常 | |
| """ | |
| client = get_client() | |
| last_exception: Optional[Exception] = None | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| logger.info(f"调用 Qwen API (尝试 {attempt}/{max_retries}),模型: {model}") | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| logger.info(f"API 调用成功,响应长度: {len(content)} 字符") | |
| return content | |
| except Exception as e: | |
| last_exception = e | |
| logger.warning(f"API 调用失败 (尝试 {attempt}/{max_retries}): {e}") | |
| if attempt < max_retries: | |
| sleep_time = retry_delay * (2 ** (attempt - 1)) | |
| logger.info(f"等待 {sleep_time:.1f} 秒后重试...") | |
| time.sleep(sleep_time) | |
| # 重试耗尽,抛出异常 | |
| raise RuntimeError( | |
| f"Qwen API 调用在 {max_retries} 次尝试后仍然失败: {last_exception}" | |
| ) | |
| def call_qwen_stream( | |
| messages: list[dict[str, str]], | |
| model: str = DEFAULT_MODEL, | |
| temperature: float = 0.8, | |
| max_tokens: int = 2000, | |
| ): | |
| """ | |
| 调用 Qwen API 的流式版本,逐块 yield 文本内容。 | |
| 使用 stream=True,让用户在 AI 生成过程中就能看到文字逐步出现, | |
| 大幅改善感知延迟。 | |
| Args: | |
| messages: OpenAI 格式的消息列表 | |
| model: 模型名称 | |
| temperature: 生成温度 | |
| max_tokens: 最大生成 token 数 | |
| Yields: | |
| 每次生成的文本片段(str) | |
| """ | |
| client = get_client() | |
| logger.info(f"调用 Qwen 流式 API,模型: {model}") | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| ) | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| yield chunk.choices[0].delta.content | |
| except Exception as e: | |
| logger.error(f"流式 API 调用失败: {e}") | |
| raise | |
| # ============================================================ | |
| # JSON 安全解析工具 | |
| # ============================================================ | |
| def extract_json_from_text(text: str) -> Optional[dict | list]: | |
| """ | |
| 从 LLM 输出的文本中提取 JSON 数据。 | |
| 设计思路: | |
| LLM 有时会在 JSON 前后附加说明文字,或使用 ```json 代码块包裹。 | |
| 此函数通过多种策略尝试提取有效 JSON: | |
| 1. 先尝试直接解析整段文本 | |
| 2. 再尝试提取 ```json ... ``` 代码块 | |
| 3. 最后尝试匹配第一个 { ... } 或 [ ... ] 结构 | |
| Args: | |
| text: LLM 返回的原始文本 | |
| Returns: | |
| 解析后的 dict/list,解析失败返回 None | |
| """ | |
| if not text: | |
| return None | |
| # 策略1: 直接解析(LLM 可能返回纯 JSON) | |
| try: | |
| return json.loads(text.strip()) | |
| except json.JSONDecodeError: | |
| pass | |
| # 策略2: 提取 ```json ... ``` 代码块 | |
| code_block_pattern = r"```(?:json)?\s*\n?(.*?)\n?\s*```" | |
| matches = re.findall(code_block_pattern, text, re.DOTALL) | |
| for match in matches: | |
| try: | |
| return json.loads(match.strip()) | |
| except json.JSONDecodeError: | |
| continue | |
| # 策略3: 匹配第一个完整的 JSON 对象 { ... } | |
| brace_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}" | |
| brace_matches = re.findall(brace_pattern, text, re.DOTALL) | |
| for match in brace_matches: | |
| try: | |
| return json.loads(match) | |
| except json.JSONDecodeError: | |
| continue | |
| # 策略4: 匹配嵌套更深的 JSON(贪婪匹配从第一个 { 到最后一个 }) | |
| deep_match = re.search(r"\{.*\}", text, re.DOTALL) | |
| if deep_match: | |
| try: | |
| return json.loads(deep_match.group()) | |
| except json.JSONDecodeError: | |
| pass | |
| # 策略5: 匹配 JSON 数组 [ ... ] | |
| array_match = re.search(r"\[.*\]", text, re.DOTALL) | |
| if array_match: | |
| try: | |
| return json.loads(array_match.group()) | |
| except json.JSONDecodeError: | |
| pass | |
| logger.warning(f"无法从文本中提取 JSON: {text[:200]}...") | |
| return None | |
| def safe_json_call( | |
| messages: list[dict[str, str]], | |
| model: str = DEFAULT_MODEL, | |
| temperature: float = 0.3, | |
| max_tokens: int = 2000, | |
| max_retries: int = 3, | |
| ) -> Optional[dict | list]: | |
| """ | |
| 调用 Qwen API 并安全地解析返回的 JSON。 | |
| 设计思路: | |
| - 将 API 调用与 JSON 解析合为一步 | |
| - 如果第一次解析失败,会额外重试(重新调用 API) | |
| - temperature 默认较低 (0.3),让 JSON 输出更稳定 | |
| Args: | |
| messages: 消息列表 | |
| model: 模型名称 | |
| temperature: 生成温度(JSON 输出建议低温) | |
| max_tokens: 最大 token 数 | |
| max_retries: JSON 解析失败时的额外重试次数 | |
| Returns: | |
| 解析后的 dict/list,全部失败返回 None | |
| """ | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| raw_text = call_qwen( | |
| messages=messages, | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| result = extract_json_from_text(raw_text) | |
| if result is not None: | |
| return result | |
| logger.warning( | |
| f"JSON 解析失败 (尝试 {attempt}/{max_retries}),原始文本: {raw_text[:300]}..." | |
| ) | |
| except Exception as e: | |
| logger.error(f"safe_json_call 异常 (尝试 {attempt}/{max_retries}): {e}") | |
| logger.error(f"safe_json_call 在 {max_retries} 次尝试后仍无法获取有效 JSON") | |
| return None | |
| # ============================================================ | |
| # 辅助工具函数 | |
| # ============================================================ | |
| def clamp(value: int, min_val: int, max_val: int) -> int: | |
| """将数值限制在 [min_val, max_val] 范围内""" | |
| return max(min_val, min(max_val, value)) | |
| def format_dict_for_prompt(data: dict, indent: int = 0) -> str: | |
| """ | |
| 将字典格式化为易读的 Prompt 文本。 | |
| 用于将状态数据注入 System Prompt。 | |
| """ | |
| lines = [] | |
| prefix = " " * indent | |
| for key, value in data.items(): | |
| if isinstance(value, dict): | |
| lines.append(f"{prefix}{key}:") | |
| lines.append(format_dict_for_prompt(value, indent + 1)) | |
| elif isinstance(value, list): | |
| if value: | |
| lines.append(f"{prefix}{key}: {', '.join(str(v) for v in value)}") | |
| else: | |
| lines.append(f"{prefix}{key}: 无") | |
| else: | |
| lines.append(f"{prefix}{key}: {value}") | |
| return "\n".join(lines) | |