import json import os import re import base64 from PIL import Image from typing import Optional from openai import OpenAI from google import genai from google.genai import types # ========== 调用 Gemini 删除无用段落 ========== def clean_paper(markdown_path, clean_prompt, model, config): """ 使用 Google Gemini 清理论文 Markdown 文件: 删除 Abstract / Related Work / Appendix / References 等部分, 保留标题、作者、Introduction、Methods、Experiments、Conclusion。 """ raw_url = config.get('api_base_url', '').strip().rstrip("/") if raw_url.endswith("/v1"): base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1 else: base_url = raw_url client = genai.Client( api_key=config['api_keys']['gemini_api_key'], http_options={'base_url': base_url} if base_url else None ) # === 读取 markdown 文件 === with open(markdown_path, "r", encoding="utf-8") as f: md_text = f.read().strip() full_prompt = ( f"{clean_prompt}\n\n" "=== PAPER MARKDOWN TO CLEAN ===\n" f"\"\"\"{md_text}\"\"\"\n\n" "Return only the cleaned markdown, keeping all formatting identical to the original." ) print("🧹 Sending markdown to Gemini for cleaning...") try: # === 调用 Gemini API (Client 模式) === resp = client.models.generate_content( model=model, contents=full_prompt, config=types.GenerateContentConfig( temperature=0.0 ) ) cleaned_text = resp.text.strip() except Exception as e: print(f"❌ Gemini API Error: {e}") return None # === 提取纯 markdown(防止模型返回 ```markdown ``` 块) === m = re.search(r"```markdown\s*([\s\S]*?)```", cleaned_text) if not m: m = re.search(r"```\s*([\s\S]*?)```", cleaned_text) cleaned_text = m.group(1).strip() if m else cleaned_text # === 生成输出文件路径 === dir_path = os.path.dirname(markdown_path) base_name = os.path.basename(markdown_path) name, ext = os.path.splitext(base_name) output_path = os.path.join(dir_path, f"{name}_cleaned{ext}") # === 保存结果 === with open(output_path, "w", encoding="utf-8") as f: f.write(cleaned_text) print(f"✅ Cleaned markdown saved to: {output_path}") return output_path # ========== 调用 Gemini 划分段落 ========== SECTION_RE = re.compile(r'^#\s+\d+(\s|$)', re.MULTILINE) def sanitize_filename(name: str) -> str: """移除非法字符:/ \ : * ? \" < > |""" unsafe = r'\/:*?"<>|' return "".join(c for c in name if c not in unsafe).strip() def split_paper( cleaned_md_path: str, # 输入: A/auto/clean_paper.md prompt: str, separator: str = "===SPLIT===", model: str = "gemini-3.0-pro-preview", config: dict = None ): """ 使用 Gemini 拆分论文,并将所有拆分后的 markdown 保存在: /section_split_output/ """ # 1️⃣ 输入文件所在的 auto 文件夹 auto_dir = os.path.dirname(os.path.abspath(cleaned_md_path)) # 2️⃣ auto 的上一级目录(即 A/) parent_dir = os.path.dirname(auto_dir) # 3️⃣ 最终输出目录 A/section_split_output/ output_dir = os.path.join(parent_dir, "section_split_output") os.makedirs(output_dir, exist_ok=True) # === 读取 markdown === with open(cleaned_md_path, "r", encoding="utf-8") as f: markdown_text = f.read() # === 2. 初始化 Gemini Client === raw_url = config.get('api_base_url', '').strip().rstrip("/") if raw_url.endswith("/v1"): base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1 else: base_url = raw_url client = genai.Client( api_key=config['api_keys']['gemini_api_key'], http_options={'base_url': base_url} if base_url else None ) # === 提取一级 section 信息供参考 (假设 SECTION_RE 已在外部定义) === # 注意:确保 SECTION_RE 在此函数作用域内可用 section_positions = [(m.start(), m.group()) for m in SECTION_RE.finditer(markdown_text)] auto_analysis = "Detected top-level sections:\n" for pos, sec in section_positions: auto_analysis += f"- position={pos}, heading='{sec.strip()}'\n" auto_analysis += "\nThese are for your reference. You MUST still split strictly by the rules.\n" # === 构建 prompt === final_prompt = ( prompt + "\n\n---\nBelow is an automatic analysis of top-level sections (for reference only):\n" + auto_analysis + "\n---\nHere is the FULL MARKDOWN PAPER:\n\n" + markdown_text ) # === 3. Gemini 调用 === try: response = client.models.generate_content( model=model, contents=final_prompt, config=types.GenerateContentConfig( temperature=0.0 ) ) output_text = response.text except Exception as e: print(f"❌ Gemini Split Error: {e}") return [] # === 按分隔符拆分 === # 简单的容错处理,防止模型没有完全按格式输出 if not output_text: print("❌ Empty response from Gemini") return [] chunks = [c.strip() for c in output_text.split(separator) if c.strip()] saved_paths = [] # === 保存拆分后的 chunks === for chunk in chunks: lines = chunk.splitlines() # 找第一行有效内容 first_line = next((ln.strip() for ln in lines if ln.strip()), "") # 解析标题 if first_line.startswith("#"): title = first_line.lstrip("#").strip() else: title = first_line[:20].strip() # 注意:sanitize_filename 需要在外部定义或引入 filename = sanitize_filename(title) + ".md" filepath = os.path.join(output_dir, filename) # 写文件 with open(filepath, "w", encoding="utf-8") as f: f.write(chunk) saved_paths.append(filepath) print(f"✅ Paper is splitted successfully") return saved_paths # ========== 调用 Gemini 初始化dag.json ========== def initialize_dag(markdown_path, initialize_dag_prompt, model, config=None): """ 使用 Gemini 初始化论文 DAG。 输入: markdown_path: markdown 文件路径 initialize_dag_prompt: prompt 字符串 model: 模型名称 (建议使用 gemini-2.0-flash 或 pro) config: 包含 api_keys 的配置字典 输出: dag.json: 保存在 markdown 文件同目录 返回 python 字典形式的 DAG """ # --- load markdown --- if not os.path.exists(markdown_path): raise FileNotFoundError(f"Markdown not found: {markdown_path}") with open(markdown_path, "r", encoding="utf-8") as f: md_text = f.read() # --- Gemini Client Init --- raw_url = config.get('api_base_url', '').strip().rstrip("/") if raw_url.endswith("/v1"): base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1 else: base_url = raw_url client = genai.Client( api_key=config['api_keys']['gemini_api_key'], http_options={'base_url': base_url} if base_url else None ) # --- Gemini Call --- # 将 Prompt 和 文本合并作为用户输入,System Prompt 放入 config full_content = f"{initialize_dag_prompt}\n\n{md_text}" try: response = client.models.generate_content( model=model, contents=full_content, config=types.GenerateContentConfig( system_instruction="You are an expert academic document parser and structural analyzer.", temperature=0.0, response_mime_type="application/json" # <--- 强制输出 JSON 模式 ) ) raw_output = response.text.strip() except Exception as e: print(f"❌ Gemini API Error: {e}") raise e # --- Extract JSON (remove possible markdown fences) --- # Gemini 在 JSON 模式下通常只返回纯 JSON,但保留此逻辑以防万一 cleaned = raw_output # Remove ```json ... ``` if cleaned.startswith("```"): cleaned = cleaned.strip("`") if cleaned.lstrip().startswith("json"): cleaned = cleaned.split("\n", 1)[1] # Last safety: locate JSON via first { and last } try: first = cleaned.index("{") last = cleaned.rindex("}") cleaned = cleaned[first:last+1] except Exception: pass try: dag_data = json.loads(cleaned) except json.JSONDecodeError: print("⚠️ Standard JSON parsing failed. Attempting regex repair for backslashes...") try: # 这里保留原有的重试逻辑 (通常是为了处理转义字符) dag_data = json.loads(cleaned) except json.JSONDecodeError as e: raise ValueError(f"Gemini output is not valid JSON:\n{raw_output}") # --- Save dag.json --- out_dir = os.path.dirname(markdown_path) out_path = os.path.join(out_dir, "dag.json") with open(out_path, "w", encoding="utf-8") as f: json.dump(dag_data, f, indent=4, ensure_ascii=False) print(f"✅ DAG saved to: {out_path}") return dag_data # ========== 调用 大模型 添加视觉结点 ========== def extract_and_generate_visual_dag( markdown_path: str, prompt_for_gpt: str, output_json_path: str, model="gemini-3.0-pro-preview", config=None ): """ 输入: markdown_path: 原论文 markdown 文件路径 prompt_for_gpt: 给 GPT 使用的 prompt output_json_path: 生成的 visual_dag.json 存放路径 model: 默认 gemini-3.0-pro-preview 输出: 生成 visual_dag.json 返回 Python dict """ # === 1. 读取 markdown === if not os.path.exists(markdown_path): raise FileNotFoundError(f"Markdown not found: {markdown_path}") with open(markdown_path, "r", encoding="utf-8") as f: md_text = f.read() # === 2. 正则提取所有图片相对引用 === pattern = r"!\[[^\]]*\]\(([^)]+)\)" matches = re.findall(pattern, md_text) # 过滤为相对路径(不包含http) relative_imgs = [m for m in matches if not m.startswith("http")] # 生成标准格式 name 字段使用的写法 "![](xxx)" normalized_refs = [f"![]({m})" for m in relative_imgs] # === 3. 发送给 Gemini === raw_url = config.get('api_base_url', '').strip().rstrip("/") if raw_url.endswith("/v1"): base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1 else: base_url = raw_url # 初始化 Client client = genai.Client( api_key=config['api_keys']['gemini_api_key'], http_options={'base_url': base_url} if base_url else None ) gpt_input = prompt_for_gpt + "\n\n" + \ "### Extracted Image References:\n" + \ json.dumps(normalized_refs, indent=2) + "\n\n" + \ "### Full Markdown:\n" + md_text try: response = client.models.generate_content( model=model, contents=gpt_input, config=types.GenerateContentConfig( temperature=0.0, response_mime_type="application/json" # 强制 JSON 输出 ) ) visual_dag_str = response.text.strip() except Exception as e: print(f"❌ Gemini API Error: {e}") raise e # === JSON 解析兜底修复逻辑(不做任何语义改写) === def _strip_fenced_code_block(s: str) -> str: s = (s or "").strip() if not s.startswith("```"): return s lines = s.splitlines() if lines and lines[0].strip().startswith("```"): lines = lines[1:] while lines and lines[-1].strip().startswith("```"): lines = lines[:-1] return "\n".join(lines).strip() def _sanitize_json_string_minimal(s: str) -> str: s = (s or "") s = s.replace("\r", " ").replace("\n", " ").replace("\t", " ") s = re.sub(r"\s{2,}", " ", s) return s # 必须把 _remove_one_offending_backslash 函数定义放回来,否则后面调用会报错 def _remove_one_offending_backslash(s: str, err: Exception) -> str: if not isinstance(err, json.JSONDecodeError): return "" msg = str(err) if "Invalid \\escape" not in msg and "Invalid \\u" not in msg: return "" pos = getattr(err, "pos", None) if pos is None or pos <= 0 or pos > len(s): return "" candidates = [] if pos < len(s) and s[pos] == "\\": candidates.append(pos) if pos - 1 >= 0 and s[pos - 1] == "\\": candidates.append(pos - 1) start = max(0, pos - 16) window = s[start:pos + 1] last_bs = window.rfind("\\") if last_bs != -1: candidates.append(start + last_bs) seen = set() candidates = [i for i in candidates if not (i in seen or seen.add(i))] for idx in candidates: if 0 <= idx < len(s) and s[idx] == "\\": return s[:idx] + s[idx + 1:] return "" # 解析 JSON(要求 GPT/Gemini 返回纯 JSON) try: # Gemini 的 response_mime_type 已经很大程度保证了 json,但保留打印方便调试 # print("====== RAW GEMINI OUTPUT ======") # print(visual_dag_str) # print("====== END ======") visual_dag = json.loads(visual_dag_str) except Exception as e1: # 下面的修复逻辑保持原样 try: unwrapped = _strip_fenced_code_block(visual_dag_str) fixed_str = _sanitize_json_string_minimal(unwrapped).strip() if not fixed_str: raise ValueError("Gemini returned empty/whitespace-only JSON content after repair.") try: visual_dag = json.loads(fixed_str) except Exception as e2: working = fixed_str last_err = e2 max_backslash_removals = 50 removed_times = 0 while removed_times < max_backslash_removals: new_working = _remove_one_offending_backslash(working, last_err) if not new_working: break working = new_working removed_times += 1 try: visual_dag = json.loads(working) fixed_str = working last_err = None break except Exception as e_next: last_err = e_next continue if last_err is not None: raise ValueError( "Gemini returned invalid JSON: " + str(e1) + " | After repair still invalid: " + str(e2) + f" | Tried removing offending backslashes up to {max_backslash_removals} times " f"(actually removed {removed_times}) but still invalid: " + str(last_err) ) except Exception as e_final: if isinstance(e_final, ValueError): raise raise ValueError( "Gemini returned invalid JSON: " + str(e1) + " | After repair still invalid: " + str(e_final) ) # === 4. 保存 visual_dag.json === with open(output_json_path, "w", encoding="utf-8") as f: json.dump(visual_dag, f, indent=2, ensure_ascii=False) print(f"\n📂 visual_dag.json is generated successfully") return visual_dag # ========== 计算每一个视觉结点的分辨率 ========== def add_resolution_to_visual_dag(auto_path, visual_dag_path): """ 遍历 visual_dag.json,提取图片路径,计算分辨率并添加到结点属性中。 Args: auto_path (str): 图片所在的根目录路径。 visual_dag_path (str): visual_dag.json 文件的路径。 Returns: list: 更新后的节点列表。 """ # 1. 读取 JSON 文件 try: with open(visual_dag_path, 'r', encoding='utf-8') as f: data = json.load(f) except FileNotFoundError: print(f"错误: 找不到文件 {visual_dag_path}") return [] except json.JSONDecodeError: print(f"错误: 文件 {visual_dag_path} 不是有效的 JSON 格式") return [] nodes = data.get("nodes", []) # 正则表达式用于匹配 ![](path) 中的 path # 解释: !\[\]\((.*?)\) 匹配 ![](...) 括号内的所有内容 pattern = re.compile(r'!\[\]\((.*?)\)') for node in nodes: name_str = node.get("name", "") # 2. 从 name 中提取路径 match = pattern.search(name_str) if match: # 获取括号内的路径部分,例如 images/xxx.jpg relative_image_path = match.group(1) # 3. 拼接完整路径 full_image_path = os.path.join(auto_path, relative_image_path) # 4. 读取图片并计算分辨率 try: # 使用 Pillow 打开图片 with Image.open(full_image_path) as img: width, height = img.size resolution_str = f"{width}x{height}" # 5. 添加 resolution 字段 node["resolution"] = resolution_str # print(f"成功处理: {relative_image_path} -> {resolution_str}") except FileNotFoundError: print(f"警告: 找不到图片文件 {full_image_path},跳过该节点。") node["resolution"] = "Unknown" # 或者可以选择不添加该字段 except Exception as e: print(f"警告: 处理图片 {full_image_path} 时发生错误: {e}") node["resolution"] = "Error" else: print(f"警告: 节点 name 格式不匹配: {name_str}") # (可选) 将更新后的数据写回文件,或者另存为新文件 # 这里演示将数据写回原文件 try: with open(visual_dag_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) print(f"处理完成,已更新文件: {visual_dag_path}") except Exception as e: print(f"保存文件时出错: {e}") return nodes # ========== 调用 gemini-3-pro-preview 生成每一个section_dag ========== def build_section_dags( folder_path: str, base_prompt: str, model: str = "gemini-3.0-pro-preview", # 建议使用 flash 或 pro config: dict = None ): """ Traverse all markdown files in a folder, send each section to Gemini, and save _dag.json. Includes robust JSON repair and retry logic. """ # ----------------------------- # Tunables (safe defaults) # ----------------------------- ENABLE_FALLBACK_CONTENT_BACKSLASH_STRIP = True FALLBACK_STRIP_BACKSLASH_ONLY_IN_CONTENT = True MAX_RETRIES_ON_FAIL = 2 # === Init Client (Gemini) === raw_url = config.get('api_base_url', '').strip().rstrip("/") if raw_url.endswith("/v1"): base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1 else: base_url = raw_url # 使用 config 中的 key client = genai.Client( api_key=config['api_keys']['gemini_api_key'], http_options={'base_url': base_url} if base_url else None ) def build_full_prompt(base_prompt: str, section_name: str, md_text: str) -> str: return ( f"{base_prompt}\n\n" "=== SECTION NAME ===\n" f"{section_name}\n\n" "=== SECTION MARKDOWN (FULL) ===\n" f"\"\"\"{md_text}\"\"\"" ) # === Helper Functions (Keep exactly as is) === def remove_invisible_control_chars(s: str) -> str: if not s: return s s = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", s) s = re.sub(r"[\uFEFF\u200B\u200C\u200D\u2060\u00AD\u061C\u200E\u200F\u202A-\u202E\u2066-\u2069]", "", s) return s def sanitize_json_literal_newlines(s: str) -> str: out = [] in_string = False escape = False for ch in s: if in_string: if escape: out.append(ch); escape = False else: if ch == '\\': out.append(ch); escape = True elif ch == '"': out.append(ch); in_string = False elif ch in ('\n', '\r', '\t'): out.append(' ') else: out.append(ch) else: if ch == '"': out.append(ch); in_string = True; escape = False else: out.append(ch) return ''.join(out) def sanitize_invalid_backslashes_in_strings(s: str) -> str: out = [] in_string = False i = 0 valid_esc = set(['"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u']) while i < len(s): ch = s[i] if not in_string: out.append(ch) if ch == '"': in_string = True i += 1; continue if ch == '"': out.append(ch); in_string = False i += 1; continue if ch != '\\': out.append(ch); i += 1; continue if i == len(s) - 1: out.append('\\\\'); i += 1; continue nxt = s[i + 1] if nxt in valid_esc: out.append('\\'); out.append(nxt); i += 2 else: out.append('\\\\'); out.append(nxt); i += 2 return ''.join(out) def force_content_single_line(dag_obj): if not isinstance(dag_obj, dict): return dag_obj nodes = dag_obj.get("nodes", None) if not isinstance(nodes, list): return dag_obj for node in nodes: if isinstance(node, dict) and "content" in node and isinstance(node["content"], str): node["content"] = re.sub(r"[\r\n]+", " ", node["content"]) return dag_obj def fallback_strip_backslashes_in_content(dag_obj): if not isinstance(dag_obj, dict): return dag_obj nodes = dag_obj.get("nodes", None) if not isinstance(nodes, list): return dag_obj for node in nodes: if isinstance(node, dict) and "content" in node and isinstance(node["content"], str): node["content"] = node["content"].replace("\\", "") return dag_obj def extract_first_json_object_substring(s: str): start = s.find("{") if start < 0: return None in_string = False; escape = False; depth = 0 for i in range(start, len(s)): ch = s[i] if in_string: if escape: escape = False elif ch == "\\": escape = True elif ch == '"': in_string = False else: if ch == '"': in_string = True elif ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: return s[start:i + 1] return None def robust_load_json(raw: str): raw0 = remove_invisible_control_chars(raw) try: return json.loads(raw0), raw0, "A_raw" except json.JSONDecodeError: pass b = sanitize_json_literal_newlines(raw0) try: return json.loads(b), b, "B_newlines_fixed" except json.JSONDecodeError: pass c = sanitize_invalid_backslashes_in_strings(b) try: return json.loads(c), c, "C_backslashes_fixed" except json.JSONDecodeError: pass sub = extract_first_json_object_substring(raw0) if sub: d0 = remove_invisible_control_chars(sub) d1 = sanitize_json_literal_newlines(d0) d2 = sanitize_invalid_backslashes_in_strings(d1) try: return json.loads(d2), d2, "D_extracted_object_repaired" except json.JSONDecodeError: pass return None, raw0, "FAIL" # === Modified: Call Gemini === def call_llm(full_prompt: str) -> str: try: resp = client.models.generate_content( model=model, contents=full_prompt, config=types.GenerateContentConfig( temperature=0.2, # 可以在这里加 response_mime_type="application/json" 进一步增强稳定性 ) ) return resp.text.strip() except Exception as e: print(f"❌ Gemini API Error: {e}") return "" def preprocess_llm_output(raw_content: str) -> str: raw_content = remove_invisible_control_chars(raw_content) fence_match = re.search(r"```(?:json|JSON)?\s*([\s\S]*?)```", raw_content) if fence_match: raw_content = fence_match.group(1).strip() raw_content = remove_invisible_control_chars(raw_content) return raw_content outputs = {} # === Main Loop === if not os.path.exists(folder_path): print(f"❌ Folder not found: {folder_path}") return outputs for filename in os.listdir(folder_path): if not filename.lower().endswith((".md", ".markdown")): continue markdown_path = os.path.join(folder_path, filename) if not os.path.isfile(markdown_path): continue section_name = filename with open(markdown_path, "r", encoding="utf-8") as f: md_text = f.read().strip() full_prompt = build_full_prompt(base_prompt, section_name, md_text) print(f"📐 Sending section '{section_name}' to Gemini for DAG generation...") dag_obj = None used_text = "" stage = "INIT" # Retry Loop for attempt_idx in range(1 + MAX_RETRIES_ON_FAIL): if attempt_idx > 0: print(f"🔁 Retry LLM for section '{section_name}' (retry={attempt_idx}/{MAX_RETRIES_ON_FAIL})...") raw_content = call_llm(full_prompt) if not raw_content: continue # 如果 API 调用报错返回空,直接重试 raw_content = preprocess_llm_output(raw_content) dag_obj, used_text, stage = robust_load_json(raw_content) if dag_obj is not None: break print(f"⚠️ JSON parse failed for section '{section_name}' after repairs. Stage={stage}") if dag_obj is None: print(f"{section_name} 处理失败超过两次,已清除") dag_obj = {} else: dag_obj = force_content_single_line(dag_obj) if ENABLE_FALLBACK_CONTENT_BACKSLASH_STRIP and FALLBACK_STRIP_BACKSLASH_ONLY_IN_CONTENT: if stage in ("D_extracted_object_repaired",): dag_obj = fallback_strip_backslashes_in_content(dag_obj) # Output safe_section_name = re.sub(r"[\\/:*?\"<>|]", "_", section_name) output_filename = f"{safe_section_name}_dag.json" subdir_path = os.path.dirname(folder_path) section_dag_path = os.path.join(subdir_path, "section_dag") os.makedirs(section_dag_path, exist_ok=True) output_path = os.path.join(section_dag_path, output_filename) with open(output_path, "w", encoding="utf-8") as f: json.dump(dag_obj, f, ensure_ascii=False, indent=4) print(f"✅ DAG for section '{section_name}' saved to: {output_path} (parse_stage={stage})") outputs[section_name] = output_path return outputs # ========== 合并 section_dag 到 dag ========== def add_section_dag( section_dag_folder: str, main_dag_path: str, output_path: Optional[str] = None ) -> str: """ Merge all section DAGs under `section_dag_folder` into the main DAG at `main_dag_path`. For each section DAG JSON: - Take its root node name (nodes[0]["name"]) and append that name to the edge list of the main DAG's root node (main_dag["nodes"][0]["edge"]). - Append ALL nodes from that section DAG to the end of main_dag["nodes"], preserving their original order. Compatibility patch: - If a section JSON is a single node object (missing the top-level "nodes" wrapper), automatically wrap it into: {"nodes": []} so the downstream merge logic can proceed. Notes: - This function does NOT call GPT, it only manipulates JSON. - The main DAG is assumed to have the same format: { "nodes": [ { "name": "...", "content": "...", "edge": [], "level": 0 or 1, "visual_node": [] }, ... ] } Args: section_dag_folder: Path to a folder that contains per-section DAG JSON files. main_dag_path: Path to the main DAG JSON file (original). output_path: Path to save the merged DAG. If None, overwrite main_dag_path. Returns: The path of the merged DAG JSON file. """ def _coerce_section_dag_to_nodes_wrapper(obj, section_path: str) -> dict: """ If `obj` is already a valid {"nodes": [...]} dict, return as-is. If `obj` looks like a single node dict (has "name"/"content"/"edge"/etc. but no "nodes"), wrap it into {"nodes": [obj]}. Otherwise, raise ValueError. """ # Case 1: already in expected format if isinstance(obj, dict) and "nodes" in obj: return obj # Case 2: single-node object (missing wrapper) if isinstance(obj, dict) and "nodes" not in obj: # Heuristic: if it has at least "name" and "content" (common node keys), treat it as node. has_name = isinstance(obj.get("name"), str) and obj.get("name").strip() has_content = isinstance(obj.get("content"), str) if has_name and has_content: # Wrap into nodes list return {"nodes": [obj]} raise ValueError( f"Section DAG JSON at '{section_path}' is neither a valid DAG wrapper " f"nor a recognizable single-node object." ) # === Load main DAG === with open(main_dag_path, "r", encoding="utf-8") as f: main_dag = json.load(f) if "nodes" not in main_dag or not isinstance(main_dag["nodes"], list) or len(main_dag["nodes"]) == 0: raise ValueError("main_dag JSON is invalid: missing non-empty 'nodes' array.") # Root node is assumed to be the first node root_node = main_dag["nodes"][0] # Ensure 'edge' field exists and is a list if "edge" not in root_node or not isinstance(root_node["edge"], list): root_node["edge"] = [] # === Traverse section DAG folder === # To keep deterministic order, sort filenames for filename in sorted(os.listdir(section_dag_folder)): # Only process *.json files if not filename.lower().endswith(".json"): continue section_path = os.path.join(section_dag_folder, filename) # Skip if it's the same file as main_dag_path, just in case if os.path.abspath(section_path) == os.path.abspath(main_dag_path): continue if not os.path.isfile(section_path): continue # Load section DAG with open(section_path, "r", encoding="utf-8") as f: try: section_raw = json.load(f) except json.JSONDecodeError as e: raise ValueError(f"Section DAG JSON invalid at '{section_path}': {e}") # NEW: coerce into {"nodes":[...]} if missing wrapper section_dag = _coerce_section_dag_to_nodes_wrapper(section_raw, section_path) # Validate nodes array if "nodes" not in section_dag or not isinstance(section_dag["nodes"], list) or len(section_dag["nodes"]) == 0: raise ValueError(f"Section DAG JSON at '{section_path}' has no valid 'nodes' array.") section_nodes = section_dag["nodes"] section_root = section_nodes[0] # Get section root name section_root_name = section_root.get("name") if not isinstance(section_root_name, str) or not section_root_name.strip(): raise ValueError(f"Section DAG root node at '{section_path}' has invalid or empty 'name'.") # Append section root name into main root's edge # (avoid duplicates, in case of reruns) if section_root_name not in root_node["edge"]: root_node["edge"].append(section_root_name) # Append all section nodes to the end of main_dag["nodes"] main_dag["nodes"].extend(section_nodes) # === Save merged DAG === if output_path is None: output_path = main_dag_path # overwrite by default with open(output_path, "w", encoding="utf-8") as f: json.dump(main_dag, f, ensure_ascii=False, indent=4) return output_path # ========== 向原dag中添加visual_dag ========== def add_visual_dag(dag_path: str, visual_dag_path: str) -> str: """ Append all nodes from a visual DAG JSON file into an existing DAG JSON file. Both JSON files must share the same structure, e.g.: { "nodes": [ { "name": "...", "content": "...", "edge": [], "level": 0, "visual_node": [] } ] } Behavior: - Load the main DAG from `dag_path`. - Load the visual DAG from `visual_dag_path`. - Append ALL nodes from visual_dag["nodes"] to the end of main_dag["nodes"], preserving their original order. - Overwrite `dag_path` with the merged DAG. - Does NOT modify any edge relationships automatically. Args: dag_path: Path to the main DAG JSON (will be overwritten). visual_dag_path: Path to the visual DAG JSON whose nodes will be appended. Returns: The `dag_path` of the merged DAG. """ # === Load main DAG === with open(dag_path, "r", encoding="utf-8") as f: main_dag = json.load(f) if "nodes" not in main_dag or not isinstance(main_dag["nodes"], list): raise ValueError(f"Main DAG at '{dag_path}' is invalid: missing 'nodes' array.") # === Load visual DAG === with open(visual_dag_path, "r", encoding="utf-8") as f: visual_dag = json.load(f) if "nodes" not in visual_dag or not isinstance(visual_dag["nodes"], list): raise ValueError(f"Visual DAG at '{visual_dag_path}' is invalid: missing 'nodes' array.") # === Append visual nodes to main DAG (to the bottom) === main_dag["nodes"].extend(visual_dag["nodes"]) # === Save merged DAG back to dag_path (overwrite) === with open(dag_path, "w", encoding="utf-8") as f: json.dump(main_dag, f, ensure_ascii=False, indent=4) return dag_path # ========== 完善dag中每一个结点的visual_node ========== from typing import List def refine_visual_node(dag_path: str) -> None: """ Refine `visual_node` for each node in the DAG JSON at `dag_path`. Behavior: - Load the DAG JSON from `dag_path`, whose structure is: { "nodes": [ { "name": "...", "content": "...", "edge": [], "level": 0, "visual_node": [] }, ... ] } - For each node in `nodes`: * If node["visual_node"] == 1: - Treat this as a special marker meaning the node is already a visual node; skip it and do NOT modify `visual_node`. * Else: - Look at node["content"] (if it's a string). - Find all markdown image references of the form: ![alt text](path) using a regex. - Filter to keep only relative paths (e.g., 'images/xxx.jpg'): - path does NOT start with 'http://', 'https://', 'data:', or '//'. - For each such match, append the full markdown snippet (e.g., '![something](images/xxx.jpg)') into node["visual_node"]. - If `visual_node` is missing or not a list (and not equal to 1), it will be overwritten as a list of these strings. - The function overwrites the original `dag_path` with the refined DAG. """ # === Load DAG === with open(dag_path, "r", encoding="utf-8") as f: dag = json.load(f) if "nodes" not in dag or not isinstance(dag["nodes"], list): raise ValueError(f"DAG JSON at '{dag_path}' is invalid: missing 'nodes' array.") nodes: List[dict] = dag["nodes"] # Regex to match markdown images: ![alt](path) # group(0) = full match, group(1) = alt, group(2) = path img_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)") def is_relative_path(path: str) -> bool: """Return True if path looks like a relative path, not URL or absolute.""" lowered = path.strip().lower() if lowered.startswith("http://"): return False if lowered.startswith("https://"): return False if lowered.startswith("data:"): return False if lowered.startswith("//"): return False # You can optionally reject absolute filesystem paths too: # if lowered.startswith("/") or re.match(r"^[a-zA-Z]:[\\/]", lowered): # return False return True for node in nodes: # Skip if this is not a dict if not isinstance(node, dict): continue # If visual_node == 1, this is a special visual node -> skip if node.get("visual_node") == 1: continue content = node.get("content") if not isinstance(content, str) or not content: # No textual content to search # But if visual_node should still be a list, ensure that if "visual_node" not in node or not isinstance(node["visual_node"], list): node["visual_node"] = [] continue # Find all markdown image references matches = img_pattern.findall(content) # returns list of (alt, path) full_matches = img_pattern.finditer(content) # to get exact substrings # Ensure visual_node is a list (since we already filtered out ==1) visual_list = node.get("visual_node") if not isinstance(visual_list, list): visual_list = [] else: # create a copy to safely modify visual_list = list(visual_list) # To keep consistent mapping, use the iterator to get full strings for match in full_matches: full_str = match.group(0) # e.g., '![alt](images/xxx.jpg)' path_str = match.group(2).strip() # inside parentheses if not is_relative_path(path_str): continue # skip URLs / absolute paths if full_str not in visual_list: visual_list.append(full_str) # Update node node["visual_node"] = visual_list # === Save back to disk (overwrite) === with open(dag_path, "w", encoding="utf-8") as f: json.dump(dag, f, ensure_ascii=False, indent=4)