Story_Weaver / utils.py
wzh0617's picture
Upload 12 files
8bdaafd verified
"""
utils.py - StoryWeaver 工具函数模块
职责:
1. 加载环境变量,初始化 OpenAI 兼容客户端 (Qwen API)
2. 提供通用的 API 调用封装函数(带重试机制)
3. 提供 JSON 安全解析工具(从 LLM 输出中提取结构化数据)
"""
import os
import re
import json
import time
import logging
from typing import Any, Optional
from dotenv import load_dotenv
try:
from openai import OpenAI
_OPENAI_IMPORT_ERROR: Optional[Exception] = None
except ImportError as exc: # pragma: no cover - depends on local env
OpenAI = None # type: ignore[assignment]
_OPENAI_IMPORT_ERROR = exc
# ============================================================
# 日志配置
# ============================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("StoryWeaver")
# ============================================================
# 环境变量加载 & API 客户端初始化
# ============================================================
# 从项目根目录的 .env 文件加载环境变量
load_dotenv()
# 严禁硬编码 API Key —— 仅通过环境变量读取
QWEN_API_KEY: str = os.getenv("QWEN_API_KEY", "")
if not QWEN_API_KEY or QWEN_API_KEY == "sk-xxxxxx":
logger.warning(
"⚠️ QWEN_API_KEY 未设置或仍为模板值!"
"请在 .env 文件中填写有效的 API Key。"
)
# 使用 OpenAI 兼容格式连接 Qwen API
# base_url 指向通义千问的 OpenAI 兼容端点
_client: Optional[Any] = None
def get_client() -> Any:
"""
获取全局 OpenAI 客户端(懒加载单例)。
使用兼容格式调用 Qwen API。
"""
global _client
if OpenAI is None:
raise RuntimeError(
"未安装 openai 依赖,无法初始化 Qwen 客户端。"
"请先执行 `pip install -r requirements.txt`。"
) from _OPENAI_IMPORT_ERROR
if _client is None:
_client = OpenAI(
api_key=QWEN_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
return _client
# ============================================================
# 默认模型配置
# ============================================================
# 使用 qwen2.5-14b-instruct 以获得最快的响应速度
DEFAULT_MODEL: str = "qwen2.5-14b-instruct"
# ============================================================
# 通用 API 调用封装(带重试 & 错误处理)
# ============================================================
def call_qwen(
messages: list[dict[str, str]],
model: str = DEFAULT_MODEL,
temperature: float = 0.8,
max_tokens: int = 2000,
max_retries: int = 3,
retry_delay: float = 1.0,
) -> str:
"""
调用 Qwen API 的通用封装函数。
设计思路:
- 使用 OpenAI 兼容格式,方便后续切换模型
- 内置指数退避重试机制,应对网络波动和限流
- 返回纯文本内容,JSON 解析交给调用方处理
Args:
messages: OpenAI 格式的消息列表 [{"role": "system", "content": "..."}, ...]
model: 模型名称,默认 qwen-plus
temperature: 生成温度,越高越有创意(0.0-2.0)
max_tokens: 最大生成 token 数
max_retries: 最大重试次数
retry_delay: 初始重试间隔(秒),每次翻倍
Returns:
模型生成的文本内容
Raises:
Exception: 重试耗尽后抛出最后一次异常
"""
client = get_client()
last_exception: Optional[Exception] = None
for attempt in range(1, max_retries + 1):
try:
logger.info(f"调用 Qwen API (尝试 {attempt}/{max_retries}),模型: {model}")
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.choices[0].message.content.strip()
logger.info(f"API 调用成功,响应长度: {len(content)} 字符")
return content
except Exception as e:
last_exception = e
logger.warning(f"API 调用失败 (尝试 {attempt}/{max_retries}): {e}")
if attempt < max_retries:
sleep_time = retry_delay * (2 ** (attempt - 1))
logger.info(f"等待 {sleep_time:.1f} 秒后重试...")
time.sleep(sleep_time)
# 重试耗尽,抛出异常
raise RuntimeError(
f"Qwen API 调用在 {max_retries} 次尝试后仍然失败: {last_exception}"
)
def call_qwen_stream(
messages: list[dict[str, str]],
model: str = DEFAULT_MODEL,
temperature: float = 0.8,
max_tokens: int = 2000,
):
"""
调用 Qwen API 的流式版本,逐块 yield 文本内容。
使用 stream=True,让用户在 AI 生成过程中就能看到文字逐步出现,
大幅改善感知延迟。
Args:
messages: OpenAI 格式的消息列表
model: 模型名称
temperature: 生成温度
max_tokens: 最大生成 token 数
Yields:
每次生成的文本片段(str)
"""
client = get_client()
logger.info(f"调用 Qwen 流式 API,模型: {model}")
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"流式 API 调用失败: {e}")
raise
# ============================================================
# JSON 安全解析工具
# ============================================================
def extract_json_from_text(text: str) -> Optional[dict | list]:
"""
从 LLM 输出的文本中提取 JSON 数据。
设计思路:
LLM 有时会在 JSON 前后附加说明文字,或使用 ```json 代码块包裹。
此函数通过多种策略尝试提取有效 JSON:
1. 先尝试直接解析整段文本
2. 再尝试提取 ```json ... ``` 代码块
3. 最后尝试匹配第一个 { ... } 或 [ ... ] 结构
Args:
text: LLM 返回的原始文本
Returns:
解析后的 dict/list,解析失败返回 None
"""
if not text:
return None
# 策略1: 直接解析(LLM 可能返回纯 JSON)
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# 策略2: 提取 ```json ... ``` 代码块
code_block_pattern = r"```(?:json)?\s*\n?(.*?)\n?\s*```"
matches = re.findall(code_block_pattern, text, re.DOTALL)
for match in matches:
try:
return json.loads(match.strip())
except json.JSONDecodeError:
continue
# 策略3: 匹配第一个完整的 JSON 对象 { ... }
brace_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
brace_matches = re.findall(brace_pattern, text, re.DOTALL)
for match in brace_matches:
try:
return json.loads(match)
except json.JSONDecodeError:
continue
# 策略4: 匹配嵌套更深的 JSON(贪婪匹配从第一个 { 到最后一个 })
deep_match = re.search(r"\{.*\}", text, re.DOTALL)
if deep_match:
try:
return json.loads(deep_match.group())
except json.JSONDecodeError:
pass
# 策略5: 匹配 JSON 数组 [ ... ]
array_match = re.search(r"\[.*\]", text, re.DOTALL)
if array_match:
try:
return json.loads(array_match.group())
except json.JSONDecodeError:
pass
logger.warning(f"无法从文本中提取 JSON: {text[:200]}...")
return None
def safe_json_call(
messages: list[dict[str, str]],
model: str = DEFAULT_MODEL,
temperature: float = 0.3,
max_tokens: int = 2000,
max_retries: int = 3,
) -> Optional[dict | list]:
"""
调用 Qwen API 并安全地解析返回的 JSON。
设计思路:
- 将 API 调用与 JSON 解析合为一步
- 如果第一次解析失败,会额外重试(重新调用 API)
- temperature 默认较低 (0.3),让 JSON 输出更稳定
Args:
messages: 消息列表
model: 模型名称
temperature: 生成温度(JSON 输出建议低温)
max_tokens: 最大 token 数
max_retries: JSON 解析失败时的额外重试次数
Returns:
解析后的 dict/list,全部失败返回 None
"""
for attempt in range(1, max_retries + 1):
try:
raw_text = call_qwen(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
result = extract_json_from_text(raw_text)
if result is not None:
return result
logger.warning(
f"JSON 解析失败 (尝试 {attempt}/{max_retries}),原始文本: {raw_text[:300]}..."
)
except Exception as e:
logger.error(f"safe_json_call 异常 (尝试 {attempt}/{max_retries}): {e}")
logger.error(f"safe_json_call 在 {max_retries} 次尝试后仍无法获取有效 JSON")
return None
# ============================================================
# 辅助工具函数
# ============================================================
def clamp(value: int, min_val: int, max_val: int) -> int:
"""将数值限制在 [min_val, max_val] 范围内"""
return max(min_val, min(max_val, value))
def format_dict_for_prompt(data: dict, indent: int = 0) -> str:
"""
将字典格式化为易读的 Prompt 文本。
用于将状态数据注入 System Prompt。
"""
lines = []
prefix = " " * indent
for key, value in data.items():
if isinstance(value, dict):
lines.append(f"{prefix}{key}:")
lines.append(format_dict_for_prompt(value, indent + 1))
elif isinstance(value, list):
if value:
lines.append(f"{prefix}{key}: {', '.join(str(v) for v in value)}")
else:
lines.append(f"{prefix}{key}: 无")
else:
lines.append(f"{prefix}{key}: {value}")
return "\n".join(lines)