Spaces:
Running
Running
ego
fix: improve graph generation reliability (strip thinking blocks, add retry, robust DOT extraction)
1fe2fca | import re | |
| from core.models import get_llm | |
| from langchain_core.output_parsers import StrOutputParser | |
| from prompts import GRAPH_PROMPT | |
| def _extract_dot(raw: str) -> str: | |
| """从模型输出中提取干净的 DOT 代码。 | |
| 处理以下情况: | |
| 1. <think>...</think> reasoning 块 (enable_thinking=True 时产生) | |
| 2. ```dot ... ``` 或 ``` ... ``` Markdown 代码块 | |
| 3. 纯文本中直接包含 digraph {...} | |
| """ | |
| # 1. 剥离 <think>...</think> 块(贪婪匹配,跨行) | |
| raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip() | |
| # 2. 优先提取 Markdown 代码块内容 | |
| md_match = re.search(r"```(?:dot)?\s*(digraph\s+\w+\s*\{.*?\})\s*```", raw, re.DOTALL) | |
| if md_match: | |
| return md_match.group(1).strip() | |
| # 3. 直接提取 digraph { ... } 块(处理嵌套花括号) | |
| digraph_match = re.search(r"(digraph\s+\w+\s*\{)", raw) | |
| if digraph_match: | |
| start = digraph_match.start() | |
| depth = 0 | |
| for i, ch in enumerate(raw[start:], start=start): | |
| if ch == "{": | |
| depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| return raw[start : i + 1].strip() | |
| # 4. 兜底:清理 Markdown 标记后返回原文 | |
| cleaned = raw.replace("```dot", "").replace("```", "").strip() | |
| if "digraph" not in cleaned: | |
| cleaned = f"digraph G {{ {cleaned} }}" | |
| return cleaned | |
| class KnowledgeGraphGenerator: | |
| def __init__(self): | |
| self.llm = get_llm() | |
| def generate_graph(self, text: str, max_retries: int = 3) -> str: | |
| """生成 DOT 格式知识图谱,失败时自动重试。""" | |
| chain = GRAPH_PROMPT | self.llm | StrOutputParser() | |
| last_error = None | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| raw = chain.invoke({"text": text}) | |
| dot_code = _extract_dot(raw) | |
| # 基本合法性校验:必须包含 digraph 和至少一条边 | |
| if "digraph" in dot_code and "->" in dot_code: | |
| return dot_code | |
| last_error = f"Attempt {attempt}: DOT output invalid (no edges found)." | |
| except Exception as e: | |
| last_error = f"Attempt {attempt}: {e}" | |
| # 全部重试失败,返回最小合法占位图 | |
| return ( | |
| 'digraph G {\n' | |
| ' rankdir=LR;\n' | |
| ' node [style="filled", fillcolor="#FFEBEE", shape="box"];\n' | |
| f' "Generation Failed" -> "Please Retry" [label="error"];\n' | |
| f' "Reason" [label="{last_error[:80]}"];\n' | |
| '}' | |
| ) | |