File size: 10,893 Bytes
9e03a34
 
 
 
 
 
 
 
 
 
 
 
4998893
 
 
 
 
 
 
 
 
 
 
9e03a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4998893
 
 
 
9e03a34
 
4998893
 
 
 
 
 
 
 
 
 
9e03a34
 
 
 
 
 
 
 
 
4998893
 
9e03a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9942d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e03a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""

utils.py - StoryWeaver 工具函数模块



职责:

1. 加载环境变量,初始化 OpenAI 兼容客户端 (Qwen API)

2. 提供通用的 API 调用封装函数(带重试机制)

3. 提供 JSON 安全解析工具(从 LLM 输出中提取结构化数据)

"""

import os
import re
import json
import time
import logging
from typing import Any, Optional
from dotenv import load_dotenv

try:
    from openai import OpenAI
    _OPENAI_IMPORT_ERROR: Optional[Exception] = None
except ImportError as exc:  # pragma: no cover - depends on local env
    OpenAI = None  # type: ignore[assignment]
    _OPENAI_IMPORT_ERROR = exc

# ============================================================
# 日志配置
# ============================================================
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("StoryWeaver")

# ============================================================
# 环境变量加载 & API 客户端初始化
# ============================================================

# 从项目根目录的 .env 文件加载环境变量
load_dotenv()

# 严禁硬编码 API Key —— 仅通过环境变量读取
QWEN_API_KEY: str = os.getenv("QWEN_API_KEY", "")

if not QWEN_API_KEY or QWEN_API_KEY == "sk-xxxxxx":
    logger.warning(
        "⚠️  QWEN_API_KEY 未设置或仍为模板值!"
        "请在 .env 文件中填写有效的 API Key。"
    )

# 使用 OpenAI 兼容格式连接 Qwen API
# base_url 指向通义千问的 OpenAI 兼容端点
_client: Optional[Any] = None


def get_client() -> Any:
    """

    获取全局 OpenAI 客户端(懒加载单例)。

    使用兼容格式调用 Qwen API。

    """
    global _client
    if OpenAI is None:
        raise RuntimeError(
            "未安装 openai 依赖,无法初始化 Qwen 客户端。"
            "请先执行 `pip install -r requirements.txt`。"
        ) from _OPENAI_IMPORT_ERROR
    if _client is None:
        _client = OpenAI(
            api_key=QWEN_API_KEY,
            base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
        )
    return _client


# ============================================================
# 默认模型配置
# ============================================================
# 使用 qwen2.5-14b-instruct 以获得最快的响应速度
DEFAULT_MODEL: str = "qwen2.5-14b-instruct"


# ============================================================
# 通用 API 调用封装(带重试 & 错误处理)
# ============================================================


def call_qwen(

    messages: list[dict[str, str]],

    model: str = DEFAULT_MODEL,

    temperature: float = 0.8,

    max_tokens: int = 2000,

    max_retries: int = 3,

    retry_delay: float = 1.0,

) -> str:
    """

    调用 Qwen API 的通用封装函数。



    设计思路:

    - 使用 OpenAI 兼容格式,方便后续切换模型

    - 内置指数退避重试机制,应对网络波动和限流

    - 返回纯文本内容,JSON 解析交给调用方处理



    Args:

        messages: OpenAI 格式的消息列表 [{"role": "system", "content": "..."}, ...]

        model: 模型名称,默认 qwen-plus

        temperature: 生成温度,越高越有创意(0.0-2.0)

        max_tokens: 最大生成 token 数

        max_retries: 最大重试次数

        retry_delay: 初始重试间隔(秒),每次翻倍



    Returns:

        模型生成的文本内容



    Raises:

        Exception: 重试耗尽后抛出最后一次异常

    """
    client = get_client()
    last_exception: Optional[Exception] = None

    for attempt in range(1, max_retries + 1):
        try:
            logger.info(f"调用 Qwen API (尝试 {attempt}/{max_retries}),模型: {model}")
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            content = response.choices[0].message.content.strip()
            logger.info(f"API 调用成功,响应长度: {len(content)} 字符")
            return content

        except Exception as e:
            last_exception = e
            logger.warning(f"API 调用失败 (尝试 {attempt}/{max_retries}): {e}")
            if attempt < max_retries:
                sleep_time = retry_delay * (2 ** (attempt - 1))
                logger.info(f"等待 {sleep_time:.1f} 秒后重试...")
                time.sleep(sleep_time)

    # 重试耗尽,抛出异常
    raise RuntimeError(
        f"Qwen API 调用在 {max_retries} 次尝试后仍然失败: {last_exception}"
    )


def call_qwen_stream(

    messages: list[dict[str, str]],

    model: str = DEFAULT_MODEL,

    temperature: float = 0.8,

    max_tokens: int = 2000,

):
    """

    调用 Qwen API 的流式版本,逐块 yield 文本内容。



    使用 stream=True,让用户在 AI 生成过程中就能看到文字逐步出现,

    大幅改善感知延迟。



    Args:

        messages: OpenAI 格式的消息列表

        model: 模型名称

        temperature: 生成温度

        max_tokens: 最大生成 token 数



    Yields:

        每次生成的文本片段(str)

    """
    client = get_client()
    logger.info(f"调用 Qwen 流式 API,模型: {model}")
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stream=True,
        )
        for chunk in response:
            if chunk.choices and chunk.choices[0].delta.content:
                yield chunk.choices[0].delta.content
    except Exception as e:
        logger.error(f"流式 API 调用失败: {e}")
        raise


# ============================================================
# JSON 安全解析工具
# ============================================================


def extract_json_from_text(text: str) -> Optional[dict | list]:
    """

    从 LLM 输出的文本中提取 JSON 数据。



    设计思路:

    LLM 有时会在 JSON 前后附加说明文字,或使用 ```json 代码块包裹。

    此函数通过多种策略尝试提取有效 JSON:

    1. 先尝试直接解析整段文本

    2. 再尝试提取 ```json ... ``` 代码块

    3. 最后尝试匹配第一个 { ... } 或 [ ... ] 结构



    Args:

        text: LLM 返回的原始文本



    Returns:

        解析后的 dict/list,解析失败返回 None

    """
    if not text:
        return None

    # 策略1: 直接解析(LLM 可能返回纯 JSON)
    try:
        return json.loads(text.strip())
    except json.JSONDecodeError:
        pass

    # 策略2: 提取 ```json ... ``` 代码块
    code_block_pattern = r"```(?:json)?\s*\n?(.*?)\n?\s*```"
    matches = re.findall(code_block_pattern, text, re.DOTALL)
    for match in matches:
        try:
            return json.loads(match.strip())
        except json.JSONDecodeError:
            continue

    # 策略3: 匹配第一个完整的 JSON 对象 { ... }
    brace_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
    brace_matches = re.findall(brace_pattern, text, re.DOTALL)
    for match in brace_matches:
        try:
            return json.loads(match)
        except json.JSONDecodeError:
            continue

    # 策略4: 匹配嵌套更深的 JSON(贪婪匹配从第一个 { 到最后一个 })
    deep_match = re.search(r"\{.*\}", text, re.DOTALL)
    if deep_match:
        try:
            return json.loads(deep_match.group())
        except json.JSONDecodeError:
            pass

    # 策略5: 匹配 JSON 数组 [ ... ]
    array_match = re.search(r"\[.*\]", text, re.DOTALL)
    if array_match:
        try:
            return json.loads(array_match.group())
        except json.JSONDecodeError:
            pass

    logger.warning(f"无法从文本中提取 JSON: {text[:200]}...")
    return None


def safe_json_call(

    messages: list[dict[str, str]],

    model: str = DEFAULT_MODEL,

    temperature: float = 0.3,

    max_tokens: int = 2000,

    max_retries: int = 3,

) -> Optional[dict | list]:
    """

    调用 Qwen API 并安全地解析返回的 JSON。



    设计思路:

    - 将 API 调用与 JSON 解析合为一步

    - 如果第一次解析失败,会额外重试(重新调用 API)

    - temperature 默认较低 (0.3),让 JSON 输出更稳定



    Args:

        messages: 消息列表

        model: 模型名称

        temperature: 生成温度(JSON 输出建议低温)

        max_tokens: 最大 token 数

        max_retries: JSON 解析失败时的额外重试次数



    Returns:

        解析后的 dict/list,全部失败返回 None

    """
    for attempt in range(1, max_retries + 1):
        try:
            raw_text = call_qwen(
                messages=messages,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            result = extract_json_from_text(raw_text)
            if result is not None:
                return result
            logger.warning(
                f"JSON 解析失败 (尝试 {attempt}/{max_retries}),原始文本: {raw_text[:300]}..."
            )
        except Exception as e:
            logger.error(f"safe_json_call 异常 (尝试 {attempt}/{max_retries}): {e}")

    logger.error(f"safe_json_call 在 {max_retries} 次尝试后仍无法获取有效 JSON")
    return None


# ============================================================
# 辅助工具函数
# ============================================================


def clamp(value: int, min_val: int, max_val: int) -> int:
    """将数值限制在 [min_val, max_val] 范围内"""
    return max(min_val, min(max_val, value))


def format_dict_for_prompt(data: dict, indent: int = 0) -> str:
    """

    将字典格式化为易读的 Prompt 文本。

    用于将状态数据注入 System Prompt。

    """
    lines = []
    prefix = "  " * indent
    for key, value in data.items():
        if isinstance(value, dict):
            lines.append(f"{prefix}{key}:")
            lines.append(format_dict_for_prompt(value, indent + 1))
        elif isinstance(value, list):
            if value:
                lines.append(f"{prefix}{key}: {', '.join(str(v) for v in value)}")
            else:
                lines.append(f"{prefix}{key}: 无")
        else:
            lines.append(f"{prefix}{key}: {value}")
    return "\n".join(lines)