Spaces:
Runtime error
Runtime error
File size: 10,893 Bytes
9e03a34 4998893 9e03a34 4998893 9e03a34 4998893 9e03a34 4998893 9e03a34 9942d7a 9e03a34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 | """
utils.py - StoryWeaver 工具函数模块
职责:
1. 加载环境变量,初始化 OpenAI 兼容客户端 (Qwen API)
2. 提供通用的 API 调用封装函数(带重试机制)
3. 提供 JSON 安全解析工具(从 LLM 输出中提取结构化数据)
"""
import os
import re
import json
import time
import logging
from typing import Any, Optional
from dotenv import load_dotenv
try:
from openai import OpenAI
_OPENAI_IMPORT_ERROR: Optional[Exception] = None
except ImportError as exc: # pragma: no cover - depends on local env
OpenAI = None # type: ignore[assignment]
_OPENAI_IMPORT_ERROR = exc
# ============================================================
# 日志配置
# ============================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("StoryWeaver")
# ============================================================
# 环境变量加载 & API 客户端初始化
# ============================================================
# 从项目根目录的 .env 文件加载环境变量
load_dotenv()
# 严禁硬编码 API Key —— 仅通过环境变量读取
QWEN_API_KEY: str = os.getenv("QWEN_API_KEY", "")
if not QWEN_API_KEY or QWEN_API_KEY == "sk-xxxxxx":
logger.warning(
"⚠️ QWEN_API_KEY 未设置或仍为模板值!"
"请在 .env 文件中填写有效的 API Key。"
)
# 使用 OpenAI 兼容格式连接 Qwen API
# base_url 指向通义千问的 OpenAI 兼容端点
_client: Optional[Any] = None
def get_client() -> Any:
"""
获取全局 OpenAI 客户端(懒加载单例)。
使用兼容格式调用 Qwen API。
"""
global _client
if OpenAI is None:
raise RuntimeError(
"未安装 openai 依赖,无法初始化 Qwen 客户端。"
"请先执行 `pip install -r requirements.txt`。"
) from _OPENAI_IMPORT_ERROR
if _client is None:
_client = OpenAI(
api_key=QWEN_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
return _client
# ============================================================
# 默认模型配置
# ============================================================
# 使用 qwen2.5-14b-instruct 以获得最快的响应速度
DEFAULT_MODEL: str = "qwen2.5-14b-instruct"
# ============================================================
# 通用 API 调用封装(带重试 & 错误处理)
# ============================================================
def call_qwen(
messages: list[dict[str, str]],
model: str = DEFAULT_MODEL,
temperature: float = 0.8,
max_tokens: int = 2000,
max_retries: int = 3,
retry_delay: float = 1.0,
) -> str:
"""
调用 Qwen API 的通用封装函数。
设计思路:
- 使用 OpenAI 兼容格式,方便后续切换模型
- 内置指数退避重试机制,应对网络波动和限流
- 返回纯文本内容,JSON 解析交给调用方处理
Args:
messages: OpenAI 格式的消息列表 [{"role": "system", "content": "..."}, ...]
model: 模型名称,默认 qwen-plus
temperature: 生成温度,越高越有创意(0.0-2.0)
max_tokens: 最大生成 token 数
max_retries: 最大重试次数
retry_delay: 初始重试间隔(秒),每次翻倍
Returns:
模型生成的文本内容
Raises:
Exception: 重试耗尽后抛出最后一次异常
"""
client = get_client()
last_exception: Optional[Exception] = None
for attempt in range(1, max_retries + 1):
try:
logger.info(f"调用 Qwen API (尝试 {attempt}/{max_retries}),模型: {model}")
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.choices[0].message.content.strip()
logger.info(f"API 调用成功,响应长度: {len(content)} 字符")
return content
except Exception as e:
last_exception = e
logger.warning(f"API 调用失败 (尝试 {attempt}/{max_retries}): {e}")
if attempt < max_retries:
sleep_time = retry_delay * (2 ** (attempt - 1))
logger.info(f"等待 {sleep_time:.1f} 秒后重试...")
time.sleep(sleep_time)
# 重试耗尽,抛出异常
raise RuntimeError(
f"Qwen API 调用在 {max_retries} 次尝试后仍然失败: {last_exception}"
)
def call_qwen_stream(
messages: list[dict[str, str]],
model: str = DEFAULT_MODEL,
temperature: float = 0.8,
max_tokens: int = 2000,
):
"""
调用 Qwen API 的流式版本,逐块 yield 文本内容。
使用 stream=True,让用户在 AI 生成过程中就能看到文字逐步出现,
大幅改善感知延迟。
Args:
messages: OpenAI 格式的消息列表
model: 模型名称
temperature: 生成温度
max_tokens: 最大生成 token 数
Yields:
每次生成的文本片段(str)
"""
client = get_client()
logger.info(f"调用 Qwen 流式 API,模型: {model}")
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"流式 API 调用失败: {e}")
raise
# ============================================================
# JSON 安全解析工具
# ============================================================
def extract_json_from_text(text: str) -> Optional[dict | list]:
"""
从 LLM 输出的文本中提取 JSON 数据。
设计思路:
LLM 有时会在 JSON 前后附加说明文字,或使用 ```json 代码块包裹。
此函数通过多种策略尝试提取有效 JSON:
1. 先尝试直接解析整段文本
2. 再尝试提取 ```json ... ``` 代码块
3. 最后尝试匹配第一个 { ... } 或 [ ... ] 结构
Args:
text: LLM 返回的原始文本
Returns:
解析后的 dict/list,解析失败返回 None
"""
if not text:
return None
# 策略1: 直接解析(LLM 可能返回纯 JSON)
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# 策略2: 提取 ```json ... ``` 代码块
code_block_pattern = r"```(?:json)?\s*\n?(.*?)\n?\s*```"
matches = re.findall(code_block_pattern, text, re.DOTALL)
for match in matches:
try:
return json.loads(match.strip())
except json.JSONDecodeError:
continue
# 策略3: 匹配第一个完整的 JSON 对象 { ... }
brace_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
brace_matches = re.findall(brace_pattern, text, re.DOTALL)
for match in brace_matches:
try:
return json.loads(match)
except json.JSONDecodeError:
continue
# 策略4: 匹配嵌套更深的 JSON(贪婪匹配从第一个 { 到最后一个 })
deep_match = re.search(r"\{.*\}", text, re.DOTALL)
if deep_match:
try:
return json.loads(deep_match.group())
except json.JSONDecodeError:
pass
# 策略5: 匹配 JSON 数组 [ ... ]
array_match = re.search(r"\[.*\]", text, re.DOTALL)
if array_match:
try:
return json.loads(array_match.group())
except json.JSONDecodeError:
pass
logger.warning(f"无法从文本中提取 JSON: {text[:200]}...")
return None
def safe_json_call(
messages: list[dict[str, str]],
model: str = DEFAULT_MODEL,
temperature: float = 0.3,
max_tokens: int = 2000,
max_retries: int = 3,
) -> Optional[dict | list]:
"""
调用 Qwen API 并安全地解析返回的 JSON。
设计思路:
- 将 API 调用与 JSON 解析合为一步
- 如果第一次解析失败,会额外重试(重新调用 API)
- temperature 默认较低 (0.3),让 JSON 输出更稳定
Args:
messages: 消息列表
model: 模型名称
temperature: 生成温度(JSON 输出建议低温)
max_tokens: 最大 token 数
max_retries: JSON 解析失败时的额外重试次数
Returns:
解析后的 dict/list,全部失败返回 None
"""
for attempt in range(1, max_retries + 1):
try:
raw_text = call_qwen(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
result = extract_json_from_text(raw_text)
if result is not None:
return result
logger.warning(
f"JSON 解析失败 (尝试 {attempt}/{max_retries}),原始文本: {raw_text[:300]}..."
)
except Exception as e:
logger.error(f"safe_json_call 异常 (尝试 {attempt}/{max_retries}): {e}")
logger.error(f"safe_json_call 在 {max_retries} 次尝试后仍无法获取有效 JSON")
return None
# ============================================================
# 辅助工具函数
# ============================================================
def clamp(value: int, min_val: int, max_val: int) -> int:
"""将数值限制在 [min_val, max_val] 范围内"""
return max(min_val, min(max_val, value))
def format_dict_for_prompt(data: dict, indent: int = 0) -> str:
"""
将字典格式化为易读的 Prompt 文本。
用于将状态数据注入 System Prompt。
"""
lines = []
prefix = " " * indent
for key, value in data.items():
if isinstance(value, dict):
lines.append(f"{prefix}{key}:")
lines.append(format_dict_for_prompt(value, indent + 1))
elif isinstance(value, list):
if value:
lines.append(f"{prefix}{key}: {', '.join(str(v) for v in value)}")
else:
lines.append(f"{prefix}{key}: 无")
else:
lines.append(f"{prefix}{key}: {value}")
return "\n".join(lines)
|