File size: 2,669 Bytes
4822069
ea8f8db
 
ba7bcd3
ea8f8db
1fe2fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea8f8db
 
1fe2fca
ea8f8db
1fe2fca
 
ba7bcd3
1fe2fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import re
from core.models import get_llm
from langchain_core.output_parsers import StrOutputParser
from prompts import GRAPH_PROMPT


def _extract_dot(raw: str) -> str:
    """从模型输出中提取干净的 DOT 代码。

    处理以下情况:
    1. <think>...</think> reasoning 块 (enable_thinking=True 时产生)
    2. ```dot ... ``` 或 ``` ... ``` Markdown 代码块
    3. 纯文本中直接包含 digraph {...}
    """
    # 1. 剥离 <think>...</think> 块(贪婪匹配,跨行)
    raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()

    # 2. 优先提取 Markdown 代码块内容
    md_match = re.search(r"```(?:dot)?\s*(digraph\s+\w+\s*\{.*?\})\s*```", raw, re.DOTALL)
    if md_match:
        return md_match.group(1).strip()

    # 3. 直接提取 digraph { ... } 块(处理嵌套花括号)
    digraph_match = re.search(r"(digraph\s+\w+\s*\{)", raw)
    if digraph_match:
        start = digraph_match.start()
        depth = 0
        for i, ch in enumerate(raw[start:], start=start):
            if ch == "{":
                depth += 1
            elif ch == "}":
                depth -= 1
                if depth == 0:
                    return raw[start : i + 1].strip()

    # 4. 兜底:清理 Markdown 标记后返回原文
    cleaned = raw.replace("```dot", "").replace("```", "").strip()
    if "digraph" not in cleaned:
        cleaned = f"digraph G {{ {cleaned} }}"
    return cleaned


class KnowledgeGraphGenerator:
    def __init__(self):
        self.llm = get_llm()

    def generate_graph(self, text: str, max_retries: int = 3) -> str:
        """生成 DOT 格式知识图谱,失败时自动重试。"""
        chain = GRAPH_PROMPT | self.llm | StrOutputParser()

        last_error = None
        for attempt in range(1, max_retries + 1):
            try:
                raw = chain.invoke({"text": text})
                dot_code = _extract_dot(raw)

                # 基本合法性校验:必须包含 digraph 和至少一条边
                if "digraph" in dot_code and "->" in dot_code:
                    return dot_code

                last_error = f"Attempt {attempt}: DOT output invalid (no edges found)."
            except Exception as e:
                last_error = f"Attempt {attempt}: {e}"

        # 全部重试失败,返回最小合法占位图
        return (
            'digraph G {\n'
            '    rankdir=LR;\n'
            '    node [style="filled", fillcolor="#FFEBEE", shape="box"];\n'
            f'    "Generation Failed" -> "Please Retry" [label="error"];\n'
            f'    "Reason" [label="{last_error[:80]}"];\n'
            '}'
        )