`, ``, ``, `
`.\n - The target system's renderer **cannot parse** entity encodings; if you output `<`, the system will error. You must output `<`.\n - Overwrite the target node's `content` field.\n\n3. ACTION: TYPOGRAPHY\n - Modify `typography.font-size` (>= 16pt).\n\n4. ACTION: MODIFY_TITLE (Update Heading)\n - Extract the new title text from [DETAIL].\n - Locate the `heading` field of the target node and overwrite it directly.\n\nConstraints\n1. Output must be valid JSON.\n2. **JSON Escaping vs HTML Escaping**:\n - ✅ **MUST** escape double quotes: `\"` (Requirement of JSON syntax)\n - ❌ **FORBIDDEN** to escape angle brackets: `<` (Requirement of your task)\n\nCorrect vs Incorrect Examples\n- ❌ Incorrect (Unstructured text): \"content\": \"This is point one. This is point two.\"\n- ✅ Correct (Bulleted structure): \"content\": \"- This is point one
- This is point two
\"\n\nInput Processing\nRead instructions and JSON, output the modified JSON.",
+
+ "poster_outline_prompt":"You are given a section node JSON (SECTION_JSON) from a paper DAG. The section JSON you see has NO visual_node field and must be treated as authoritative.\n\nSECTION_JSON:\n{SECTION_JSON}\n\nHAS_VISUAL: {HAS_VISUAL}\n\nIf HAS_VISUAL is true, you are also given the best visual node JSON (VISUAL_JSON), plus IMAGE_SRC and ALT_TEXT. The visual content MUST ONLY come from this provided IMAGE_SRC (do not invent or substitute any other image).\n\nVISUAL_JSON:\n{VISUAL_JSON}\n\nIMAGE_SRC: {IMAGE_SRC}\nALT_TEXT: {ALT_TEXT}\n\nTask:\n1) Write ONE concise paragraph summarizing ONLY the section's content for a scientific poster. Constraints: 2–5 sentences, factual, non-hallucinatory, no bullet lists, and avoid starting with \"This section\" Additional constraint: The summary must contain no more than 40 words and be written with strong logical coherence and smooth transitions to minimize perplexity (PPL).\n2) Output EXACTLY ONE HTML section block in the required template below. Output ONLY the HTML and nothing else.\n\nStrict output rules:\n- Output only ONE block.\n- Do NOT add markdown fences, explanations, or extra text.\n- The must be the section title (use SECTION_JSON.name).\n- Replace the sample paragraph text with your summary paragraph.\n- If HAS_VISUAL is true AND IMAGE_SRC is non-empty, include exactly one
with one
![]()
whose src is exactly IMAGE_SRC and alt is ALT_TEXT.\n- If HAS_VISUAL is false OR IMAGE_SRC is empty, do NOT output any img-section or img tag.\n\nRequired HTML template (follow structure exactly):\n
\n SECTION_TITLE
\n \n\n
SUMMARY_TEXT
\n\n
\n

\n
\n
\n",
+
+ "modified_poster_logic_prompt":"You are given an HTML-like poster outline consisting of multiple
blocks. Each section has a title in ...
and main text in the first ...
inside .\n\nYour task: Write ONLY bridging sentences to connect the narrative flow between sections.\n- You MUST generate one bridging sentence for each section from the FIRST section up to the SECOND-TO-LAST section.\n- Each bridging sentence should be placed at the END of that section's first
...
, and should naturally lead into the NEXT section.\n- The sentence MUST be content-aware: it should reference the next section's topic (based on its title and/or content) without adding new technical claims.\n\nSTRICT OUTPUT REQUIREMENTS:\n1) Output ONLY a valid JSON array of strings.\n2) The array length MUST equal (number_of_sections - 1).\n3) The i-th string corresponds to the i-th section (0-indexed) and bridges to section i+1.\n4) Do NOT output any HTML. Do NOT output explanations. Do NOT output markdown fences.\n5) Each string should be a single sentence in fluent academic English, concise (about 12–25 words), and must not include newline characters.\n\nCONTENT SAFETY / FORMATTING CONSTRAINTS:\n- Do NOT modify, rewrite, paraphrase, or comment on any existing text.\n- Do NOT include section numbers.\n- Avoid generic filler like \"In conclusion\". Prefer specific transitions like \"Next, we turn to ...\" while reflecting the next section’s topic.\n\nNow wait for the user to provide the full poster outline text.You may prioritize the following transition patterns: when the next section is intended to introduce, elaborate on, or substantiate a method or idea, use phrasing such as “To introduce / elaborate on / substantiate
, we next present .”; if the next section is an experimental section (e.g., the title contains Experiment(s) or the content discusses experimental setups, results, or performance), use transitions like “Next, we evaluate to demonstrate …” or “We then conduct experiments on to empirically validate the proposed method.”; in all cases, the transition sentence should explicitly point to the next section’s topic without introducing new technical details or conclusions, serving only to ensure narrative continuity.",
+
+ "poster_refinement_prompt": "You are an expert Academic Poster Designer and Web Developer. Your task is to refine an existing HTML poster based on its visual rendering (screenshot) and the current code. Output ONLY the full, valid, and corrected HTML code inside a single code block.\n\nI will provide the following content:\n1. **Current Poster Code (HTML)**\n2. **Visual Render (Image)**\n\n**TASKS:**\n\n**Task 1: Fix LaTeX/Encoding Issues**\n- Scan the HTML for raw LaTeX code that is not rendering correctly (e.g., \"$d_S \\ge 10$\").\n- **FIX:** Replace raw LaTeX with standard HTML entities or Unicode text to ensure readability (e.g., change \"$d_S \\ge 10$\" to \"dS ≥ 10\").\n- Remove any garbled characters or artifacts caused by math rendering failures.\n\n**Task 2: Normalize Section Headers**\n- Locate all elements with `class=\"section-bar\"`.\n- Remove any leading numbering, alphabetic prefixes, or meaningless characters (e.g., \"1.\", \"2.1\", \"A.\", \"E.\", \"-\") from the text content.\n- Ensure only the core title text remains.\n- **Example:** Change `E. Implementation details
` to `Implementation details
`.\n\n**Output Requirement:**\n- Return the **complete, runnable HTML code**.\n- Do not break the existing CSS layout style; only inject or fix content.\n\nI will now provide the Current Poster HTML and the Rendered Screenshot. Please wait for these inputs and then generate the final HTML code. IMPORTANT: During refinement, do NOT modify any font-related settings, including but not limited to font-family, font names, font styles, or font sources.",
+
+ "extract_basic_information_prompt":"You will be given a single JSON object named ROOT_NODE_JSON representing the root node of a paper DAG. Your task is to extract and format the paper's basic metadata.\n\nINPUT:\n- ROOT_NODE_JSON.name: paper title.\n- ROOT_NODE_JSON.content: contains authors followed by institutions (the first part is a comma-separated list of author names; the second part is a comma-separated list of institutions). The boundary may be indicated by a newline, multiple spaces, or a transition from person names to organization names. Infer the best split.\n- ROOT_NODE_JSON.github: a GitHub/project URL string. If missing, empty, null, or not a valid URL, output N/A.\n\nOUTPUT REQUIREMENTS (STRICT):\n- Output exactly 4 lines, in this exact order and with these exact labels:\nTitle: \nAuthor: \nInstitution: \nGithub: \n- Do NOT add any extra lines, explanations, bullets, numbering, code fences, or markdown.\n- Preserve original capitalization for names and institutions.\n- If authors or institutions cannot be reliably separated, make the best effort: keep person names in Author, organizations in Institution.",
+
+ "futurework_and_reference_prompt":"You are given a paper Markdown (MD). Your job:\n1) Find whether the paper includes a Future Work / Future work / Future directions / Limitations and future work / Discussion and future work section (or future-work-like content). If it exists, extract the key points and rewrite them into ONE concise paragraph (scientific, factual, non-hallucinatory). If it does NOT exist, write a reasonable Future Work paragraph by summarizing limitations and plausible next steps based ONLY on the provided MD. Do NOT invent results or claims not present in MD.\n2) Find the References/Bibliography section in MD and select the FIRST THREE references as they appear. Preserve author/title/venue/year/URL text exactly as present in MD. Do not invent missing bibliographic fields. If references are numbered, keep the numbering; otherwise create [1], [2], [3]. Each reference must be a single line string.\n\nOUTPUT FORMAT (STRICT):\nReturn ONLY valid JSON (no markdown fences, no extra text):\n{\n \"future_work\": \"\",\n \"references\": [\"ref1\", \"ref2\", \"ref3\"]\n}\n\nMD:\n{MD_TEXT}\n",
+
+ "contribution_prompt":"You are a professional academic editor. Your task is to identify and summarize the 'Contributions' of a research paper based on the provided Markdown content.\n\nRules:\n1. Identification: If a 'Contribution' section exists, extract and refine the key points. If NOT, analyze the Introduction or Conclusion to summarize the main contributions yourself.\n2. Constraints: Summarize the contributions into concise points. **Each point must be strictly under 25 words.**\n3. Format Requirement: Output ONLY pure plain text.\n - Do NOT use Markdown formatting (no bolding, headers, or bullet points).\n - Do NOT use HTML tags.\n - Separate distinct points using double newlines (\\n\\n).\n - The output must be ready to be wrapped in tags directly.",
+
+ "generate_pr_prompt":"You will receive ONE JSON object describing a paper section node with fields: name, content, visual_node (a list of image objects/paths). Your task: (1) Determine which high-level paper part this section belongs to: Introduction-like, Methods-like, Experiments/Results-like, or Conclusion-like, based on name and content. (2) Output ONLY ONE of the following formats (no extra text):\n\n[If Introduction-like]\nKey Question: <2-3 sentences, engaging question/surprising fact/relatable hook>\nBrilliant Idea: <2-3 sentences background/context/idea>\n\n\n[If Methods-like]\nCore Methods: \n\n\n[If Experiments/Results-like]\nCore Results: \n\n\n[If Conclusion-like]\nSignificance/Impact: \n\nRules: Use English labels exactly as shown. If no suitable image exists, omit the image line entirely. Choose at most ONE image. Do not output code fences. Here is the node JSON:\n{NODE_JSON}",
+
+ "add_title_and_hashtag_prompt":"You will receive a Markdown promotion draft for an academic paper.\n\nTASK:\n1) Write a short, catchy, easy-to-understand Title that accurately summarizes the core topic or main finding for a general audience. Prefer a hook style (question / key result / clear takeaway). Avoid excessive jargon, but keep scientific accuracy.\n2) Generate EXACTLY 3 highly relevant Specific Tags (each must start with '#', PascalCase recommended; no spaces inside a tag).\n3) Generate EXACTLY 1 Community Tag related to activity/community (must start with '#'; no spaces inside the tag).\n\nINPUT MARKDOWN (do not rewrite it, only use it as context):\n{MD_TEXT}\n\nOUTPUT FORMAT (strictly follow, one item per line; every tag MUST include '#'):\nTitle: \nSpecific Tag: <#Tag1> <#Tag2> <#Tag3>\nCommunity Tag: <#CommunityTag>",
+
+ "pr_refinement_prompt": "Refine the content of each section to match the 'Xiaohongshu' (Little Red Book) style in English. \n\nRequirements:\n1. **Style**: Use a lively, engaging, and 'sharing with friends' tone. Use short paragraphs and bullet points.\n2. **Visuals**: Generously use relevant emojis to structure the text and add vibe.\n3. **Content**: Strictly retain all original information and technical details. Do NOT add any information not present in the source text.\n4. **Length**: Keep the word count approximately the same as the original text.\n5. **Format**: Return valid Markdown.",
+
+ "section_generation_prompt":"You are given ONE section of an academic paper in Markdown format. Your task is to convert it into a DAG-style JSON object with schema: {\"nodes\":[{\"name\":string,\"content\":string,\"edge\":string[],\"level\":int,\"visual_node\":[]}...]}. STRICT RULES: (1) Output ONLY valid JSON (no markdown fences, no commentary, no extra keys). (2) Use the provided SECTION_FILENAME as the ROOT node name. (3) ROOT node: name = SECTION_FILENAME; content = the ENTIRE original Markdown exactly as given; level = 1; visual_node = []; edge = a list of subsection node names extracted from this Markdown. (4) Subsection nodes: Identify ALL subsections whose headings contain hierarchical numbering like '3.1', '3.2', '4.1', '4.3' etc. Treat each such heading as a subsection boundary. For each subsection, create ONE node where: name = the subsection heading text (remove leading Markdown #'s, trim spaces, keep the numbering and title text); content = the FULL Markdown belonging to that subsection ONLY (from its heading line up to before the next same-level numbered subsection heading or end of document), preserving original text exactly; level = 2; edge = []; visual_node = []. (5) ROOT.edge must exactly match the list of subsection node 'name' values, in their original order of appearance. (6) Do NOT invent subsections that are not present. (7) If no numbered subsections exist, output only the ROOT node with edge=[]. INPUTS YOU WILL RECEIVE: SECTION_FILENAME and SECTION_MARKDOWN. Now produce the JSON."
+
+}
\ No newline at end of file
diff --git a/src/DAG2poster.py b/src/DAG2poster.py
new file mode 100644
index 0000000000000000000000000000000000000000..45e6dcd4a273e48479e299b9eb104a3fb4ac0493
--- /dev/null
+++ b/src/DAG2poster.py
@@ -0,0 +1,871 @@
+from __future__ import annotations
+import json
+import os
+import re
+import shutil
+from pathlib import Path
+from typing import Optional
+from openai import OpenAI
+from google import genai
+from typing import Any, Dict, List, Optional
+import traceback
+import shutil
+from pathlib import Path
+from bs4 import BeautifulSoup
+
+
+# ========== Generate poster_outline.txt via Gemini (section-by-section) ==========
+def _load_json(path: str) -> Dict[str, Any]:
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"JSON not found: {path}")
+ with open(path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+
+def _ensure_dag_schema(dag_obj: Any) -> Dict[str, Any]:
+ """
+ Robustness: if LLM/other step produced a single node dict instead of {"nodes":[...]},
+ wrap it into the expected schema.
+ """
+ if isinstance(dag_obj, dict) and "nodes" in dag_obj and isinstance(dag_obj["nodes"], list):
+ return dag_obj
+ if isinstance(dag_obj, dict) and "name" in dag_obj and "content" in dag_obj:
+ return {"nodes": [dag_obj]}
+ raise ValueError("Invalid dag.json schema: expected {'nodes': [...]} or a single node dict.")
+
+
+def _resolution_area(resolution: Any) -> int:
+ """
+ resolution can be like "536x86" (string) or [536, 86] etc.
+ Returns area; invalid -> 0
+ """
+ if resolution is None:
+ return 0
+ if isinstance(resolution, str):
+ m = re.match(r"^\s*(\d+)\s*[xX]\s*(\d+)\s*$", resolution)
+ if not m:
+ return 0
+ w = int(m.group(1))
+ h = int(m.group(2))
+ return w * h
+ if isinstance(resolution, (list, tuple)) and len(resolution) >= 2:
+ try:
+ w = int(resolution[0])
+ h = int(resolution[1])
+ return w * h
+ except Exception:
+ return 0
+ if isinstance(resolution, dict):
+ # sometimes {"width":..., "height":...}
+ try:
+ w = int(resolution.get("width", 0))
+ h = int(resolution.get("height", 0))
+ return w * h
+ except Exception:
+ return 0
+ return 0
+
+
+def _strip_md_image(s: str) -> str:
+ return (s or "").strip()
+
+
+def _extract_image_src_from_md(md_image: str) -> Optional[str]:
+ """
+ md_image example: "" or "  "
+ returns "images/xxx.jpg" (without surrounding spaces)
+ """
+ if not md_image:
+ return None
+ m = re.search(r"!\[[^\]]*\]\(([^)]+)\)", md_image.strip())
+ if not m:
+ return None
+ return m.group(1).strip()
+
+
+def _safe_section_title(name: str) -> str:
+ """
+ Optional cleanup: remove trailing .md if present.
+ """
+ if not name:
+ return ""
+ name = name.strip()
+ if name.lower().endswith(".md"):
+ name = name[:-3]
+ return name.strip()
+
+
+def _remove_key_deep(obj: Any, key_to_remove: str) -> Any:
+ """
+ Create a JSON-serializable copy of obj with a top-level key removed if dict.
+ (We only need to remove section_node["visual_node"] at top-level, but keep it safe.)
+ """
+ if isinstance(obj, dict):
+ return {k: _remove_key_deep(v, key_to_remove) for k, v in obj.items() if k != key_to_remove}
+ if isinstance(obj, list):
+ return [_remove_key_deep(x, key_to_remove) for x in obj]
+ return obj
+
+def generate_poster_outline_txt(
+ dag_path: str,
+ poster_outline_path: str,
+ poster_outline_prompt: str,
+ model: str = "gemini-2.5-pro",
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ client: Optional[Any] = None, # Type relaxed to accept both clients
+ overwrite: bool = True,
+ config: dict = None
+) -> None:
+ """
+ Read dag.json from dag_path, iterate root->section nodes, and for each section:
+ - choose the largest-resolution visual node referenced by section["visual_node"]
+ - send (section_without_visual_node, best_visual_node_if_any, IMAGE_SRC, ALT_TEXT) to LLM
+ - LLM returns EXACTLY one HTML block
+ - append/write to poster_outline_path
+
+ Supports both OpenAI and Google GenAI (Gemini) clients.
+ """
+ # Resolve dag.json path
+ if os.path.isdir(dag_path):
+ dag_json_path = os.path.join(dag_path, "dag.json")
+ else:
+ dag_json_path = dag_path
+
+ dag_obj = _ensure_dag_schema(_load_json(dag_json_path))
+ nodes: List[Dict[str, Any]] = dag_obj.get("nodes", [])
+ if not nodes:
+ raise ValueError("dag.json has empty 'nodes'.")
+
+ # Root node is the first node by your spec
+ root = nodes[0]
+ root_edges = root.get("edge", [])
+ if not isinstance(root_edges, list) or not root_edges:
+ raise ValueError("Root node has no valid 'edge' list of section names.")
+
+ # Build lookup: name -> node (first occurrence)
+ name2node: Dict[str, Dict[str, Any]] = {}
+ for n in nodes:
+ if isinstance(n, dict) and "name" in n:
+ name2node.setdefault(str(n["name"]), n)
+
+ # Determine Model Type
+ is_gemini = "gemini" in model.lower()
+
+ # Prepare Client if not provided
+ if client is None:
+ api_keys_config = config.get("api_keys", {}) if config else {}
+
+ if is_gemini:
+ # Setup Google GenAI Client
+ api_key = api_key or api_keys_config.get("gemini_api_key") or os.getenv("GOOGLE_API_KEY")
+
+ client = genai.Client(api_key=api_key)
+ else:
+ # Setup OpenAI Client
+ api_key = api_key or api_keys_config.get("openai_api_key") or os.getenv("OPENAI_API_KEY")
+
+ client = OpenAI(api_key=api_key)
+
+ # Output file init
+ out_dir = os.path.dirname(os.path.abspath(poster_outline_path))
+ if out_dir and not os.path.exists(out_dir):
+ os.makedirs(out_dir, exist_ok=True)
+
+ write_mode = "w" if overwrite else "a"
+ with open(poster_outline_path, write_mode, encoding="utf-8") as f_out:
+ # Iterate sections in the order of root.edge
+ for sec_name in root_edges:
+ if sec_name not in name2node:
+ raise KeyError(f"Section node not found by name from root.edge: {sec_name}")
+
+ section_node = name2node[sec_name]
+ if not isinstance(section_node, dict):
+ raise ValueError(f"Invalid section node for name={sec_name}")
+
+ # Find all visual nodes referenced by this section
+ visual_refs = section_node.get("visual_node", [])
+ best_visual_node: Optional[Dict[str, Any]] = None
+ best_area = -1
+ best_image_src: Optional[str] = None
+
+ if isinstance(visual_refs, list) and len(visual_refs) > 0:
+ for ref in visual_refs:
+ ref_str = _strip_md_image(str(ref))
+ cand = name2node.get(ref_str)
+
+ if cand is None:
+ for k, v in name2node.items():
+ if isinstance(k, str) and k.strip() == ref_str.strip():
+ cand = v
+ break
+
+ if cand is None or not isinstance(cand, dict):
+ continue
+
+ area = _resolution_area(cand.get("resolution"))
+ if area > best_area:
+ best_area = area
+ best_visual_node = cand
+ best_image_src = _extract_image_src_from_md(str(cand.get("name", "")))
+
+ if best_visual_node is not None and not best_image_src:
+ for ref in visual_refs:
+ tmp = _extract_image_src_from_md(str(ref))
+ if tmp:
+ best_image_src = tmp
+ break
+
+ # Build the section JSON WITHOUT visual_node attribute
+ section_wo_visual = _remove_key_deep(section_node, "visual_node")
+ section_wo_visual["name"] = _safe_section_title(str(section_wo_visual.get("name", "")))
+
+ # Compose ALT_TEXT
+ alt_text = None
+ if best_visual_node is not None:
+ cap = best_visual_node.get("caption")
+ if isinstance(cap, str) and cap.strip():
+ alt_text = cap.strip()
+ if not alt_text:
+ alt_text = "Figure"
+
+ # Compose prompt input fields
+ section_json_str = json.dumps(section_wo_visual, ensure_ascii=False, indent=2)
+
+ if best_visual_node is not None:
+ visual_json_str = json.dumps(best_visual_node, ensure_ascii=False, indent=2)
+ image_src = best_image_src or ""
+ if image_src and not image_src.startswith("images/") and "images/" in image_src:
+ pass
+ payload = poster_outline_prompt.format(
+ SECTION_JSON=section_json_str,
+ HAS_VISUAL="true",
+ VISUAL_JSON=visual_json_str,
+ IMAGE_SRC=image_src,
+ ALT_TEXT=alt_text,
+ )
+ else:
+ payload = poster_outline_prompt.format(
+ SECTION_JSON=section_json_str,
+ HAS_VISUAL="false",
+ VISUAL_JSON="",
+ IMAGE_SRC="",
+ ALT_TEXT="",
+ )
+
+ # Call API based on model type
+ html_block = ""
+ if is_gemini:
+ # Gemini New API Call
+ # Note: config is passed via helper if needed, usually default generation config is fine
+ resp = client.models.generate_content(
+ model=model,
+ contents=payload
+ )
+ if resp.text:
+ html_block = resp.text
+ else:
+ raise RuntimeError("Gemini returned empty content.")
+ else:
+ # OpenAI API Call
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": payload}],
+ )
+ if not hasattr(resp, "choices") or not resp.choices:
+ raise RuntimeError("OpenAI returned empty choices.")
+
+ html_block = resp.choices[0].message.content
+
+ if not isinstance(html_block, str) or not html_block.strip():
+ raise RuntimeError("LLM returned empty content string.")
+
+ # Append to output file
+ f_out.write(html_block.strip())
+ f_out.write("\n\n")
+
+
+# ========== Modify poster_outline.txt ==========
+def modify_poster_outline(
+ poster_outline_path: str,
+ poster_paper_name: str,
+ modified_poster_outline_path: str
+):
+ """
+ 功能:
+ 1. 找到 section-bar 内容等于 poster_paper_name 的 section (忽略大小写):
+ - 将其 section-bar 改为 "Introduction"
+ - 将该 section 移动到文件最前面
+ 2. 对其余 section:
+ - 删除 section-bar 标题前的数字序号
+ 3. 只保留处理后的前 6 个 section
+ 4. 将最终结果保存到 modified_poster_outline_path
+ """
+
+ text = Path(poster_outline_path).read_text(encoding="utf-8")
+
+ # ===== 1. 提取所有 section 块 =====
+ section_pattern = re.compile(
+ r"",
+ re.DOTALL
+ )
+ sections = section_pattern.findall(text)
+
+ intro_section = None
+ other_sections = []
+
+ # 预处理目标名称:去除首尾空格并转小写,用于后续比较
+ target_name_normalized = poster_paper_name.strip().lower()
+
+ for sec in sections:
+ # 提取 section-bar 内容
+ m = re.search(
+ r"(.*?)
",
+ sec,
+ re.DOTALL
+ )
+ if not m:
+ continue
+
+ # 获取原始内容用于后续替换,同时获取用于比较的归一化字符串
+ original_title = m.group(1).strip()
+ current_title_normalized = original_title.lower()
+
+ # ===== 2. 处理 paper title 对应的 section (修改点:忽略大小写) =====
+ if current_title_normalized == target_name_normalized:
+ # 改名为 Introduction
+ sec = re.sub(
+ r".*?
",
+ 'Introduction
',
+ sec,
+ count=1,
+ flags=re.DOTALL
+ )
+ intro_section = sec
+ else:
+ # ===== 3. 删除其余 section 标题前的数字序号 =====
+ # 例如: "2 Contextual Auction Design" -> "Contextual Auction Design"
+ new_title = re.sub(r"^\s*\d+(\.\d+)*\s*", "", original_title)
+
+ # 仅替换标题部分
+ sec = sec.replace(original_title, new_title, 1)
+ other_sections.append(sec)
+
+ # ===== 4. 重新组合内容(Introduction 在最前) =====
+ final_sections = []
+ if intro_section is not None:
+ final_sections.append(intro_section)
+ final_sections.extend(other_sections)
+
+ # ===== 5. (修改点) 只保留前 6 个 section =====
+ final_sections = final_sections[:6]
+ # ===== 5.5 清洗 section-bar:确保以字母单词开头 =====
+ cleaned_sections = []
+
+ for sec in final_sections:
+ def _clean_title(match):
+ title = match.group(1)
+ # 去掉开头所有非字母字符(直到第一个字母)
+ cleaned_title = re.sub(r"^[^A-Za-z]+", "", title)
+ return f'{cleaned_title}
'
+ sec = re.sub(
+ r'(.*?)
',
+ _clean_title,
+ sec,
+ count=1,
+ flags=re.DOTALL
+ )
+ cleaned_sections.append(sec)
+
+ final_sections = cleaned_sections
+ final_text = "\n\n".join(final_sections)
+
+ # ===== 6. 保存结果 =====
+ Path(modified_poster_outline_path).write_text(final_text, encoding="utf-8")
+
+
+# ========== Build final poster HTML from outline.txt + template ==========
+def build_poster_from_outline(
+ poster_outline_path: str,
+ poster_template_path: str,
+ poster_path: str,
+) -> str:
+ """
+ 输入:
+ - poster_outline_path: 一个 .txt 文件路径,内容为要插入的 HTML 片段(若干 等)
+ - poster_template_path: poster 模板路径(目录 或 具体文件)。函数会在此路径下定位 poster_template.html
+ - poster_path: 输出 HTML 保存路径(完整文件路径,如 /xxx/my_poster.html)
+
+ 行为:
+ 1) 定位 poster_template.html(不修改原文件)
+ 2) 复制一份到 poster_path
+ 3) 在复制后的 HTML 中,找到:
+
+
+ ...这里...
+
+
+ 并将 outline txt 的内容插入到 与其
之间
+ 4) 做基础稳健性处理:换行规范化、缩进对齐、避免破坏标签结构
+
+ 返回:
+ - poster_path(便于上层链式调用)
+ """
+ # ---------- 基础检查 ----------
+ if not os.path.isfile(poster_outline_path):
+ raise FileNotFoundError(f"poster_outline_path not found: {poster_outline_path}")
+
+ if not poster_path.lower().endswith(".html"):
+ raise ValueError(f"poster_path must be an .html file path, got: {poster_path}")
+
+ os.makedirs(os.path.dirname(os.path.abspath(poster_path)), exist_ok=True)
+
+ # ---------- 定位模板文件 poster_template.html ----------
+ template_file = None
+ if os.path.isdir(poster_template_path):
+ candidate = os.path.join(poster_template_path, "poster_template.html")
+ if os.path.isfile(candidate):
+ template_file = candidate
+ else:
+ # 兜底:递归搜索同名文件(防止模板目录层级不固定)
+ for root, _, files in os.walk(poster_template_path):
+ if "poster_template.html" in files:
+ template_file = os.path.join(root, "poster_template.html")
+ break
+ else:
+ # poster_template_path 可能直接就是某个文件
+ if os.path.isfile(poster_template_path) and os.path.basename(poster_template_path) == "poster_template.html":
+ template_file = poster_template_path
+ elif os.path.isfile(poster_template_path):
+ # 兜底:如果用户传的是某个 html 文件,也允许用它作为模板
+ template_file = poster_template_path
+
+ if template_file is None:
+ raise FileNotFoundError(
+ f"Cannot locate poster_template.html under: {poster_template_path}"
+ )
+
+ # ---------- 读取 outline 内容,并做换行规范化 ----------
+ with open(poster_outline_path, "r", encoding="utf-8") as f:
+ outline_raw = f.read()
+
+ # 统一换行到 '\n',并去掉 BOM
+ outline_raw = outline_raw.replace("\ufeff", "").replace("\r\n", "\n").replace("\r", "\n").strip()
+
+ # 如果 outline 为空,允许插入空内容(但仍会保持结构)
+ # 给 outline 末尾补一个换行,避免与
直接黏连
+ if outline_raw:
+ outline_raw += "\n"
+
+ # ---------- 复制模板到输出路径(不修改原模板) ----------
+ shutil.copyfile(template_file, poster_path)
+
+ # ---------- 读取复制后的 html ----------
+ with open(poster_path, "r", encoding="utf-8") as f:
+ html = f.read()
+
+ html = html.replace("\ufeff", "").replace("\r\n", "\n").replace("\r", "\n")
+
+ # ---------- 在 ...
内插入 ----------
+ # 说明:
+ # - 使用非贪婪匹配,尽量锁定 main/flow 这个区域
+ # - 捕获 div 起始标签、原内部内容、div 结束标签
+ pattern = re.compile(
+ r'(\s*'
+ r'\s*)'
+ r'(.*?)'
+ r'(\s*
\s*)',
+ flags=re.DOTALL | re.IGNORECASE,
+ )
+
+ m = pattern.search(html)
+ if not m:
+ # 再做一次更宽松的匹配(只要求 flow div,不强依赖 main 结构)
+ pattern2 = re.compile(
+ r'(\s*)'
+ r'(.*?)'
+ r'(\s*
)',
+ flags=re.DOTALL | re.IGNORECASE,
+ )
+ m2 = pattern2.search(html)
+ if not m2:
+ raise ValueError(
+ 'Cannot find target insertion block: ...
'
+ )
+
+ prefix, _, suffix = m2.group(1), m2.group(2), m2.group(3)
+ base_indent = _infer_indent_from_prefix(prefix, html, m2.start(1))
+ outline_formatted = _indent_block(outline_raw, base_indent + " ") # 默认在 div 内再缩进 2 空格
+ new_block = prefix + "\n" + outline_formatted + suffix
+ html = html[: m2.start()] + new_block + html[m2.end():]
+ else:
+ prefix, _, suffix = m.group(1), m.group(2), m.group(3)
+ base_indent = _infer_indent_from_prefix(prefix, html, m.start(1))
+ outline_formatted = _indent_block(outline_raw, base_indent + " ")
+ new_block = prefix + "\n" + outline_formatted + suffix
+ html = html[: m.start()] + new_block + html[m.end():]
+
+ # ---------- 轻量格式稳健性:清理过多空行 ----------
+ html = _collapse_blank_lines(html)
+
+ # ---------- 写回输出 ----------
+ with open(poster_path, "w", encoding="utf-8", newline="\n") as f:
+ f.write(html)
+
+ return poster_path
+
+
+def _infer_indent_from_prefix(prefix: str, full_html: str, prefix_start_idx: int) -> str:
+ """
+ 推断插入区域的基础缩进(用于让插入块的空格更“正确”)。
+ 策略:取 prefix_start_idx 所在行的前导空白作为 base indent。
+ """
+ line_start = full_html.rfind("\n", 0, prefix_start_idx) + 1
+ line = full_html[line_start:prefix_start_idx]
+ m = re.match(r"[ \t]*", line)
+ return m.group(0) if m else ""
+
+
+def _indent_block(text: str, indent: str) -> str:
+ """
+ 将一段多行文本整体缩进到指定 indent。
+ - 空行保持为空行(不强塞空格),避免出现“看起来很多空格”的脏格式
+ """
+ if not text:
+ return ""
+ lines = text.split("\n")
+ out = []
+ for ln in lines:
+ if ln.strip() == "":
+ out.append("")
+ else:
+ out.append(indent + ln)
+ return "\n".join(out) + ("\n" if not text.endswith("\n") else "")
+
+
+def _collapse_blank_lines(html: str, max_blank: int = 2) -> str:
+ """
+ 将连续空行压缩到最多 max_blank 行,避免插入后产生大量空白。
+ """
+ # 先把只含空白的行变成真正空行
+ html = re.sub(r"[ \t]+\n", "\n", html)
+ # 压缩空行:\n\n\n... -> 最多 max_blank+1 个 \n(表示 max_blank 个空行)
+ html = re.sub(r"\n{"+str(max_blank+2)+r",}", "\n" * (max_blank + 1), html)
+ return html
+
+
+# ========== 修改 poster.html 中的 title 和 authors ==========
+def modify_title_and_author(dag_path: str, poster_path: str) -> None:
+ if not os.path.exists(dag_path):
+ raise FileNotFoundError(f"dag.json not found: {dag_path}")
+ if not os.path.exists(poster_path):
+ raise FileNotFoundError(f"poster.html not found: {poster_path}")
+
+ with open(dag_path, "r", encoding="utf-8") as f:
+ dag: Dict[str, Any] = json.load(f)
+
+ nodes = dag.get("nodes")
+ if not isinstance(nodes, list) or len(nodes) == 0:
+ raise ValueError("Invalid dag.json: missing or empty 'nodes' list")
+
+ first = nodes[0]
+ if not isinstance(first, dict):
+ raise ValueError("Invalid dag.json: first node is not an object")
+
+ title = str(first.get("name", "")).strip()
+ authors = str(first.get("content", "")).strip()
+
+ if not title:
+ raise ValueError("Invalid dag.json: first node 'name' (title) is empty")
+ if not authors:
+ raise ValueError("Invalid dag.json: first node 'content' (authors) is empty")
+
+ with open(poster_path, "r", encoding="utf-8") as f:
+ html = f.read()
+
+ title_pattern = re.compile(
+ r'()(.*?)(
)',
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if not title_pattern.search(html):
+ raise ValueError('Cannot find ...
in poster.html')
+
+ html = title_pattern.sub(lambda m: m.group(1) + title + m.group(3), html, count=1)
+
+ authors_pattern = re.compile(
+ r'()(.*?)(
)',
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if not authors_pattern.search(html):
+ raise ValueError('Cannot find ...
in poster.html')
+
+ html = authors_pattern.sub(lambda m: m.group(1) + authors + m.group(3), html, count=1)
+
+ with open(poster_path, "w", encoding="utf-8") as f:
+ f.write(html)
+
+
+def inject_img_section_to_poster(
+ figure_path: str,
+ auto_path: str,
+ poster_path: str,
+ target_filename: str = "expore_our_work_in_detail.jpg",
+) -> str:
+ """
+ 1) 将 figure_path 指向的图片复制到 auto_path/images/ 下(文件名固定为 target_filename)
+ 2) 读取 poster_path 对应的 HTML,定位 内部的
+ (若存在),把 img-section 插入到该 flow div 的末尾,
+ 从而保证新增块出现在 `
` 的 (flow 的闭合)之前。
+ 若找不到 flow div,则退化为插入到 main 的末尾。
+
+ 返回:写回后的 poster_path(绝对路径)
+ """
+ auto_dir = Path(auto_path).expanduser().resolve()
+ poster_file = Path(poster_path).expanduser().resolve()
+ src_figure = Path(figure_path).expanduser().resolve()
+
+ if not src_figure.exists() or not src_figure.is_file():
+ raise FileNotFoundError(f"figure_path not found or not a file: {src_figure}")
+ if not auto_dir.exists() or not auto_dir.is_dir():
+ raise FileNotFoundError(f"auto_path not found or not a directory: {auto_dir}")
+ if not poster_file.exists() or not poster_file.is_file():
+ raise FileNotFoundError(f"poster_path not found or not a file: {poster_file}")
+
+ # 1) copy image into auto/images/
+ images_dir = auto_dir / "images"
+ images_dir.mkdir(parents=True, exist_ok=True)
+
+ dst_figure = images_dir / target_filename
+ shutil.copy2(src_figure, dst_figure)
+
+ # 2) edit poster html
+ html_text = poster_file.read_text(encoding="utf-8")
+ soup = BeautifulSoup(html_text, "html.parser")
+
+ main_tag = soup.find("main", class_="main")
+ if main_tag is None:
+ raise ValueError(f'Cannot find
in poster: {poster_file}')
+
+ # Prefer inserting into the flow div so the new block sits right before
+ flow_tag = main_tag.find("div", attrs={"class": "flow", "id": "flow"})
+
+ # Avoid duplicate insertion
+ target_src = f"images/{target_filename}"
+ existing_img = main_tag.find("img", attrs={"src": target_src})
+ if existing_img is None:
+ new_div = soup.new_tag("div", attrs={"class": "img-section"})
+ new_img = soup.new_tag(
+ "img",
+ attrs={"src": target_src, "alt": "", "class": "figure"},
+ )
+ new_div.append(new_img)
+
+ if flow_tag is not None:
+ # Insert before of flow
+ flow_tag.append(new_div)
+ else:
+ # Fallback: append to main
+ main_tag.append(new_div)
+
+ # Write back
+ poster_file.write_text(str(soup), encoding="utf-8")
+
+ return str(poster_file)
+
+
+# =================================优化逻辑性==================================
+def _parse_sections(html_text: str) -> List[dict]:
+ """
+ 解析每个 块,提取:
+ - section_block: 原始块文本
+ - title: section-bar 内标题
+ - first_p_inner: 第一个
...
的 inner 文本(可为空)
+ - p_span: 第一个 ...
的 (start,end) 在 section_block 内的 span(包含..
)
+ """
+ section_pat = re.compile(
+ r'()',
+ re.DOTALL | re.IGNORECASE,
+ )
+ sections = []
+ for m in section_pat.finditer(html_text):
+ block = m.group(1)
+
+ # 标题
+ title_m = re.search(
+ r']*>(.*?)
',
+ block,
+ re.DOTALL | re.IGNORECASE,
+ )
+ title = title_m.group(1).strip() if title_m else ""
+
+ # 只处理第一个 ...
+ p_m = re.search(r'(]*>)(.*?)(
)', block, re.DOTALL | re.IGNORECASE)
+ if p_m:
+ p_open, p_inner, p_close = p_m.group(1), p_m.group(2), p_m.group(3)
+ p_span = (p_m.start(0), p_m.end(0))
+ first_p_inner = p_inner
+ else:
+ p_span = None
+ first_p_inner = ""
+
+ sections.append(
+ {
+ "section_block": block,
+ "title": title,
+ "first_p_inner": first_p_inner,
+ "p_span": p_span,
+ "match_span_in_full": (m.start(1), m.end(1)), # span in full html_text
+ }
+ )
+ return sections
+
+
+def _extract_json_array(text: str) -> List[str]:
+ """
+ 从模型输出中提取 JSON 数组(允许模型带少量前后缀文本,但最终必须能抽到一个 [...])。
+ """
+ text = text.strip()
+ # 直接就是JSON数组
+ if text.startswith("["):
+ try:
+ arr = json.loads(text)
+ if isinstance(arr, list) and all(isinstance(x, str) for x in arr):
+ return arr
+ except Exception:
+ pass
+
+ # 尝试抽取第一个 [...] 段
+ m = re.search(r"\[[\s\S]*\]", text)
+ if not m:
+ raise ValueError("LLM output does not contain a JSON array.")
+ arr_str = m.group(0)
+ arr = json.loads(arr_str)
+ if not (isinstance(arr, list) and all(isinstance(x, str) for x in arr)):
+ raise ValueError("Extracted JSON is not a list of strings.")
+ return arr
+
+def modified_poster_logic(
+ poster_outline_path_modified: str,
+ modified_poster_logic_prompt: str,
+ model: Optional[str] = None,
+ temperature: float = 0.2,
+ config: dict = None
+) -> str:
+ """
+ 读取 poster_outline_path_modified(txt) 的 HTML-like 内容;
+ 把全文发给 LLM,让其“只输出”从第一个到倒数第二个小节需要追加的衔接句(JSON 数组,顺序一致);
+ 然后将这些衔接句依次追加到每个小节的第一个 ...
的末尾(在 前),不改变任何原格式/内容;
+ 最后覆盖写回原 txt,并返回该 txt 的绝对路径。
+ """
+ txt_path = Path(poster_outline_path_modified).expanduser().resolve()
+ if not txt_path.exists() or not txt_path.is_file():
+ raise FileNotFoundError(f"txt not found: {txt_path}")
+
+ html_text = txt_path.read_text(encoding="utf-8")
+
+ # 注意:_parse_sections 和 _extract_json_array 需在上下文环境中定义
+ sections = _parse_sections(html_text)
+ if len(sections) < 2:
+ return str(txt_path)
+
+ # 需要加衔接句的小节数量:从第1到倒数第2 => len(sections)-1
+ expected_n = len(sections) - 1
+
+ # 确定模型名称
+ model_name = model or os.getenv("OPENAI_MODEL") or "gpt-4o"
+ is_gemini = "gemini" in model_name.lower()
+
+ # 获取 API 配置
+ api_keys_config = config.get("api_keys", {}) if config else {}
+
+ out_text = ""
+
+ if is_gemini:
+ # --- Gemini Client Setup ---
+ api_key = api_keys_config.get("gemini_api_key") or os.getenv("GOOGLE_API_KEY")
+
+ client = genai.Client(api_key=api_key)
+
+ # Call Gemini
+ # 将 system prompt 放入配置中,user content 放入 contents
+ resp = client.models.generate_content(
+ model=model_name,
+ contents=html_text,
+ config={
+ "system_instruction": modified_poster_logic_prompt,
+ "temperature": temperature,
+ }
+ )
+ if not resp.text:
+ raise RuntimeError("Gemini returned empty text.")
+ out_text = resp.text
+
+ else:
+ # --- OpenAI Client Setup ---
+ api_key = api_keys_config.get("openai_api_key") or os.getenv("OPENAI_API_KEY")
+
+ client = OpenAI(api_key=api_key)
+
+ # Call OpenAI
+ messages = [
+ {"role": "system", "content": modified_poster_logic_prompt},
+ {"role": "user", "content": html_text},
+ ]
+ resp = client.chat.completions.create(
+ model=model_name,
+ messages=messages,
+ temperature=temperature,
+ )
+ out_text = resp.choices[0].message.content or ""
+
+ # 解析返回的 JSON
+ transitions = _extract_json_array(out_text)
+
+ if len(transitions) != expected_n:
+ # 容错:如果 LLM 偶尔多生成或少生成,可根据实际情况决定是报错还是截断/填充
+ # 这里保持原逻辑报错
+ raise ValueError(
+ f"Transition count mismatch: expected {expected_n}, got {len(transitions)}"
+ )
+
+ # 逐个 section 进行插入:只改 ...
inner 的末尾( 前)
+ new_html_parts = []
+ cursor = 0
+
+ for i, sec in enumerate(sections):
+ full_start, full_end = sec["match_span_in_full"]
+ # 先拼接 section 之前的内容(保持原样)
+ new_html_parts.append(html_text[cursor:full_start])
+
+ block = sec["section_block"]
+ p_span = sec["p_span"]
+
+ if i <= len(sections) - 2: # 第1到倒数第2个 section
+ trans = transitions[i].strip()
+ if trans and p_span:
+ # 在该 section 的第一个 前插入
+ p_start, p_end = p_span
+ p_block = block[p_start:p_end]
+
+ close_idx = p_block.lower().rfind("")
+ if close_idx == -1:
+ new_block = block
+ else:
+ insert = (" " + trans) if not p_block[:close_idx].endswith((" ", "\n", "\t")) else trans
+ new_p_block = p_block[:close_idx] + insert + p_block[close_idx:]
+ new_block = block[:p_start] + new_p_block + block[p_end:]
+ else:
+ new_block = block
+ else:
+ # 最后一个 section 不加衔接句
+ new_block = block
+
+ new_html_parts.append(new_block)
+ cursor = full_end
+
+ # 追加尾部
+ new_html_parts.append(html_text[cursor:])
+ new_html_text = "".join(new_html_parts)
+
+ txt_path.write_text(new_html_text, encoding="utf-8")
+ return str(txt_path)
\ No newline at end of file
diff --git a/src/DAG2ppt.py b/src/DAG2ppt.py
new file mode 100644
index 0000000000000000000000000000000000000000..874aff1ae5ec715e7dd1b316fd851b5adc2c724a
--- /dev/null
+++ b/src/DAG2ppt.py
@@ -0,0 +1,773 @@
+import json
+import os
+import re
+from typing import Optional
+from openai import OpenAI
+from google import genai
+from typing import Any, Dict, List, Optional, Union
+
+
+# ========== 生成selected_nodes.json ==========
+def generate_selected_nodes(dag_json_path, max_len, output_path='selected_node.json'):
+
+ # 1. 读取 dag.json
+ with open(dag_json_path, 'r', encoding='utf-8') as f:
+ dag_data = json.load(f)
+
+ all_nodes = dag_data.get('nodes', [])
+
+ # 2. 构建辅助字典,方便通过 name 快速查找节点信息
+ # 同时区分普通节点和视觉节点
+ node_map = {node['name']: node for node in all_nodes}
+
+ # 3. 初始化队列
+ # 找到根节点 (level=0)
+ root_node = next((node for node in all_nodes if node.get('level') == 0), None)
+
+ if not root_node:
+ raise ValueError("Root node (level 0) not found in dag.json")
+
+ # 获取根节点的子节点 (Sections) 作为初始队列
+ # 注意:这里队列存储的是节点的 name
+ current_queue = list(root_node.get('edge', []))
+
+ # 初始化计数器
+ node_num = len(current_queue)
+ level_num = 1
+
+ # 4. 循环处理队列,直到 level_num 达到 5
+ while level_num < 5:
+ i = 0
+ while i < len(current_queue):
+ node_name = current_queue[i]
+ node_info = node_map.get(node_name)
+
+ if not node_info:
+ # 异常情况:队列里的节点在map里找不到
+ i += 1
+ continue
+
+ # ===== 新增逻辑:如果结点 name 含有 "introduction"/"INTRODUCTION",则跳过该结点 =====
+ # 注意:不修改其他逻辑,仅在处理该结点时直接跳过
+ if "introduction" in node_name.lower():
+ i += 1
+ continue
+
+ # 这里的 level 属性可能缺失,默认给个非当前level的值
+ current_node_level = node_info.get('level', -1)
+
+ # 判断这个结点的level是否等于level_num
+ if current_node_level != level_num:
+ i += 1
+ continue
+
+ # 获取子节点
+ children_names = node_info.get('edge', [])
+ num_children = len(children_names)
+
+ if num_children == 0:
+ # 没有子节点,无法展开
+ i += 1
+ continue
+
+ potential_total_num = len(current_queue) + num_children
+ if len(current_queue) + num_children <= max_len:
+ # 执行展开操作
+ current_queue[i:i+1] = children_names
+ else:
+ # 大于 max_num,不展开,处理下一个
+ i += 1
+
+ # 当处理完当前队列的最后一个结点时,level+1
+ level_num += 1
+
+ # 5. 生成最终结果
+ final_nodes_list = []
+
+ for node_name in current_queue:
+ original_node = node_map.get(node_name)
+ if not original_node:
+ continue
+
+ # 深拷贝以避免修改原始数据(也可以直接构建新字典)
+ # 这里为了安全起见构建新字典
+ export_node = original_node.copy()
+
+ original_visual_list = export_node.get('visual_node', [])
+
+ # 某些节点可能 visual_node 字段是空的或者不存在
+ if original_visual_list:
+ expanded_visual_nodes = []
+
+ # 确保它是列表,有些脏数据可能不是列表
+ if isinstance(original_visual_list, list):
+ for v_name in original_visual_list:
+ # 根据 name 查找视觉节点详细信息
+ v_node_full = node_map.get(v_name)
+ if v_node_full:
+ expanded_visual_nodes.append(v_node_full)
+ else:
+ # 如果找不到,保留原名或者忽略,这里选择保留原结构提醒缺失
+ expanded_visual_nodes.append({"name": v_name, "error": "Node not found"})
+
+ # 替换原有属性
+ export_node['visual_node'] = expanded_visual_nodes
+
+ final_nodes_list.append(export_node)
+
+ # 6. 写入文件
+ output_data = {"selected_nodes": final_nodes_list}
+
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(final_nodes_list, f, ensure_ascii=False, indent=4)
+
+ print(f"Successfully generated {output_path} with {len(final_nodes_list)} nodes.")
+
+
+
+
+# ========== 初始化outline ==========
+import os
+import json
+from openai import OpenAI
+from google import genai
+from google.genai import types
+
+def outline_initialize(dag_json_path, outline_initialize_prompt, model, config):
+ """
+ 使用 LLM 初始化 outline.json(仅创建两个节点:Title + Contents)
+ 适配 OpenAI 和 Google Gemini (新版 google-genai SDK)
+
+ 输入:
+ dag_json_path: dag.json 文件路径
+ outline_initialize_prompt: 传给 LLM 的 prompt(字符串)
+ model: 模型名称 (例如 "gpt-4o" 或 "gemini-2.0-flash")
+ config: 配置字典,需包含 ['api_keys']['gemini_api_key']
+
+ 输出:
+ outline.json: 保存在 dag.json 同目录
+ 返回 python list(outline 结构)
+ """
+
+ # --- load dag.json ---
+ if not os.path.exists(dag_json_path):
+ raise FileNotFoundError(f"dag.json not found: {dag_json_path}")
+
+ with open(dag_json_path, "r", encoding="utf-8") as f:
+ dag_data = json.load(f)
+
+ # --- extract first node ---
+ if isinstance(dag_data, list):
+ first_node = dag_data[0]
+ elif isinstance(dag_data, dict) and "nodes" in dag_data:
+ first_node = dag_data["nodes"][0]
+ else:
+ raise ValueError("Unsupported dag.json format")
+
+ first_node_text = json.dumps(first_node, ensure_ascii=False, indent=2)
+
+ # 系统提示词
+ system_prompt = "You are an expert academic presentation outline generator."
+
+ raw_output = ""
+
+ # --- LLM Call Switch ---
+ # 简单的判别逻辑:如果模型名包含 "gemini" 则调用 Google SDK,否则默认为 OpenAI 兼容 SDK
+ if "gemini" in model.lower():
+ # --- Gemini Call (google-genai SDK) ---
+ api_key = config['api_keys'].get('gemini_api_key')
+
+ # 配置 Client
+ client = genai.Client(api_key=api_key)
+
+ # 构造 user 消息内容
+ user_content = f"{outline_initialize_prompt}\n\nData Context:\n{first_node_text}"
+
+ try:
+ response = client.models.generate_content(
+ model=model,
+ contents=user_content,
+ config=types.GenerateContentConfig(
+ system_instruction=system_prompt,
+ temperature=0.0,
+ response_mime_type="application/json" # 强制 Gemini 输出 JSON,提高稳定性
+ )
+ )
+ raw_output = response.text
+ except Exception as e:
+ raise RuntimeError(f"Gemini API call failed: {str(e)}")
+
+ else:
+ # --- OpenAI Call ---
+ api_key = config['api_keys'].get('openai_api_key')
+
+ client = OpenAI(api_key=api_key)
+
+ try:
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": system_prompt
+ },
+ {
+ "role": "user",
+ "content": outline_initialize_prompt
+ },
+ {
+ "role": "user",
+ "content": first_node_text
+ }
+ ],
+ temperature=0
+ )
+ raw_output = response.choices[0].message.content.strip()
+ except Exception as e:
+ raise RuntimeError(f"OpenAI API call failed: {str(e)}")
+
+ # --- Extract JSON (Generic cleaning logic) ---
+ cleaned = raw_output.strip()
+
+ # Remove ```json ... ``` markdown fences
+ if cleaned.startswith("```"):
+ cleaned = cleaned.strip("`")
+ if cleaned.lstrip().startswith("json"):
+ cleaned = cleaned.split("\n", 1)[1]
+
+ # Robustness: locate JSON block via first [ and last ]
+ try:
+ first = cleaned.index("[")
+ last = cleaned.rindex("]")
+ cleaned = cleaned[first:last + 1]
+ except ValueError:
+ pass # Try parsing the whole string if brackets aren't found cleanly
+
+ try:
+ outline_data = json.loads(cleaned)
+ except json.JSONDecodeError:
+ raise ValueError(f"LLM output is not valid JSON:\nRaw Output: {raw_output}")
+
+ # --- Save outline.json ---
+ out_dir = os.path.dirname(dag_json_path)
+ out_path = os.path.join(out_dir, "outline.json")
+
+ with open(out_path, "w", encoding="utf-8") as f:
+ json.dump(outline_data, f, indent=4, ensure_ascii=False)
+
+ print(f"✅ Outline saved to: {out_path} (Model: {model})")
+
+ return outline_data
+
+
+# ========== 调用 gpt 生成完整 outline ==========
+def generate_complete_outline(
+ selected_node_path,
+ outline_path,
+ generate_complete_outline_prompt,
+ model,
+ config
+):
+ """
+ 逐个 selected_node 调用 LLM,生成 outline 节点并追加到 outline.json
+ 适配 OpenAI 和 Google Gemini (新版 google-genai SDK)
+
+ 输入:
+ selected_node_path: selected_node.json 路径
+ outline_path: outline.json 路径
+ generate_complete_outline_prompt: 给 LLM 的 prompt(字符串)
+ model: 模型名称 (例如 "gpt-4o" 或 "gemini-2.0-flash")
+ config: 配置字典
+
+ 输出:
+ 更新后的 outline.json
+ 返回 outline(list)
+ """
+
+ # --- load selected_node.json ---
+ if not os.path.exists(selected_node_path):
+ raise FileNotFoundError(f"selected_node.json not found: {selected_node_path}")
+
+ with open(selected_node_path, "r", encoding="utf-8") as f:
+ selected_nodes = json.load(f)
+
+ if not isinstance(selected_nodes, list):
+ raise ValueError("selected_node.json must be a list")
+
+ # --- load outline.json ---
+ if not os.path.exists(outline_path):
+ raise FileNotFoundError(f"outline.json not found: {outline_path}")
+
+ with open(outline_path, "r", encoding="utf-8") as f:
+ outline_data = json.load(f)
+
+ if not isinstance(outline_data, list):
+ raise ValueError("outline.json must be a list")
+
+ # --- Initialize Client based on model ---
+ is_gemini = "gemini" in model.lower()
+ client = None
+ system_prompt = "You are an expert academic presentation outline generator."
+
+ if is_gemini:
+ api_key = config['api_keys'].get('gemini_api_key')
+
+ client = genai.Client(api_key=api_key)
+ else:
+ api_key = config['api_keys'].get('openai_api_key')
+ client = OpenAI(api_key=api_key)
+
+ # --- iterate selected nodes ---
+ for idx, node in enumerate(selected_nodes):
+
+ payload = {
+ "name": node.get("name"),
+ "content": node.get("content"),
+ "visual_node": node.get("visual_node", [])
+ }
+
+ payload_text = json.dumps(payload, ensure_ascii=False, indent=2)
+ raw_output = ""
+
+ try:
+ if is_gemini:
+ # --- Gemini Call ---
+ user_content = f"{generate_complete_outline_prompt}\n\nNode Data:\n{payload_text}"
+ response = client.models.generate_content(
+ model=model,
+ contents=user_content,
+ config=types.GenerateContentConfig(
+ system_instruction=system_prompt,
+ temperature=0.0,
+ response_mime_type="application/json"
+ )
+ )
+ raw_output = response.text
+ else:
+ # --- OpenAI Call ---
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": system_prompt
+ },
+ {
+ "role": "user",
+ "content": generate_complete_outline_prompt
+ },
+ {
+ "role": "user",
+ "content": payload_text
+ }
+ ],
+ temperature=0
+ )
+ raw_output = response.choices[0].message.content.strip()
+
+ except Exception as e:
+ print(f"⚠️ Error processing node {idx} ({node.get('name')}): {e}")
+ continue # Skip this node or handle error as needed
+
+ # --- clean JSON ---
+ cleaned = raw_output.strip()
+
+ if cleaned.startswith("```"):
+ cleaned = cleaned.strip("`")
+ if cleaned.lstrip().startswith("json"):
+ cleaned = cleaned.split("\n", 1)[1]
+
+ try:
+ first = cleaned.index("{")
+ last = cleaned.rindex("}")
+ cleaned = cleaned[first:last + 1]
+ except Exception:
+ pass
+
+ try:
+ outline_node = json.loads(cleaned)
+ except json.JSONDecodeError:
+ # print error but maybe continue? strict raise for now
+ raise ValueError(
+ f"LLM output is not valid JSON for selected_node index {idx}:\n{raw_output}"
+ )
+
+ # --- append to outline ---
+ outline_data.append(outline_node)
+
+ # --- save outline.json ---
+ with open(outline_path, "w", encoding="utf-8") as f:
+ json.dump(outline_data, f, indent=4, ensure_ascii=False)
+
+ print(f"✅ Complete outline updated: {outline_path}")
+
+ return outline_data
+
+
+# ========== 调用 LLM 为每一张ppt配模板 ==========
+SlideType = Dict[str, Any]
+OutlineType = List[SlideType]
+JsonType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
+
+def arrange_template(
+ outline_path: str,
+ arrange_template_prompt: str,
+ model: str,
+ config: Dict[str, Any]
+) -> OutlineType:
+ """
+ Read an outline.json, call LLM to choose a PPT template for slides with null template.
+ 适配 OpenAI 和 Google Gemini (新版 google-genai SDK)
+ """
+
+ # --- Client Init ---
+ is_gemini = "gemini" in model.lower()
+ client = None
+
+ if is_gemini:
+ api_key = config['api_keys'].get('gemini_api_key')
+
+ client = genai.Client(api_key=api_key)
+ else:
+ api_key = config['api_keys'].get('openai_api_key')
+
+ client = OpenAI(api_key=api_key)
+
+ # 读取 outline.json
+ with open(outline_path, "r", encoding="utf-8") as f:
+ outline: OutlineType = json.load(f)
+
+ def is_null_template(value: Any) -> bool:
+ """
+ Treat Python None or explicit string 'NULL' / 'null' / ''
+ as empty template that needs to be filled.
+ """
+ if value is None:
+ return True
+ if isinstance(value, str) and value.strip().lower() in {"null", ""}:
+ return True
+ return False
+
+ def select_template_for_slide(slide: SlideType, index: int) -> None:
+ """
+ If slide['template'] is NULL/None, call LLM to select a template.
+ """
+ if not is_null_template(slide.get("template")):
+ return # already has a template, skip
+
+ # 整个 slide 作为 JSON 发给 GPT
+ slide_json_str = json.dumps(slide, ensure_ascii=False, indent=2)
+
+ # 统计信息
+ figures = slide.get("figure", []) or []
+ formulas = slide.get("formula", []) or []
+
+ summary_info = {
+ "slide_index": index,
+ "num_figures": len(figures),
+ "num_formulas": len(formulas),
+ }
+ summary_json_str = json.dumps(summary_info, ensure_ascii=False, indent=2)
+
+ # 构造 User Content
+ user_content = (
+ "Below is one slide node from outline.json.\n"
+ "First, read the raw slide JSON.\n"
+ "Then, use the template selection rules in the system message to choose "
+ "exactly one template for this slide.\n\n"
+ "A small auto-generated summary is also provided to help you:\n"
+ f"Summary:\n```json\n{summary_json_str}\n```\n\n"
+ "Full slide node (JSON):\n```json\n"
+ + slide_json_str
+ + "\n```"
+ )
+
+ content = ""
+
+ try:
+ if is_gemini:
+ # --- Gemini Call ---
+ response = client.models.generate_content(
+ model=model,
+ contents=user_content,
+ config=types.GenerateContentConfig(
+ system_instruction=arrange_template_prompt,
+ temperature=0.0,
+ response_mime_type="application/json"
+ )
+ )
+ content = response.text
+ else:
+ # --- OpenAI Call ---
+ messages = [
+ {
+ "role": "system",
+ "content": arrange_template_prompt,
+ },
+ {
+ "role": "user",
+ "content": user_content,
+ },
+ ]
+ response = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=0.0,
+ )
+ content = (response.choices[0].message.content or "").strip()
+
+ except Exception as e:
+ print(f"[WARN] Failed to call LLM for slide {index}: {e}")
+ return
+
+ # 期望 GPT 返回 JSON:{"template": "T2_ImageRight.html"}
+ template_name: Union[str, None] = None
+
+ # 1) 尝试直接解析为 JSON
+ try:
+ # 去掉可能的代码块包装 ```json ... ```
+ content_for_json = content
+ if "```" in content:
+ parts = content.split("```")
+ # 寻找包含 json 的部分或直接取第二部分
+ if len(parts) > 1:
+ candidate = parts[1]
+ if candidate.lstrip().startswith("json"):
+ candidate = candidate.split("\n", 1)[-1]
+ content_for_json = candidate
+
+ parsed = json.loads(content_for_json)
+
+ if isinstance(parsed, dict) and "template" in parsed:
+ template_name = parsed["template"]
+ elif isinstance(parsed, str):
+ template_name = parsed
+ except Exception:
+ # 2) 如果 JSON 解析失败,当作纯文本处理
+ cleaned = content.strip()
+ if cleaned.startswith('"') and cleaned.endswith('"'):
+ cleaned = cleaned[1:-1].strip()
+ template_name = cleaned or None
+
+ if isinstance(template_name, str) and template_name:
+ slide["template"] = template_name
+ else:
+ print(
+ f"[WARN] Could not parse template from model output for slide {index}, "
+ "leaving 'template' unchanged."
+ )
+
+ # 顶层是一个列表,每个元素是一张 slide
+ if not isinstance(outline, list):
+ raise ValueError("outline.json must be a list of slide nodes at top level.")
+
+ for idx, slide in enumerate(outline):
+ if isinstance(slide, dict):
+ select_template_for_slide(slide, idx)
+
+ # 写回文件
+ with open(outline_path, "w", encoding="utf-8") as f:
+ json.dump(outline, f, ensure_ascii=False, indent=2)
+
+ return outline
+
+
+# ========== 生成最终的PPT ==========
+_MD_IMAGE_RE = re.compile(r"!\[\s*.*?\s*\]\(\s*([^)]+?)\s*\)")
+def _extract_md_image_path(name_field: str) -> str:
+ """
+ Extracts relative image path from a markdown image string like:
+ '' -> 'images/abc.jpg'
+ If not markdown format, returns the original string stripped.
+ """
+ if not isinstance(name_field, str):
+ return ""
+ s = name_field.strip()
+ m = _MD_IMAGE_RE.search(s)
+ if m:
+ return m.group(1).strip()
+ return s
+
+
+def _normalize_node(node: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Normalize node fields and extract clean image paths for figure/formula name fields.
+ """
+ text = node.get("text", "")
+ template = node.get("template", "")
+ figure = node.get("figure", []) or []
+ formula = node.get("formula", []) or []
+
+ def norm_imgs(imgs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ out = []
+ for it in imgs:
+ if not isinstance(it, dict):
+ continue
+ name = it.get("name", "")
+ out.append({
+ "name": name,
+ "path": _extract_md_image_path(name),
+ "caption": it.get("caption", ""),
+ "resolution": it.get("resolution", "")
+ })
+ return out
+
+ return {
+ "text": text if isinstance(text, str) else str(text),
+ "template": template if isinstance(template, str) else str(template),
+ "figure": norm_imgs(figure if isinstance(figure, list) else []),
+ "formula": norm_imgs(formula if isinstance(formula, list) else [])
+ }
+
+def generate_ppt(
+ outline_path: str,
+ ppt_template_path: str,
+ generate_ppt_with_gemini_prompt: Union[Dict[str, str], List[Dict[str, str]]],
+ model: str,
+ config: Dict[str, Any]
+) -> List[str]:
+ """
+ Traverse outline JSON nodes, load corresponding HTML templates, send (prompt + node + template)
+ to LLM (Gemini or OpenAI), then save revised HTML to the outline.json directory.
+
+ Args:
+ outline_path: path to outline json file.
+ ppt_template_path: folder containing html templates.
+ generate_ppt_with_gemini_prompt: JSON-like prompt (dict or list of messages).
+ model: model name (e.g., 'gemini-2.0-flash', 'gpt-4o').
+ config: config dict containing api_keys.
+
+ Returns:
+ List of saved HTML file paths (one per node).
+ """
+
+ # --- Client Init ---
+ is_gemini = "gemini" in model.lower()
+ client = None
+
+ if is_gemini:
+ api_key = config['api_keys'].get('gemini_api_key')
+
+ client = genai.Client(api_key=api_key)
+ else:
+ api_key = config['api_keys'].get('openai_api_key')
+
+ client = OpenAI(api_key=api_key)
+
+ outline_path = os.path.abspath(outline_path)
+ ppt_template_path = os.path.abspath(ppt_template_path)
+
+ if not os.path.isfile(outline_path):
+ raise FileNotFoundError(f"outline_path 不存在或不是文件: {outline_path}")
+ if not os.path.isdir(ppt_template_path):
+ raise NotADirectoryError(f"ppt_template_path 不存在或不是文件夹: {ppt_template_path}")
+
+ with open(outline_path, "r", encoding="utf-8") as f:
+ outline = json.load(f)
+
+ if not isinstance(outline, list):
+ raise ValueError("outline_path 的 JSON 顶层必须是 list(每个元素代表一页 PPT 结点)")
+
+ out_dir = os.path.dirname(outline_path)
+ saved_files: List[str] = []
+
+ # Allow prompt to be either a single message dict or a list of messages.
+ base_messages = []
+ if isinstance(generate_ppt_with_gemini_prompt, dict):
+ base_messages = [generate_ppt_with_gemini_prompt]
+ elif isinstance(generate_ppt_with_gemini_prompt, list):
+ base_messages = generate_ppt_with_gemini_prompt
+ else:
+ raise TypeError("generate_ppt_with_gemini_prompt 必须是 dict 或 list[dict] 的 JSON 形式")
+
+ # Helper to clean node (normalize) - assuming simple dict copy if function missing
+ def _normalize_node(n): return n
+
+ for idx, node in enumerate(outline, start=1):
+ if not isinstance(node, dict):
+ continue
+
+ norm_node = _normalize_node(node)
+ template_file = norm_node.get("template", "").strip()
+
+ # Skip if no template or explicitly null/empty
+ if not template_file or template_file.lower() == "null":
+ continue
+
+ template_full_path = os.path.join(ppt_template_path, template_file)
+ if not os.path.isfile(template_full_path):
+ # raise FileNotFoundError(f"找不到模板文件: {template_full_path}")
+ print(f"⚠️ Template not found: {template_file}, skipping slide {idx}")
+ continue
+
+ with open(template_full_path, "r", encoding="utf-8") as tf:
+ template_html = tf.read()
+
+ user_payload = {
+ "ppt_index": idx,
+ "node": norm_node,
+ "template_html": template_html,
+ }
+
+ # Construct OpenAI-style messages list
+ current_messages = list(base_messages) + [
+ {
+ "role": "user",
+ "content": (
+ "Here is the slide node JSON and the HTML template. "
+ "Revise the HTML per instructions and return ONLY the final HTML code.\n"
+ "Do NOT include markdown fences like ```html ... ```.\n\n"
+ f"{json.dumps(user_payload, ensure_ascii=False)}"
+ ),
+ }
+ ]
+
+ revised_html = ""
+
+ try:
+ if is_gemini:
+ # --- Gemini Call ---
+ # Convert messages list to a single string prompt for Gemini
+ # (or pass list if using chat interface, but generate_content with string is often simpler for 1-turn)
+ prompt_parts = []
+ for m in current_messages:
+ prompt_parts.append(str(m.get("content", "")))
+ final_prompt = "\n\n".join(prompt_parts)
+
+ resp = client.models.generate_content(
+ model=model,
+ contents=final_prompt
+ )
+ revised_html = getattr(resp, "text", str(resp))
+ else:
+ # --- OpenAI Call ---
+ resp = client.chat.completions.create(
+ model=model,
+ messages=current_messages,
+ temperature=0.0
+ )
+ revised_html = resp.choices[0].message.content
+
+ except Exception as e:
+ print(f"⚠️ API Call failed for slide {idx}: {e}")
+ continue
+
+ # Clean output
+ if revised_html:
+ revised_html = revised_html.strip()
+ # Remove markdown fences if present
+ if revised_html.startswith("```"):
+ revised_html = revised_html.strip("`")
+ if revised_html.lstrip().startswith("html"):
+ revised_html = revised_html.split("\n", 1)[1]
+
+ # Save
+ out_name = f"{idx}_ppt.html"
+ out_path = os.path.join(out_dir, out_name)
+ with open(out_path, "w", encoding="utf-8") as wf:
+ wf.write(revised_html)
+
+ saved_files.append(out_path)
+ print(f"✅ Generated: {out_path}")
+
+ return saved_files
\ No newline at end of file
diff --git a/src/DAG2pr.py b/src/DAG2pr.py
new file mode 100644
index 0000000000000000000000000000000000000000..051f75e3c3145f409c7a5f15eecafacf22a09e5e
--- /dev/null
+++ b/src/DAG2pr.py
@@ -0,0 +1,995 @@
+import json
+import os
+import re
+import time
+import shutil
+from pathlib import Path
+from typing import Optional, List, Tuple, Any, Dict
+from google import genai
+from google.genai import types
+from openai import OpenAI
+
+
+# ========== Extract basic information for PR Generation ==========
+def extract_basic_information(
+ dag_path: str,
+ extract_basic_information_prompt: str,
+ model: str,
+ auto_path: str,
+ output_filename: str = "basic_information.txt",
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ config: dict = None
+) -> str:
+ """
+ 读取 dag.json 的第一个 node(root),将其完整信息发给 LLM,提取并输出:
+ Title / Author / Institution / Github
+ 然后把 LLM 输出写入 auto_path/basic_information.txt(不存在则创建)。
+
+ 返回:写入的 txt 的绝对路径(str)
+ """
+ dag_file = Path(dag_path).expanduser().resolve()
+ auto_dir = Path(auto_path).expanduser().resolve()
+ out_path = auto_dir / output_filename
+
+ if not dag_file.exists() or not dag_file.is_file():
+ raise FileNotFoundError(f"dag_path not found or not a file: {dag_file}")
+
+ auto_dir.mkdir(parents=True, exist_ok=True)
+
+ # 1) load dag.json
+ with dag_file.open("r", encoding="utf-8") as f:
+ dag = json.load(f)
+
+ nodes = dag.get("nodes", [])
+ if not isinstance(nodes, list) or len(nodes) == 0 or not isinstance(nodes[0], dict):
+ raise ValueError("Invalid dag.json: missing or invalid 'nodes[0]' root node.")
+
+ root_node = nodes[0]
+
+ # 2) build LLM input (send the entire root node)
+ root_payload = {
+ "name": root_node.get("name", ""),
+ "content": root_node.get("content", ""),
+ "github": root_node.get("github", ""),
+ "edge": root_node.get("edge", []),
+ "level": root_node.get("level", 0),
+ "visual_node": root_node.get("visual_node", []),
+ }
+
+ user_message = (
+ "ROOT_NODE_JSON:\n"
+ + json.dumps(root_payload, ensure_ascii=False, indent=2)
+ )
+
+ # 3) call LLM
+ llm_text = ""
+ api_keys_config = config.get("api_keys", {}) if config else {}
+
+ # 判别平台
+ is_gemini = "gemini" in model.lower()
+
+ if is_gemini:
+ # === Gemini Client Setup ===
+ api_key = api_keys_config.get("gemini_api_key") or os.getenv("GOOGLE_API_KEY")
+
+ client = genai.Client(api_key=api_key)
+
+ # Gemini Call
+ resp = client.models.generate_content(
+ model=model,
+ contents=user_message,
+ config={
+ "system_instruction": extract_basic_information_prompt,
+ "temperature": 0,
+ }
+ )
+ if resp.text:
+ llm_text = resp.text
+
+ else:
+ # === OpenAI Client Setup ===
+ api_key = api_keys_config.get("openai_api_key") or os.getenv("OPENAI_API_KEY")
+
+ client = OpenAI(api_key=api_key)
+
+ # OpenAI Call
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": extract_basic_information_prompt},
+ {"role": "user", "content": user_message},
+ ],
+ temperature=0,
+ )
+ llm_text = resp.choices[0].message.content or ""
+
+ if not llm_text:
+ raise RuntimeError("LLM returned empty content for basic information extraction.")
+
+ # 4) write to auto/basic_information.txt
+ with out_path.open("w", encoding="utf-8") as f:
+ f.write(llm_text.strip() + "\n")
+
+ return str(out_path)
+
+def initialize_pr_markdown(
+ basic_information_path: str,
+ auto_path: str,
+ pr_template_path: str
+) -> str:
+ """
+ basic_information_path: txt 文件路径(包含 Title/Author/Institution/Github)
+ auto_path: 输出目录
+ pr_template_path: PR markdown 模板路径
+ 输出:auto_path/markdown.md
+ """
+
+ info_txt_path = Path(basic_information_path)
+ auto_dir = Path(auto_path)
+ template_md = Path(pr_template_path)
+
+ if not info_txt_path.exists():
+ raise FileNotFoundError(f"basic_information_path not found: {basic_information_path}")
+ if info_txt_path.suffix.lower() != ".txt":
+ raise ValueError(f"basic_information_path must be a .txt file, got: {basic_information_path}")
+ if not template_md.exists():
+ raise FileNotFoundError(f"pr_template_path not found: {pr_template_path}")
+
+ auto_dir.mkdir(parents=True, exist_ok=True)
+
+ # -------- 1) 读取并解析 txt --------
+ txt = info_txt_path.read_text(encoding="utf-8", errors="ignore")
+
+ def _extract_value(key: str) -> str:
+ pattern = re.compile(
+ rf"^\s*{re.escape(key)}\s*:\s*(.*?)\s*$",
+ re.IGNORECASE | re.MULTILINE
+ )
+ m = pattern.search(txt)
+ return m.group(1).strip() if m else ""
+
+ title = _extract_value("Title")
+ authors = _extract_value("Author") or _extract_value("Authors")
+ institution = _extract_value("Institution")
+ github = _extract_value("Github") or _extract_value("GitHub") or _extract_value("Direct Link")
+
+ # -------- 2) 复制模板到 auto_path/markdown.md --------
+ out_md_path = auto_dir / "markdown.md"
+ shutil.copyfile(template_md, out_md_path)
+
+ md = out_md_path.read_text(encoding="utf-8", errors="ignore")
+
+ # 关键:统一换行,避免 \r\n 导致整行匹配失败
+ md = md.replace("\r\n", "\n").replace("\r", "\n")
+
+ # -------- 3) 按行替换(更鲁棒)--------
+ def _fill_line_by_anchor(md_text: str, anchor_label: str, value: str) -> str:
+ """
+ anchor_label: 例如 "Authors",用于匹配 **Authors**
+ 将包含 **anchor_label** 的那一行,改写为:
+ <该行中 **anchor_label** 及其之前部分>:
+ 例如:
+ ✍️ **Authors** -> ✍️ **Authors:** xxx
+ 🏛️**Institution** -> 🏛️**Institution:** yyy
+ """
+ if not value:
+ return md_text
+
+ anchor = f"**{anchor_label}**"
+ lines = md_text.split("\n")
+ replaced = False
+
+ for i, line in enumerate(lines):
+ if anchor in line:
+ # 保留锚点之前的所有内容 + 锚点本身
+ left = line.split(anchor, 1)[0] + anchor
+ lines[i] = f"{left}: {value}"
+ replaced = True
+ break
+
+ return "\n".join(lines) if replaced else md_text
+
+ md = _fill_line_by_anchor(md, "Authors", authors)
+ md = _fill_line_by_anchor(md, "Direct Link", github)
+ md = _fill_line_by_anchor(md, "Paper Title", title)
+ md = _fill_line_by_anchor(md, "Institution", institution)
+
+ out_md_path.write_text(md, encoding="utf-8")
+
+ return str(out_md_path)
+
+
+def generate_pr_from_dag(
+ dag_path: str,
+ pr_path: str,
+ generate_pr_prompt: str,
+ model: str = "gpt-4o-mini",
+ timeout: int = 120,
+ debug: bool = True,
+ max_retries: int = 5,
+ backoff_base: float = 1.5,
+ max_content_chars: int = 12000,
+ max_visuals: int = 20,
+ config: dict = None
+) -> None:
+ """
+ Generate a PR markdown from dag.json by iterating section nodes and querying an LLM.
+ Supports both OpenAI and Google Gemini APIs using official SDKs.
+ """
+
+ def log(msg: str) -> None:
+ if debug:
+ print(msg)
+
+ # -------------------------
+ # 0) Env: API key & Client Setup
+ # -------------------------
+ if config is None:
+ config = {}
+
+ api_keys_conf = config.get("api_keys", {})
+
+ # Simple heuristic to determine provider based on model name
+ is_gemini = "gemini" in model.lower()
+
+ client = None
+
+ if is_gemini:
+ # Gemini Configuration (New SDK)
+ if genai is None:
+ raise ImportError("Google GenAI SDK not installed. Please run `pip install google-genai`.")
+
+ api_key = api_keys_conf.get("gemini_api_key") or os.getenv("GEMINI_API_KEY", "").strip()
+ if not api_key:
+ raise RuntimeError("API Key not found for Gemini.")
+
+ client = genai.Client(api_key=api_key)
+
+ else:
+ # OpenAI Configuration
+ if OpenAI is None:
+ raise ImportError("OpenAI SDK not installed. Please run `pip install openai`.")
+
+ api_key = api_keys_conf.get("openai_api_key") or os.getenv("OPENAI_API_KEY", "").strip()
+
+ if not api_key:
+ raise RuntimeError("API Key not found for OpenAI.")
+
+ client = OpenAI(
+ api_key=api_key,
+ timeout=timeout,
+ max_retries=0 # We handle retries manually below
+ )
+
+ log(f"[INFO] Using provider = {'Gemini' if is_gemini else 'OpenAI'}")
+ log(f"[INFO] Using model = {model}")
+
+ # -------------------------
+ # 1) Load DAG
+ # -------------------------
+ dag_obj = json.loads(Path(dag_path).read_text(encoding="utf-8"))
+ if isinstance(dag_obj, list):
+ nodes = dag_obj
+ elif isinstance(dag_obj, dict) and isinstance(dag_obj.get("nodes"), list):
+ nodes = dag_obj["nodes"]
+ else:
+ raise ValueError("Unsupported dag.json format (expect list or dict with 'nodes').")
+
+ if not nodes:
+ log("[WARN] dag.json has no nodes; exiting.")
+ return
+
+ root = nodes[0]
+ root_edges = root.get("edge", [])
+ if not isinstance(root_edges, list):
+ log("[WARN] Root node edge is not a list; exiting.")
+ return
+
+ log(f"[INFO] Root node name = {root.get('name')}")
+ log(f"[INFO] Root edges (section names) = {root_edges}")
+
+ name2node: Dict[str, Dict[str, Any]] = {}
+ for n in nodes:
+ nm = n.get("name")
+ if isinstance(nm, str) and nm:
+ name2node[nm] = n
+
+ section_nodes: List[Dict[str, Any]] = []
+ for sec_name in root_edges:
+ if sec_name in name2node:
+ section_nodes.append(name2node[sec_name])
+ else:
+ log(f"[WARN] Section node '{sec_name}' not found in DAG; skipped.")
+
+ log(f"[INFO] Resolved {len(section_nodes)} section nodes.")
+ if not section_nodes:
+ log("[WARN] No section nodes to process; exiting.")
+ return
+
+ # -------------------------
+ # 2) Load PR markdown
+ # -------------------------
+ pr_file = Path(pr_path)
+ pr_text = pr_file.read_text(encoding="utf-8")
+
+ KNOWN_LABELS = ["Key Question", "Brilliant Idea", "Core Methods", "Core Results", "Significance/Impact"]
+
+ if debug:
+ log("[INFO] PR template header scan:")
+ for lab in KNOWN_LABELS:
+ found = bool(re.search(rf"(?mi)^.*\*\*{re.escape(lab)}\*\*.*$", pr_text))
+ log(f" - {lab}: found={found}")
+
+ filled_key_question = False
+ filled_brilliant_idea = False
+ filled_significance = False
+ core_methods_inlined = False
+ core_results_inlined = False
+
+ # -------------------------
+ # 3) LLM call (OpenAI & Gemini SDKs)
+ # -------------------------
+ def chat_complete(prompt: str, payload_meta: str = "") -> str:
+ system_msg = "You are a precise scientific writing assistant."
+
+ last_err: Optional[Exception] = None
+
+ for attempt in range(1, max_retries + 1):
+ try:
+ if is_gemini:
+ # Gemini (Google GenAI SDK)
+ response = client.models.generate_content(
+ model=model,
+ contents=prompt,
+ config=types.GenerateContentConfig(
+ system_instruction=system_msg,
+ temperature=0.4,
+ )
+ )
+ # 检查是否因为安全原因被拦截
+ if not response.text:
+ log(f"[WARN] Gemini response empty (possibly safety blocked) {payload_meta}")
+ return ""
+ return response.text.strip()
+
+ else:
+ # OpenAI (Official SDK)
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": system_msg},
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.4,
+ )
+ content = response.choices[0].message.content
+ return content.strip() if content else ""
+
+ except Exception as e:
+ # 捕获 SDK 抛出的各类异常 (RateLimit, APIError, ConnectionError 等)
+ last_err = e
+ sleep_s = backoff_base ** (attempt - 1)
+ log(f"[LLM-RETRY] attempt {attempt}/{max_retries} error: {repr(e)}; sleep {sleep_s:.2f}s {payload_meta}")
+ time.sleep(sleep_s)
+
+ raise RuntimeError(f"LLM request failed after {max_retries} retries. Last error: {repr(last_err)}")
+
+ # -------------------------
+ # 4) Parse LLM output
+ # -------------------------
+ def parse_llm_output(text: str) -> Dict[str, Any]:
+ out: Dict[str, Any] = {
+ "type": "unknown",
+ "key_question": None,
+ "brilliant_idea": None,
+ "core_methods": None,
+ "core_results": None,
+ "significance": None,
+ "image": None,
+ }
+
+ img = re.search(r'!\[\]\(([^)]+)\)', text)
+ if img:
+ out["image"] = f".strip()})"
+
+ def grab(label: str) -> Optional[str]:
+ m = re.search(
+ rf"(?is)(?:^{re.escape(label)}\s*[::]\s*)(.*?)(?=^\w[\w /-]*\s*[::]|\Z)",
+ text,
+ flags=re.MULTILINE,
+ )
+ return m.group(1).strip() if m else None
+
+ out["key_question"] = grab("Key Question")
+ out["brilliant_idea"] = grab("Brilliant Idea")
+ out["core_methods"] = grab("Core Methods")
+ out["core_results"] = grab("Core Results")
+ out["significance"] = grab("Significance/Impact")
+
+ if out["key_question"] or out["brilliant_idea"]:
+ out["type"] = "intro"
+ elif out["core_methods"]:
+ out["type"] = "methods"
+ elif out["core_results"]:
+ out["type"] = "results"
+ elif out["significance"]:
+ out["type"] = "impact"
+
+ return out
+
+ # -------------------------
+ # 5) PR template helpers (emoji-safe + APPEND inside block)
+ # -------------------------
+ def _find_header_line(md: str, label: str) -> Optional[re.Match]:
+ return re.search(rf"(?mi)^(?P.*\*\*{re.escape(label)}\*\*.*)$", md)
+
+ def has_header(md: str, label: str) -> bool:
+ found = _find_header_line(md, label) is not None
+ log(f"[CHECK] Header '{label}' found = {found}")
+ return found
+
+ def set_inline(md: str, label: str, text: str) -> str:
+ m = _find_header_line(md, label)
+ if not m:
+ log(f"[SKIP] set_inline: header '{label}' not found")
+ return md
+
+ line = m.group("line")
+ token_pat = re.compile(rf"\*\*{re.escape(label)}\*\*", re.I)
+ tm = token_pat.search(line)
+ if not tm:
+ log(f"[SKIP] set_inline: token '**{label}**' not found in line")
+ return md
+
+ prefix = line[:tm.end()] # includes **Label**
+ new_line = f"{prefix}: {text}".rstrip()
+
+ start, end = m.start("line"), m.end("line")
+ return md[:start] + new_line + md[end:]
+
+ def _next_header_pos(md: str, start_pos: int) -> int:
+ tail = md[start_pos:]
+ next_pos = None
+ for lab in KNOWN_LABELS:
+ mm = re.search(rf"(?mi)^\s*.*\*\*{re.escape(lab)}\*\*.*$", tail)
+ if mm:
+ cand = start_pos + mm.start()
+ if next_pos is None or cand < next_pos:
+ next_pos = cand
+ return next_pos if next_pos is not None else len(md)
+
+ def append_to_label_block(md: str, label: str, insertion_lines: List[str]) -> str:
+ m = _find_header_line(md, label)
+ if not m:
+ log(f"[SKIP] append_to_label_block: header '{label}' not found")
+ return md
+
+ insertion = "\n".join([ln for ln in insertion_lines if ln and ln.strip()]).rstrip()
+ if not insertion:
+ log(f"[SKIP] append_to_label_block: empty insertion for '{label}'")
+ return md
+
+ hdr_end = m.end("line")
+ block_end = _next_header_pos(md, hdr_end)
+
+ before = md[:hdr_end]
+ middle = md[hdr_end:block_end]
+ after = md[block_end:]
+
+ if not before.endswith("\n"):
+ before += "\n"
+ if middle and not middle.startswith("\n"):
+ middle = "\n" + middle
+
+ middle_stripped_right = middle.rstrip("\n")
+ if middle_stripped_right.strip() == "":
+ new_middle = "\n" + insertion + "\n"
+ else:
+ new_middle = middle_stripped_right + "\n\n" + insertion + "\n"
+
+ return before + new_middle + after
+
+ # -------------------------
+ # 6) Main loop over sections
+ # -------------------------
+ for idx, sec in enumerate(section_nodes):
+ log("\n" + "=" * 90)
+ log(f"[SECTION {idx}] name = {sec.get('name')}")
+
+ content = sec.get("content", "")
+ if not isinstance(content, str):
+ content = ""
+
+ visuals = sec.get("visual_node", [])
+ if not isinstance(visuals, list):
+ visuals = []
+
+ visuals_norm: List[str] = []
+ for v in visuals:
+ if isinstance(v, str):
+ visuals_norm.append(v)
+ elif isinstance(v, dict):
+ for k in ("path", "src", "url", "md", "markdown"):
+ if k in v and isinstance(v[k], str):
+ visuals_norm.append(v[k])
+ break
+
+ if len(content) > max_content_chars:
+ content = content[:max_content_chars] + "\n...(truncated)"
+
+ if len(visuals_norm) > max_visuals:
+ visuals_norm = visuals_norm[:max_visuals]
+
+ node_obj = {
+ "name": sec.get("name", ""),
+ "content": content,
+ "visual_node": visuals_norm,
+ }
+
+ if debug:
+ log(f"[PAYLOAD] content_chars={len(node_obj['content'])} visuals={len(node_obj['visual_node'])}")
+ preview = dict(node_obj)
+ if len(preview["content"]) > 800:
+ preview["content"] = preview["content"][:800] + "...(truncated)"
+ log("[SEND TO LLM] Node payload preview:")
+ print(json.dumps(preview, indent=2, ensure_ascii=False))
+
+ prompt = generate_pr_prompt.replace("{NODE_JSON}", json.dumps(node_obj, ensure_ascii=False))
+ llm_text = chat_complete(prompt, payload_meta=f"(section_idx={idx}, section_name={node_obj['name']})")
+
+ log("\n[LLM RAW OUTPUT]")
+ print(llm_text)
+
+ parsed = parse_llm_output(llm_text)
+ log("\n[PARSED OUTPUT]")
+ print(json.dumps(parsed, indent=2, ensure_ascii=False))
+
+ if parsed["type"] == "intro":
+ log("[TYPE] Introduction-like")
+ kq = parsed.get("key_question")
+ bi = parsed.get("brilliant_idea")
+ img = parsed.get("image")
+
+ if kq and not filled_key_question:
+ if has_header(pr_text, "Key Question"):
+ pr_text = set_inline(pr_text, "Key Question", kq)
+ filled_key_question = True
+ log("[WRITE] Key Question filled (first time).")
+ else:
+ log("[MISS] PR has no Key Question header.")
+ else:
+ log("[SKIP] Key Question ignored (empty or already filled).")
+
+ if bi and not filled_brilliant_idea:
+ if has_header(pr_text, "Brilliant Idea"):
+ pr_text = set_inline(pr_text, "Brilliant Idea", bi)
+ if img:
+ pr_text = append_to_label_block(pr_text, "Brilliant Idea", [img])
+ filled_brilliant_idea = True
+ log("[WRITE] Brilliant Idea filled (first time).")
+ else:
+ log("[MISS] PR has no Brilliant Idea header.")
+ else:
+ log("[SKIP] Brilliant Idea ignored (empty or already filled).")
+
+ elif parsed["type"] == "methods":
+ log("[TYPE] Methods-like")
+ cm = parsed.get("core_methods")
+ img = parsed.get("image")
+
+ if not cm:
+ log("[SKIP] Core Methods empty.")
+ continue
+ if not has_header(pr_text, "Core Methods"):
+ log("[MISS] PR has no Core Methods header.")
+ continue
+
+ if not core_methods_inlined:
+ pr_text = set_inline(pr_text, "Core Methods", cm)
+ core_methods_inlined = True
+ log("[WRITE] Core Methods inlined (first time).")
+ if img:
+ pr_text = append_to_label_block(pr_text, "Core Methods", [img])
+ log("[WRITE] Core Methods image appended.")
+ else:
+ lines = [cm] + ([img] if img else [])
+ pr_text = append_to_label_block(pr_text, "Core Methods", lines)
+ log("[WRITE] Core Methods appended as (text, image) pair.")
+
+ elif parsed["type"] == "results":
+ log("[TYPE] Results-like")
+ cr = parsed.get("core_results")
+ img = parsed.get("image")
+
+ if not cr:
+ log("[SKIP] Core Results empty.")
+ continue
+ if not has_header(pr_text, "Core Results"):
+ log("[MISS] PR has no Core Results header.")
+ continue
+
+ if not core_results_inlined:
+ pr_text = set_inline(pr_text, "Core Results", cr)
+ core_results_inlined = True
+ log("[WRITE] Core Results inlined (first time).")
+ if img:
+ pr_text = append_to_label_block(pr_text, "Core Results", [img])
+ log("[WRITE] Core Results image appended.")
+ else:
+ lines = [cr] + ([img] if img else [])
+ pr_text = append_to_label_block(pr_text, "Core Results", lines)
+ log("[WRITE] Core Results appended as (text, image) pair.")
+
+ elif parsed["type"] == "impact":
+ log("[TYPE] Impact-like")
+ si = parsed.get("significance")
+ if not si:
+ log("[SKIP] Significance/Impact empty.")
+ continue
+ if filled_significance:
+ log("[SKIP] Significance/Impact ignored (already filled).")
+ continue
+ if has_header(pr_text, "Significance/Impact"):
+ pr_text = set_inline(pr_text, "Significance/Impact", si)
+ filled_significance = True
+ log("[WRITE] Significance/Impact filled (first time).")
+ else:
+ log("[MISS] PR has no Significance/Impact header.")
+
+ else:
+ log("[WARN] Unknown section type; ignored.")
+
+ # -------------------------
+ # 7) Save
+ # -------------------------
+ pr_file.write_text(pr_text, encoding="utf-8")
+ log("\n[SAVED] PR markdown updated in-place.")
+
+
+# ================================增加标题和hashtag=================================
+def add_title_and_hashtag(pr_path: str, add_title_and_hashtag_prompt: str, model: str = "gpt-4o-mini", config: dict = None) -> None:
+ """
+ 1) Read markdown from pr_path.
+ 2) Send it to LLM using add_title_and_hashtag_prompt (expects {MD_TEXT} placeholder).
+ 3) Parse LLM output.
+ 4) Update file in-place.
+ """
+
+ # 确保 config 不为空
+ if config is None:
+ config = {}
+ api_keys = config.get('api_keys', {})
+
+ # -------------------------
+ # 0) Read markdown
+ # -------------------------
+ pr_file = Path(pr_path)
+ if not pr_file.exists():
+ raise FileNotFoundError(f"pr_path not found: {pr_path}")
+
+ md_text = pr_file.read_text(encoding="utf-8")
+
+ # -------------------------
+ # 1) Call LLM (Modified for Dual Support)
+ # -------------------------
+ prompt_content = add_title_and_hashtag_prompt.replace("{MD_TEXT}", md_text)
+ system_instruction = "You are a precise scientific social media copywriter."
+ llm_out = ""
+
+ # >>> 分支判断逻辑 >>>
+ if "gemini" in model.lower():
+ # --- Google Gemini (New SDK) ---
+ api_key = api_keys.get("gemini_api_key", "").strip()
+
+ if not api_key:
+ raise ValueError("Missing config['api_keys']['gemini_api_key']")
+
+ try:
+ from google import genai
+ from google.genai import types
+ except ImportError as e:
+ raise ImportError("google-genai package is required. Install with: pip install google-genai") from e
+
+ # 配置 client
+ client = genai.Client(api_key=api_key)
+
+ try:
+ response = client.models.generate_content(
+ model=model,
+ contents=prompt_content,
+ config=types.GenerateContentConfig(
+ system_instruction=system_instruction,
+ temperature=0.6
+ )
+ )
+ llm_out = response.text
+ except Exception as e:
+ raise RuntimeError(f"Gemini API call failed: {e}")
+
+ else:
+ # --- OpenAI (Existing) ---
+ api_key = api_keys.get("openai_api_key", "").strip()
+
+ if not api_key:
+ # 兼容旧逻辑,如果没有传 config,尝试读环境变量(可选,视你需求而定)
+ api_key = os.getenv("OPENAI_API_KEY", "").strip()
+
+ if not api_key:
+ raise ValueError("Missing config['api_keys']['openai_api_key']")
+
+ try:
+ from openai import OpenAI
+ except ImportError as e:
+ raise ImportError("openai package is required. Install with: pip install openai") from e
+
+ client = OpenAI(api_key=api_key)
+
+ try:
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": system_instruction},
+ {"role": "user", "content": prompt_content},
+ ],
+ temperature=0.6,
+ )
+ llm_out = (resp.choices[0].message.content or "").strip()
+ except Exception as e:
+ raise RuntimeError(f"OpenAI API call failed: {e}")
+
+ if not llm_out:
+ raise ValueError("LLM returned empty content.")
+
+ # -------------------------
+ # 2) Parse LLM output (Remaining logic unchanged)
+ # -------------------------
+ # 注意:这里假设 _parse_title_and_tags 和 _line_ending 已经在外部定义或在此作用域内可用
+ title, specific_tags, community_tag = _parse_title_and_tags(llm_out)
+
+ # -------------------------
+ # 3) Update first line: "# {Title}"
+ # -------------------------
+ lines = md_text.splitlines(True)
+ if not lines:
+ raise ValueError("Markdown file is empty.")
+
+ first_line = lines[0].rstrip("\r\n")
+ # 辅助函数:如果原代码中 _line_ending 未定义,需自行补充。此处沿用原逻辑。
+ # 假设 _line_ending(s) 返回 s 的换行符
+
+ def _local_line_ending(s):
+ if s.endswith("\r\n"): return "\r\n"
+ if s.endswith("\n"): return "\n"
+ return "\n" # default
+
+ current_ending = _local_line_ending(lines[0])
+
+ if re.fullmatch(r"#\s*", first_line):
+ lines[0] = f"# 🔥{title}😯{current_ending}"
+ else:
+ if first_line.startswith("#"):
+ lines[0] = re.sub(r"^#\s*.*$", f"# 🔥{title}😯", first_line) + current_ending
+ else:
+ lines.insert(0, f"# 🔥{title}😯\n")
+
+ updated = "".join(lines)
+
+ # -------------------------
+ # 4) Replace "Specific:" line
+ # -------------------------
+ def _replace_specific_line(match: re.Match) -> str:
+ prefix = match.group(1)
+ tail = match.group(2) or ""
+ has_semicolon = ";" in tail
+ end = ";" if has_semicolon else ""
+ return f"{prefix}{specific_tags[0]} {specific_tags[1]} {specific_tags[2]}{end}"
+
+ updated, n1 = re.subn(
+ r"(?mi)^(Specific:\s*)(.*)$",
+ lambda m: _replace_specific_line(m),
+ updated,
+ count=1,
+ )
+
+ # -------------------------
+ # 5) Replace "Community:" line
+ # -------------------------
+ updated, n2 = re.subn(
+ r"(Community:\s*)(#[^\s;]+)",
+ lambda m: f"{m.group(1)}{community_tag}",
+ updated,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+
+ if n1 == 0:
+ raise ValueError("Could not find a line starting with 'Specific:' to replace.")
+ if n2 == 0:
+ raise ValueError("Could not find the 'Community: #Tag1' pattern to replace.")
+
+ # -------------------------
+ # 6) Write back
+ # -------------------------
+ pr_file.write_text(updated, encoding="utf-8")
+
+
+def _parse_title_and_tags(llm_out: str) -> Tuple[str, List[str], str]:
+ """
+ Parse:
+ Title: ...
+ Specific Tag: #A #B #C
+ Community Tag: #X
+ """
+ def pick_line(prefix: str) -> str:
+ m = re.search(rf"^{re.escape(prefix)}\s*(.+)$", llm_out, flags=re.MULTILINE)
+ if not m:
+ raise ValueError(f"LLM output missing line: '{prefix} ...'")
+ return m.group(1).strip()
+
+ title = pick_line("Title:")
+ spec = pick_line("Specific Tag:")
+ comm = pick_line("Community Tag:")
+
+ spec_tags = re.findall(r"#[A-Za-z0-9_]+", spec)
+ comm_tags = re.findall(r"#[A-Za-z0-9_]+", comm)
+
+ if len(spec_tags) != 3:
+ raise ValueError(f"Expected exactly 3 specific tags, got {len(spec_tags)}. Raw: {spec}")
+ if len(comm_tags) != 1:
+ raise ValueError(f"Expected exactly 1 community tag, got {len(comm_tags)}. Raw: {comm}")
+
+ title = title.strip().strip('"').strip("'")
+ if not title:
+ raise ValueError("Parsed title is empty.")
+
+ return title, spec_tags, comm_tags[0]
+
+
+def _line_ending(original_line: str) -> str:
+ if original_line.endswith("\r\n"):
+ return "\r\n"
+ if original_line.endswith("\n"):
+ return "\n"
+ return "\n"
+
+
+
+
+# ===========================增加机构的tag===============================
+def add_institution_tag(pr_path: str) -> None:
+ """
+ 读取 markdown 中 🏛️**Institution**: 后的所有机构名(鲁棒分割),然后:
+ 1) 将所有机构名拼成一行,写入 Strategic Mentions:# 后面,例如:
+ Strategic Mentions:# NVIDIA, Tel Aviv University
+ 2) 在 Strategic Mentions 这一行的下一行追加一行:
+ @NVIDIA, Tel Aviv University
+ 原地修改文件。
+ """
+
+ md_path = Path(pr_path)
+ if not md_path.exists():
+ raise FileNotFoundError(f"Markdown file not found: {pr_path}")
+
+ text = md_path.read_text(encoding="utf-8")
+
+ # -------------------------------------------------
+ # 1) 提取 Institution 行内容
+ # -------------------------------------------------
+ institution_pattern = re.compile(r"🏛️\s*\*\*Institution\*\*\s*:\s*(.+)")
+ m = institution_pattern.search(text)
+ if not m:
+ raise ValueError("Institution section not found in markdown.")
+
+ institution_raw = m.group(1).strip()
+
+ # -------------------------------------------------
+ # 2) 鲁棒切分:提取所有机构名
+ # 支持中英文分号/逗号/顿号/竖线/斜杠/换行/制表符等
+ # -------------------------------------------------
+ split_pattern = re.compile(r"\s*(?:;|;|,|,|、|\||||/|\n|\t)\s*")
+ parts = [p.strip() for p in split_pattern.split(institution_raw) if p.strip()]
+ if not parts:
+ raise ValueError("Institution content is empty after parsing.")
+
+ # 去除常见尾部标点噪声,并保持原顺序去重
+ seen = set()
+ institutions = []
+ for p in parts:
+ p = p.strip(" .。")
+ if p and p not in seen:
+ institutions.append(p)
+ seen.add(p)
+
+ if not institutions:
+ raise ValueError("No valid institution names parsed.")
+
+ institutions_str = ", ".join(institutions)
+
+ # -------------------------------------------------
+ # 3) 替换 Strategic Mentions:# 行,并在其下一行插入 @...
+ # 仅处理第一次出现
+ # -------------------------------------------------
+ strategic_line_pattern = re.compile(r"^(Strategic Mentions:\s*#).*$", re.MULTILINE)
+
+ def repl(mm: re.Match) -> str:
+ prefix = mm.group(1)
+ return f"{prefix} {institutions_str}\n@{institutions_str}"
+
+ new_text, n = strategic_line_pattern.subn(repl, text, count=1)
+ if n == 0:
+ raise ValueError("Strategic Mentions section not found in markdown.")
+
+ md_path.write_text(new_text, encoding="utf-8")
+
+
+# =======================删除重复图片引用======================
+def dedup_consecutive_markdown_images(md_path: str, inplace: bool = True) -> Tuple[str, int]:
+ """
+ 去除 Markdown 中“连续出现”的相同图片引用(按 src 判重),只保留一个。
+ 连续的定义:两张图片之间只要是空白字符(空格/Tab/换行)也视为连续;
+ 若中间出现任何非空白内容(文字、代码、列表符号等),则不算连续。
+
+ 支持:
+ 
+ 
+ 
+
+ 返回:
+ (new_text, removed_count)
+ """
+
+ p = Path(md_path)
+ text = p.read_text(encoding="utf-8")
+
+ img_pat = re.compile(
+ r'!\[(?P[^\]]*)\]\((?P[^)\s]+)(?:\s+"[^"]*")?\)',
+ flags=re.MULTILINE
+ )
+
+ parts = []
+ last = 0
+ for m in img_pat.finditer(text):
+ if m.start() > last:
+ parts.append(("text", text[last:m.start()]))
+ parts.append(("img", m.group(0), m.group("src")))
+ last = m.end()
+ if last < len(text):
+ parts.append(("text", text[last:]))
+
+ removed = 0
+ new_parts = []
+
+ prev_img_src = None
+ # 只有遇到“非空白文本”才会真正打断连续性
+ saw_non_whitespace_since_prev_img = True # 初始视为打断状态
+
+ for part in parts:
+ if part[0] == "img":
+ _, raw_img, src = part
+
+ # 连续重复判定:上一张图存在 + 中间没有非空白内容 + src 相同
+ if prev_img_src is not None and (not saw_non_whitespace_since_prev_img) and prev_img_src == src:
+ removed += 1
+ # 删除该图片(不追加)
+ continue
+
+ # 保留该图片
+ new_parts.append(raw_img)
+ prev_img_src = src
+ saw_non_whitespace_since_prev_img = False
+
+ else:
+ _, raw_text = part
+ new_parts.append(raw_text)
+
+ # 只有出现非空白,才视为打断连续图片序列
+ if raw_text.strip() != "":
+ saw_non_whitespace_since_prev_img = True
+
+ new_text = "".join(new_parts)
+
+ if inplace and removed > 0:
+ p.write_text(new_text, encoding="utf-8")
+
+ return new_text, removed
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2563a80fb93c0abe71a4de4deb0e19590719fcdf
--- /dev/null
+++ b/src/__init__.py
@@ -0,0 +1,34 @@
+from .paper2DAG import clean_paper
+from .paper2DAG import split_paper
+from .paper2DAG import initialize_dag
+from .paper2DAG import extract_and_generate_visual_dag
+from .paper2DAG import add_resolution_to_visual_dag
+from .paper2DAG import build_section_dags
+from .paper2DAG import add_section_dag
+from .paper2DAG import add_visual_dag
+from .paper2DAG import add_section_dag
+from .paper2DAG import refine_visual_node
+
+from .DAG2ppt import generate_selected_nodes
+from .DAG2ppt import outline_initialize
+from .DAG2ppt import generate_complete_outline
+from .DAG2ppt import arrange_template
+from .DAG2ppt import generate_ppt
+
+from .DAG2poster import generate_poster_outline_txt
+from .DAG2poster import modify_poster_outline
+from .DAG2poster import build_poster_from_outline
+from .DAG2poster import modify_title_and_author
+from .DAG2poster import inject_img_section_to_poster
+from .DAG2poster import modified_poster_logic
+
+from .DAG2pr import extract_basic_information
+from .DAG2pr import initialize_pr_markdown
+from .DAG2pr import generate_pr_from_dag
+from .DAG2pr import add_title_and_hashtag
+from .DAG2pr import add_institution_tag
+from .DAG2pr import dedup_consecutive_markdown_images
+
+from .refinement.refinement import refinement_ppt
+from .refinement.refinement import refinement_poster
+from .refinement.refinement import refinement_pr
\ No newline at end of file
diff --git a/src/__pycache__/DAG2poster.cpython-310.pyc b/src/__pycache__/DAG2poster.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6e355d7244560d057f3fbaed55e25512c35036b
Binary files /dev/null and b/src/__pycache__/DAG2poster.cpython-310.pyc differ
diff --git a/src/__pycache__/DAG2ppt.cpython-310.pyc b/src/__pycache__/DAG2ppt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7e7f5ead7341440dc9add895752a6660fecbd62
Binary files /dev/null and b/src/__pycache__/DAG2ppt.cpython-310.pyc differ
diff --git a/src/__pycache__/DAG2pr.cpython-310.pyc b/src/__pycache__/DAG2pr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48bf372c4a1aee728fbc7350e1260ddcab5a07fb
Binary files /dev/null and b/src/__pycache__/DAG2pr.cpython-310.pyc differ
diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5438c8e419ff2a6fa342b2dfb92ecd23a403f40a
Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/__pycache__/__init__.cpython-312.pyc b/src/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2f3c977e42c02095f7d113435f284727845bb11
Binary files /dev/null and b/src/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src/__pycache__/paper2DAG.cpython-310.pyc b/src/__pycache__/paper2DAG.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d93d07d257464cebba070e8e03ce5b003c77a116
Binary files /dev/null and b/src/__pycache__/paper2DAG.cpython-310.pyc differ
diff --git a/src/__pycache__/paper2DAG.cpython-312.pyc b/src/__pycache__/paper2DAG.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e649eb6aef3426e2dec341dd7bc2c530d17a7cdf
Binary files /dev/null and b/src/__pycache__/paper2DAG.cpython-312.pyc differ
diff --git a/src/paper2DAG.py b/src/paper2DAG.py
new file mode 100644
index 0000000000000000000000000000000000000000..10d3af7da879bc6047bb51545181bd927d9d1cf2
--- /dev/null
+++ b/src/paper2DAG.py
@@ -0,0 +1,1056 @@
+import json
+import os
+import re
+import base64
+from PIL import Image
+from typing import Optional
+from openai import OpenAI
+from google import genai
+from google.genai import types
+
+
+# ========== 调用 Gemini 删除无用段落 ==========
+def clean_paper(markdown_path, clean_prompt, model, config):
+ """
+ 使用 Google Gemini 清理论文 Markdown 文件:
+ 删除 Abstract / Related Work / Appendix / References 等部分,
+ 保留标题、作者、Introduction、Methods、Experiments、Conclusion。
+ """
+ # === 初始化 Client ===
+ client = genai.Client(
+ api_key=config['api_keys']['gemini_api_key']
+ )
+
+ # === 读取 markdown 文件 ===
+ with open(markdown_path, "r", encoding="utf-8") as f:
+ md_text = f.read().strip()
+
+ full_prompt = (
+ f"{clean_prompt}\n\n"
+ "=== PAPER MARKDOWN TO CLEAN ===\n"
+ f"\"\"\"{md_text}\"\"\"\n\n"
+ "Return only the cleaned markdown, keeping all formatting identical to the original."
+ )
+
+ print("🧹 Sending markdown to Gemini for cleaning...")
+
+ try:
+ # === 调用 Gemini API (Client 模式) ===
+ resp = client.models.generate_content(
+ model=model,
+ contents=full_prompt,
+ config=types.GenerateContentConfig(
+ temperature=0.0
+ )
+ )
+ cleaned_text = resp.text.strip()
+
+ except Exception as e:
+ print(f"❌ Gemini API Error: {e}")
+ return None
+
+ # === 提取纯 markdown(防止模型返回 ```markdown ``` 块) ===
+ m = re.search(r"```markdown\s*([\s\S]*?)```", cleaned_text)
+ if not m:
+ m = re.search(r"```\s*([\s\S]*?)```", cleaned_text)
+ cleaned_text = m.group(1).strip() if m else cleaned_text
+
+ # === 生成输出文件路径 ===
+ dir_path = os.path.dirname(markdown_path)
+ base_name = os.path.basename(markdown_path)
+ name, ext = os.path.splitext(base_name)
+ output_path = os.path.join(dir_path, f"{name}_cleaned{ext}")
+
+ # === 保存结果 ===
+ with open(output_path, "w", encoding="utf-8") as f:
+ f.write(cleaned_text)
+ print(f"✅ Cleaned markdown saved to: {output_path}")
+
+ return output_path
+
+
+# ========== 调用 Gemini 划分段落 ==========
+SECTION_RE = re.compile(r'^#\s+\d+(\s|$)', re.MULTILINE)
+
+
+def sanitize_filename(name: str) -> str:
+ """移除非法字符:/ \ : * ? \" < > |"""
+ unsafe = r'\/:*?"<>|'
+ return "".join(c for c in name if c not in unsafe).strip()
+
+
+def split_paper(
+ cleaned_md_path: str, # 输入: A/auto/clean_paper.md
+ prompt: str,
+ separator: str = "===SPLIT===",
+ model: str = "gemini-3.0-pro-preview",
+ config: dict = None
+):
+ """
+ 使用 Gemini 拆分论文,并将所有拆分后的 markdown 保存在:
+ /section_split_output/
+ """
+ # 1️⃣ 输入文件所在的 auto 文件夹
+ auto_dir = os.path.dirname(os.path.abspath(cleaned_md_path))
+
+ # 2️⃣ auto 的上一级目录(即 A/)
+ parent_dir = os.path.dirname(auto_dir)
+
+ # 3️⃣ 最终输出目录 A/section_split_output/
+ output_dir = os.path.join(parent_dir, "section_split_output")
+ os.makedirs(output_dir, exist_ok=True)
+
+ # === 读取 markdown ===
+ with open(cleaned_md_path, "r", encoding="utf-8") as f:
+ markdown_text = f.read()
+
+ # === 2. 初始化 Gemini Client ===
+ client = genai.Client(
+ api_key=config['api_keys']['gemini_api_key']
+ )
+
+ # === 提取一级 section 信息供参考 (假设 SECTION_RE 已在外部定义) ===
+ # 注意:确保 SECTION_RE 在此函数作用域内可用
+ section_positions = [(m.start(), m.group()) for m in SECTION_RE.finditer(markdown_text)]
+
+ auto_analysis = "Detected top-level sections:\n"
+ for pos, sec in section_positions:
+ auto_analysis += f"- position={pos}, heading='{sec.strip()}'\n"
+ auto_analysis += "\nThese are for your reference. You MUST still split strictly by the rules.\n"
+
+ # === 构建 prompt ===
+ final_prompt = (
+ prompt
+ + "\n\n---\nBelow is an automatic analysis of top-level sections (for reference only):\n"
+ + auto_analysis
+ + "\n---\nHere is the FULL MARKDOWN PAPER:\n\n"
+ + markdown_text
+ )
+
+ # === 3. Gemini 调用 ===
+ try:
+ response = client.models.generate_content(
+ model=model,
+ contents=final_prompt,
+ config=types.GenerateContentConfig(
+ temperature=0.0
+ )
+ )
+ output_text = response.text
+ except Exception as e:
+ print(f"❌ Gemini Split Error: {e}")
+ return []
+
+ # === 按分隔符拆分 ===
+ # 简单的容错处理,防止模型没有完全按格式输出
+ if not output_text:
+ print("❌ Empty response from Gemini")
+ return []
+
+ chunks = [c.strip() for c in output_text.split(separator) if c.strip()]
+ saved_paths = []
+
+ # === 保存拆分后的 chunks ===
+ for chunk in chunks:
+ lines = chunk.splitlines()
+
+ # 找第一行有效内容
+ first_line = next((ln.strip() for ln in lines if ln.strip()), "")
+
+ # 解析标题
+ if first_line.startswith("#"):
+ title = first_line.lstrip("#").strip()
+ else:
+ title = first_line[:20].strip()
+
+ # 注意:sanitize_filename 需要在外部定义或引入
+ filename = sanitize_filename(title) + ".md"
+ filepath = os.path.join(output_dir, filename)
+
+ # 写文件
+ with open(filepath, "w", encoding="utf-8") as f:
+ f.write(chunk)
+
+ saved_paths.append(filepath)
+
+ print(f"✅ Paper is splitted successfully")
+ return saved_paths
+
+
+# ========== 调用 Gemini 初始化dag.json ==========
+def initialize_dag(markdown_path, initialize_dag_prompt, model, config=None):
+ """
+ 使用 Gemini 初始化论文 DAG。
+
+ 输入:
+ markdown_path: markdown 文件路径
+ initialize_dag_prompt: prompt 字符串
+ model: 模型名称 (建议使用 gemini-2.0-flash 或 pro)
+ config: 包含 api_keys 的配置字典
+
+ 输出:
+ dag.json: 保存在 markdown 文件同目录
+ 返回 python 字典形式的 DAG
+ """
+ # --- load markdown ---
+ if not os.path.exists(markdown_path):
+ raise FileNotFoundError(f"Markdown not found: {markdown_path}")
+
+ with open(markdown_path, "r", encoding="utf-8") as f:
+ md_text = f.read()
+
+ # --- Gemini Client Init ---
+ client = genai.Client(
+ api_key=config['api_keys']['gemini_api_key']
+ )
+
+ # --- Gemini Call ---
+ # 将 Prompt 和 文本合并作为用户输入,System Prompt 放入 config
+ full_content = f"{initialize_dag_prompt}\n\n{md_text}"
+
+ try:
+ response = client.models.generate_content(
+ model=model,
+ contents=full_content,
+ config=types.GenerateContentConfig(
+ system_instruction="You are an expert academic document parser and structural analyzer.",
+ temperature=0.0,
+ response_mime_type="application/json" # <--- 强制输出 JSON 模式
+ )
+ )
+ raw_output = response.text.strip()
+ except Exception as e:
+ print(f"❌ Gemini API Error: {e}")
+ raise e
+
+ # --- Extract JSON (remove possible markdown fences) ---
+ # Gemini 在 JSON 模式下通常只返回纯 JSON,但保留此逻辑以防万一
+ cleaned = raw_output
+
+ # Remove ```json ... ```
+ if cleaned.startswith("```"):
+ cleaned = cleaned.strip("`")
+ if cleaned.lstrip().startswith("json"):
+ cleaned = cleaned.split("\n", 1)[1]
+
+ # Last safety: locate JSON via first { and last }
+ try:
+ first = cleaned.index("{")
+ last = cleaned.rindex("}")
+ cleaned = cleaned[first:last+1]
+ except Exception:
+ pass
+
+ try:
+ dag_data = json.loads(cleaned)
+ except json.JSONDecodeError:
+ print("⚠️ Standard JSON parsing failed. Attempting regex repair for backslashes...")
+ try:
+ # 这里保留原有的重试逻辑 (通常是为了处理转义字符)
+ dag_data = json.loads(cleaned)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Gemini output is not valid JSON:\n{raw_output}")
+
+ # --- Save dag.json ---
+ out_dir = os.path.dirname(markdown_path)
+ out_path = os.path.join(out_dir, "dag.json")
+
+ with open(out_path, "w", encoding="utf-8") as f:
+ json.dump(dag_data, f, indent=4, ensure_ascii=False)
+
+ print(f"✅ DAG saved to: {out_path}")
+
+ return dag_data
+
+
+# ========== 调用 大模型 添加视觉结点 ==========
+def extract_and_generate_visual_dag(
+ markdown_path: str,
+ prompt_for_gpt: str,
+ output_json_path: str,
+ model="gemini-3.0-pro-preview",
+ config=None
+):
+ """
+ 输入:
+ markdown_path: 原论文 markdown 文件路径
+ prompt_for_gpt: 给 GPT 使用的 prompt
+ output_json_path: 生成的 visual_dag.json 存放路径
+ model: 默认 gemini-3.0-pro-preview
+
+ 输出:
+ 生成 visual_dag.json
+ 返回 Python dict
+ """
+ # === 1. 读取 markdown ===
+ if not os.path.exists(markdown_path):
+ raise FileNotFoundError(f"Markdown not found: {markdown_path}")
+
+ with open(markdown_path, "r", encoding="utf-8") as f:
+ md_text = f.read()
+
+ # === 2. 正则提取所有图片相对引用 ===
+ pattern = r"!\[[^\]]*\]\(([^)]+)\)"
+ matches = re.findall(pattern, md_text)
+
+ # 过滤为相对路径(不包含http)
+ relative_imgs = [m for m in matches if not m.startswith("http")]
+
+ # 生成标准格式 name 字段使用的写法 ""
+ normalized_refs = [f"" for m in relative_imgs]
+
+ # === 3. 发送给 Gemini ===
+ # 初始化 Client
+ client = genai.Client(
+ api_key=config['api_keys']['gemini_api_key']
+ )
+
+ gpt_input = prompt_for_gpt + "\n\n" + \
+ "### Extracted Image References:\n" + \
+ json.dumps(normalized_refs, indent=2) + "\n\n" + \
+ "### Full Markdown:\n" + md_text
+
+ try:
+ response = client.models.generate_content(
+ model=model,
+ contents=gpt_input,
+ config=types.GenerateContentConfig(
+ temperature=0.0,
+ response_mime_type="application/json" # 强制 JSON 输出
+ )
+ )
+ visual_dag_str = response.text.strip()
+ except Exception as e:
+ print(f"❌ Gemini API Error: {e}")
+ raise e
+
+ # === JSON 解析兜底修复逻辑(不做任何语义改写) ===
+ def _strip_fenced_code_block(s: str) -> str:
+ s = (s or "").strip()
+ if not s.startswith("```"):
+ return s
+ lines = s.splitlines()
+ if lines and lines[0].strip().startswith("```"):
+ lines = lines[1:]
+ while lines and lines[-1].strip().startswith("```"):
+ lines = lines[:-1]
+ return "\n".join(lines).strip()
+
+ def _sanitize_json_string_minimal(s: str) -> str:
+ s = (s or "")
+ s = s.replace("\r", " ").replace("\n", " ").replace("\t", " ")
+ s = re.sub(r"\s{2,}", " ", s)
+ return s
+
+ # 必须把 _remove_one_offending_backslash 函数定义放回来,否则后面调用会报错
+ def _remove_one_offending_backslash(s: str, err: Exception) -> str:
+ if not isinstance(err, json.JSONDecodeError):
+ return ""
+ msg = str(err)
+ if "Invalid \\escape" not in msg and "Invalid \\u" not in msg:
+ return ""
+ pos = getattr(err, "pos", None)
+ if pos is None or pos <= 0 or pos > len(s):
+ return ""
+ candidates = []
+ if pos < len(s) and s[pos] == "\\":
+ candidates.append(pos)
+ if pos - 1 >= 0 and s[pos - 1] == "\\":
+ candidates.append(pos - 1)
+ start = max(0, pos - 16)
+ window = s[start:pos + 1]
+ last_bs = window.rfind("\\")
+ if last_bs != -1:
+ candidates.append(start + last_bs)
+ seen = set()
+ candidates = [i for i in candidates if not (i in seen or seen.add(i))]
+ for idx in candidates:
+ if 0 <= idx < len(s) and s[idx] == "\\":
+ return s[:idx] + s[idx + 1:]
+ return ""
+
+ # 解析 JSON(要求 GPT/Gemini 返回纯 JSON)
+ try:
+ # Gemini 的 response_mime_type 已经很大程度保证了 json,但保留打印方便调试
+ # print("====== RAW GEMINI OUTPUT ======")
+ # print(visual_dag_str)
+ # print("====== END ======")
+ visual_dag = json.loads(visual_dag_str)
+ except Exception as e1:
+ # 下面的修复逻辑保持原样
+ try:
+ unwrapped = _strip_fenced_code_block(visual_dag_str)
+ fixed_str = _sanitize_json_string_minimal(unwrapped).strip()
+
+ if not fixed_str:
+ raise ValueError("Gemini returned empty/whitespace-only JSON content after repair.")
+
+ try:
+ visual_dag = json.loads(fixed_str)
+ except Exception as e2:
+ working = fixed_str
+ last_err = e2
+ max_backslash_removals = 50
+ removed_times = 0
+
+ while removed_times < max_backslash_removals:
+ new_working = _remove_one_offending_backslash(working, last_err)
+ if not new_working:
+ break
+ working = new_working
+ removed_times += 1
+ try:
+ visual_dag = json.loads(working)
+ fixed_str = working
+ last_err = None
+ break
+ except Exception as e_next:
+ last_err = e_next
+ continue
+
+ if last_err is not None:
+ raise ValueError(
+ "Gemini returned invalid JSON: " + str(e1) +
+ " | After repair still invalid: " + str(e2) +
+ f" | Tried removing offending backslashes up to {max_backslash_removals} times "
+ f"(actually removed {removed_times}) but still invalid: " + str(last_err)
+ )
+
+ except Exception as e_final:
+ if isinstance(e_final, ValueError):
+ raise
+ raise ValueError(
+ "Gemini returned invalid JSON: " + str(e1) +
+ " | After repair still invalid: " + str(e_final)
+ )
+
+ # === 4. 保存 visual_dag.json ===
+ with open(output_json_path, "w", encoding="utf-8") as f:
+ json.dump(visual_dag, f, indent=2, ensure_ascii=False)
+
+ print(f"\n📂 visual_dag.json is generated successfully")
+ return visual_dag
+
+
+# ========== 计算每一个视觉结点的分辨率 ==========
+def add_resolution_to_visual_dag(auto_path, visual_dag_path):
+ """
+ 遍历 visual_dag.json,提取图片路径,计算分辨率并添加到结点属性中。
+
+ Args:
+ auto_path (str): 图片所在的根目录路径。
+ visual_dag_path (str): visual_dag.json 文件的路径。
+
+ Returns:
+ list: 更新后的节点列表。
+ """
+
+ # 1. 读取 JSON 文件
+ try:
+ with open(visual_dag_path, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ except FileNotFoundError:
+ print(f"错误: 找不到文件 {visual_dag_path}")
+ return []
+ except json.JSONDecodeError:
+ print(f"错误: 文件 {visual_dag_path} 不是有效的 JSON 格式")
+ return []
+
+ nodes = data.get("nodes", [])
+
+ # 正则表达式用于匹配  中的 path
+ # 解释: !\[\]\((.*?)\) 匹配  括号内的所有内容
+ pattern = re.compile(r'!\[\]\((.*?)\)')
+
+ for node in nodes:
+ name_str = node.get("name", "")
+
+ # 2. 从 name 中提取路径
+ match = pattern.search(name_str)
+ if match:
+ # 获取括号内的路径部分,例如 images/xxx.jpg
+ relative_image_path = match.group(1)
+
+ # 3. 拼接完整路径
+ full_image_path = os.path.join(auto_path, relative_image_path)
+
+ # 4. 读取图片并计算分辨率
+ try:
+ # 使用 Pillow 打开图片
+ with Image.open(full_image_path) as img:
+ width, height = img.size
+ resolution_str = f"{width}x{height}"
+
+ # 5. 添加 resolution 字段
+ node["resolution"] = resolution_str
+ # print(f"成功处理: {relative_image_path} -> {resolution_str}")
+
+ except FileNotFoundError:
+ print(f"警告: 找不到图片文件 {full_image_path},跳过该节点。")
+ node["resolution"] = "Unknown" # 或者可以选择不添加该字段
+ except Exception as e:
+ print(f"警告: 处理图片 {full_image_path} 时发生错误: {e}")
+ node["resolution"] = "Error"
+ else:
+ print(f"警告: 节点 name 格式不匹配: {name_str}")
+
+ # (可选) 将更新后的数据写回文件,或者另存为新文件
+ # 这里演示将数据写回原文件
+ try:
+ with open(visual_dag_path, 'w', encoding='utf-8') as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
+ print(f"处理完成,已更新文件: {visual_dag_path}")
+ except Exception as e:
+ print(f"保存文件时出错: {e}")
+
+ return nodes
+
+
+# ========== 调用 gemini-3-pro-preview 生成每一个section_dag ==========
+def build_section_dags(
+ folder_path: str,
+ base_prompt: str,
+ model: str = "gemini-3.0-pro-preview", # 建议使用 flash 或 pro
+ config: dict = None
+):
+ """
+ Traverse all markdown files in a folder, send each section to Gemini,
+ and save _dag.json.
+ Includes robust JSON repair and retry logic.
+ """
+ # -----------------------------
+ # Tunables (safe defaults)
+ # -----------------------------
+ ENABLE_FALLBACK_CONTENT_BACKSLASH_STRIP = True
+ FALLBACK_STRIP_BACKSLASH_ONLY_IN_CONTENT = True
+ MAX_RETRIES_ON_FAIL = 2
+
+ # === Init Client (Gemini) ===
+ # 使用 config 中的 key
+ client = genai.Client(
+ api_key=config['api_keys']['gemini_api_key']
+ )
+
+ def build_full_prompt(base_prompt: str, section_name: str, md_text: str) -> str:
+ return (
+ f"{base_prompt}\n\n"
+ "=== SECTION NAME ===\n"
+ f"{section_name}\n\n"
+ "=== SECTION MARKDOWN (FULL) ===\n"
+ f"\"\"\"{md_text}\"\"\""
+ )
+
+ # === Helper Functions (Keep exactly as is) ===
+ def remove_invisible_control_chars(s: str) -> str:
+ if not s: return s
+ s = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", s)
+ s = re.sub(r"[\uFEFF\u200B\u200C\u200D\u2060\u00AD\u061C\u200E\u200F\u202A-\u202E\u2066-\u2069]", "", s)
+ return s
+
+ def sanitize_json_literal_newlines(s: str) -> str:
+ out = []
+ in_string = False
+ escape = False
+ for ch in s:
+ if in_string:
+ if escape:
+ out.append(ch); escape = False
+ else:
+ if ch == '\\': out.append(ch); escape = True
+ elif ch == '"': out.append(ch); in_string = False
+ elif ch in ('\n', '\r', '\t'): out.append(' ')
+ else: out.append(ch)
+ else:
+ if ch == '"': out.append(ch); in_string = True; escape = False
+ else: out.append(ch)
+ return ''.join(out)
+
+ def sanitize_invalid_backslashes_in_strings(s: str) -> str:
+ out = []
+ in_string = False
+ i = 0
+ valid_esc = set(['"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u'])
+ while i < len(s):
+ ch = s[i]
+ if not in_string:
+ out.append(ch)
+ if ch == '"': in_string = True
+ i += 1; continue
+ if ch == '"':
+ out.append(ch); in_string = False
+ i += 1; continue
+ if ch != '\\':
+ out.append(ch); i += 1; continue
+ if i == len(s) - 1:
+ out.append('\\\\'); i += 1; continue
+ nxt = s[i + 1]
+ if nxt in valid_esc:
+ out.append('\\'); out.append(nxt); i += 2
+ else:
+ out.append('\\\\'); out.append(nxt); i += 2
+ return ''.join(out)
+
+ def force_content_single_line(dag_obj):
+ if not isinstance(dag_obj, dict): return dag_obj
+ nodes = dag_obj.get("nodes", None)
+ if not isinstance(nodes, list): return dag_obj
+ for node in nodes:
+ if isinstance(node, dict) and "content" in node and isinstance(node["content"], str):
+ node["content"] = re.sub(r"[\r\n]+", " ", node["content"])
+ return dag_obj
+
+ def fallback_strip_backslashes_in_content(dag_obj):
+ if not isinstance(dag_obj, dict): return dag_obj
+ nodes = dag_obj.get("nodes", None)
+ if not isinstance(nodes, list): return dag_obj
+ for node in nodes:
+ if isinstance(node, dict) and "content" in node and isinstance(node["content"], str):
+ node["content"] = node["content"].replace("\\", "")
+ return dag_obj
+
+ def extract_first_json_object_substring(s: str):
+ start = s.find("{")
+ if start < 0: return None
+ in_string = False; escape = False; depth = 0
+ for i in range(start, len(s)):
+ ch = s[i]
+ if in_string:
+ if escape: escape = False
+ elif ch == "\\": escape = True
+ elif ch == '"': in_string = False
+ else:
+ if ch == '"': in_string = True
+ elif ch == "{": depth += 1
+ elif ch == "}":
+ depth -= 1
+ if depth == 0: return s[start:i + 1]
+ return None
+
+ def robust_load_json(raw: str):
+ raw0 = remove_invisible_control_chars(raw)
+ try: return json.loads(raw0), raw0, "A_raw"
+ except json.JSONDecodeError: pass
+
+ b = sanitize_json_literal_newlines(raw0)
+ try: return json.loads(b), b, "B_newlines_fixed"
+ except json.JSONDecodeError: pass
+
+ c = sanitize_invalid_backslashes_in_strings(b)
+ try: return json.loads(c), c, "C_backslashes_fixed"
+ except json.JSONDecodeError: pass
+
+ sub = extract_first_json_object_substring(raw0)
+ if sub:
+ d0 = remove_invisible_control_chars(sub)
+ d1 = sanitize_json_literal_newlines(d0)
+ d2 = sanitize_invalid_backslashes_in_strings(d1)
+ try: return json.loads(d2), d2, "D_extracted_object_repaired"
+ except json.JSONDecodeError: pass
+
+ return None, raw0, "FAIL"
+
+ # === Modified: Call Gemini ===
+ def call_llm(full_prompt: str) -> str:
+ try:
+ resp = client.models.generate_content(
+ model=model,
+ contents=full_prompt,
+ config=types.GenerateContentConfig(
+ temperature=0.2,
+ # 可以在这里加 response_mime_type="application/json" 进一步增强稳定性
+ )
+ )
+ return resp.text.strip()
+ except Exception as e:
+ print(f"❌ Gemini API Error: {e}")
+ return ""
+
+ def preprocess_llm_output(raw_content: str) -> str:
+ raw_content = remove_invisible_control_chars(raw_content)
+ fence_match = re.search(r"```(?:json|JSON)?\s*([\s\S]*?)```", raw_content)
+ if fence_match:
+ raw_content = fence_match.group(1).strip()
+ raw_content = remove_invisible_control_chars(raw_content)
+ return raw_content
+
+ outputs = {}
+
+ # === Main Loop ===
+ if not os.path.exists(folder_path):
+ print(f"❌ Folder not found: {folder_path}")
+ return outputs
+
+ for filename in os.listdir(folder_path):
+ if not filename.lower().endswith((".md", ".markdown")):
+ continue
+
+ markdown_path = os.path.join(folder_path, filename)
+ if not os.path.isfile(markdown_path):
+ continue
+
+ section_name = filename
+ with open(markdown_path, "r", encoding="utf-8") as f:
+ md_text = f.read().strip()
+
+ full_prompt = build_full_prompt(base_prompt, section_name, md_text)
+ print(f"📐 Sending section '{section_name}' to Gemini for DAG generation...")
+
+ dag_obj = None
+ used_text = ""
+ stage = "INIT"
+
+ # Retry Loop
+ for attempt_idx in range(1 + MAX_RETRIES_ON_FAIL):
+ if attempt_idx > 0:
+ print(f"🔁 Retry LLM for section '{section_name}' (retry={attempt_idx}/{MAX_RETRIES_ON_FAIL})...")
+
+ raw_content = call_llm(full_prompt)
+ if not raw_content: continue # 如果 API 调用报错返回空,直接重试
+
+ raw_content = preprocess_llm_output(raw_content)
+ dag_obj, used_text, stage = robust_load_json(raw_content)
+
+ if dag_obj is not None:
+ break
+
+ print(f"⚠️ JSON parse failed for section '{section_name}' after repairs. Stage={stage}")
+
+ if dag_obj is None:
+ print(f"{section_name} 处理失败超过两次,已清除")
+ dag_obj = {}
+ else:
+ dag_obj = force_content_single_line(dag_obj)
+ if ENABLE_FALLBACK_CONTENT_BACKSLASH_STRIP and FALLBACK_STRIP_BACKSLASH_ONLY_IN_CONTENT:
+ if stage in ("D_extracted_object_repaired",):
+ dag_obj = fallback_strip_backslashes_in_content(dag_obj)
+
+ # Output
+ safe_section_name = re.sub(r"[\\/:*?\"<>|]", "_", section_name)
+ output_filename = f"{safe_section_name}_dag.json"
+ subdir_path = os.path.dirname(folder_path)
+ section_dag_path = os.path.join(subdir_path, "section_dag")
+ os.makedirs(section_dag_path, exist_ok=True)
+ output_path = os.path.join(section_dag_path, output_filename)
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(dag_obj, f, ensure_ascii=False, indent=4)
+
+ print(f"✅ DAG for section '{section_name}' saved to: {output_path} (parse_stage={stage})")
+ outputs[section_name] = output_path
+
+ return outputs
+
+
+# ========== 合并 section_dag 到 dag ==========
+def add_section_dag(
+ section_dag_folder: str,
+ main_dag_path: str,
+ output_path: Optional[str] = None
+) -> str:
+ """
+ Merge all section DAGs under `section_dag_folder` into the main DAG at `main_dag_path`.
+
+ For each section DAG JSON:
+ - Take its root node name (nodes[0]["name"]) and append that name
+ to the edge list of the main DAG's root node (main_dag["nodes"][0]["edge"]).
+ - Append ALL nodes from that section DAG to the end of main_dag["nodes"],
+ preserving their original order.
+
+ Compatibility patch:
+ - If a section JSON is a single node object (missing the top-level "nodes" wrapper),
+ automatically wrap it into:
+ {"nodes": []}
+ so the downstream merge logic can proceed.
+
+ Notes:
+ - This function does NOT call GPT, it only manipulates JSON.
+ - The main DAG is assumed to have the same format:
+ {
+ "nodes": [
+ {
+ "name": "...",
+ "content": "...",
+ "edge": [],
+ "level": 0 or 1,
+ "visual_node": []
+ },
+ ...
+ ]
+ }
+
+ Args:
+ section_dag_folder: Path to a folder that contains per-section DAG JSON files.
+ main_dag_path: Path to the main DAG JSON file (original).
+ output_path: Path to save the merged DAG. If None, overwrite main_dag_path.
+
+ Returns:
+ The path of the merged DAG JSON file.
+ """
+
+ def _coerce_section_dag_to_nodes_wrapper(obj, section_path: str) -> dict:
+ """
+ If `obj` is already a valid {"nodes": [...]} dict, return as-is.
+ If `obj` looks like a single node dict (has "name"/"content"/"edge"/etc. but no "nodes"),
+ wrap it into {"nodes": [obj]}.
+ Otherwise, raise ValueError.
+ """
+ # Case 1: already in expected format
+ if isinstance(obj, dict) and "nodes" in obj:
+ return obj
+
+ # Case 2: single-node object (missing wrapper)
+ if isinstance(obj, dict) and "nodes" not in obj:
+ # Heuristic: if it has at least "name" and "content" (common node keys), treat it as node.
+ has_name = isinstance(obj.get("name"), str) and obj.get("name").strip()
+ has_content = isinstance(obj.get("content"), str)
+ if has_name and has_content:
+ # Wrap into nodes list
+ return {"nodes": [obj]}
+
+ raise ValueError(
+ f"Section DAG JSON at '{section_path}' is neither a valid DAG wrapper "
+ f"nor a recognizable single-node object."
+ )
+
+ # === Load main DAG ===
+ with open(main_dag_path, "r", encoding="utf-8") as f:
+ main_dag = json.load(f)
+
+ if "nodes" not in main_dag or not isinstance(main_dag["nodes"], list) or len(main_dag["nodes"]) == 0:
+ raise ValueError("main_dag JSON is invalid: missing non-empty 'nodes' array.")
+
+ # Root node is assumed to be the first node
+ root_node = main_dag["nodes"][0]
+
+ # Ensure 'edge' field exists and is a list
+ if "edge" not in root_node or not isinstance(root_node["edge"], list):
+ root_node["edge"] = []
+
+ # === Traverse section DAG folder ===
+ # To keep deterministic order, sort filenames
+ for filename in sorted(os.listdir(section_dag_folder)):
+ # Only process *.json files
+ if not filename.lower().endswith(".json"):
+ continue
+
+ section_path = os.path.join(section_dag_folder, filename)
+
+ # Skip if it's the same file as main_dag_path, just in case
+ if os.path.abspath(section_path) == os.path.abspath(main_dag_path):
+ continue
+
+ if not os.path.isfile(section_path):
+ continue
+
+ # Load section DAG
+ with open(section_path, "r", encoding="utf-8") as f:
+ try:
+ section_raw = json.load(f)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Section DAG JSON invalid at '{section_path}': {e}")
+
+ # NEW: coerce into {"nodes":[...]} if missing wrapper
+ section_dag = _coerce_section_dag_to_nodes_wrapper(section_raw, section_path)
+
+ # Validate nodes array
+ if "nodes" not in section_dag or not isinstance(section_dag["nodes"], list) or len(section_dag["nodes"]) == 0:
+ raise ValueError(f"Section DAG JSON at '{section_path}' has no valid 'nodes' array.")
+
+ section_nodes = section_dag["nodes"]
+ section_root = section_nodes[0]
+
+ # Get section root name
+ section_root_name = section_root.get("name")
+ if not isinstance(section_root_name, str) or not section_root_name.strip():
+ raise ValueError(f"Section DAG root node at '{section_path}' has invalid or empty 'name'.")
+
+ # Append section root name into main root's edge
+ # (avoid duplicates, in case of reruns)
+ if section_root_name not in root_node["edge"]:
+ root_node["edge"].append(section_root_name)
+
+ # Append all section nodes to the end of main_dag["nodes"]
+ main_dag["nodes"].extend(section_nodes)
+
+ # === Save merged DAG ===
+ if output_path is None:
+ output_path = main_dag_path # overwrite by default
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(main_dag, f, ensure_ascii=False, indent=4)
+
+ return output_path
+
+
+# ========== 向原dag中添加visual_dag ==========
+def add_visual_dag(dag_path: str, visual_dag_path: str) -> str:
+ """
+ Append all nodes from a visual DAG JSON file into an existing DAG JSON file.
+
+ Both JSON files must share the same structure, e.g.:
+
+ {
+ "nodes": [
+ {
+ "name": "...",
+ "content": "...",
+ "edge": [],
+ "level": 0,
+ "visual_node": []
+ }
+ ]
+ }
+
+ Behavior:
+ - Load the main DAG from `dag_path`.
+ - Load the visual DAG from `visual_dag_path`.
+ - Append ALL nodes from visual_dag["nodes"] to the end of main_dag["nodes"],
+ preserving their original order.
+ - Overwrite `dag_path` with the merged DAG.
+ - Does NOT modify any edge relationships automatically.
+
+ Args:
+ dag_path: Path to the main DAG JSON (will be overwritten).
+ visual_dag_path: Path to the visual DAG JSON whose nodes will be appended.
+
+ Returns:
+ The `dag_path` of the merged DAG.
+ """
+ # === Load main DAG ===
+ with open(dag_path, "r", encoding="utf-8") as f:
+ main_dag = json.load(f)
+
+ if "nodes" not in main_dag or not isinstance(main_dag["nodes"], list):
+ raise ValueError(f"Main DAG at '{dag_path}' is invalid: missing 'nodes' array.")
+
+ # === Load visual DAG ===
+ with open(visual_dag_path, "r", encoding="utf-8") as f:
+ visual_dag = json.load(f)
+
+ if "nodes" not in visual_dag or not isinstance(visual_dag["nodes"], list):
+ raise ValueError(f"Visual DAG at '{visual_dag_path}' is invalid: missing 'nodes' array.")
+
+ # === Append visual nodes to main DAG (to the bottom) ===
+ main_dag["nodes"].extend(visual_dag["nodes"])
+
+ # === Save merged DAG back to dag_path (overwrite) ===
+ with open(dag_path, "w", encoding="utf-8") as f:
+ json.dump(main_dag, f, ensure_ascii=False, indent=4)
+
+ return dag_path
+
+
+# ========== 完善dag中每一个结点的visual_node ==========
+from typing import List
+def refine_visual_node(dag_path: str) -> None:
+ """
+ Refine `visual_node` for each node in the DAG JSON at `dag_path`.
+
+ Behavior:
+ - Load the DAG JSON from `dag_path`, whose structure is:
+ {
+ "nodes": [
+ {
+ "name": "...",
+ "content": "...",
+ "edge": [],
+ "level": 0,
+ "visual_node": []
+ },
+ ...
+ ]
+ }
+
+ - For each node in `nodes`:
+ * If node["visual_node"] == 1:
+ - Treat this as a special marker meaning the node is already
+ a visual node; skip it and do NOT modify `visual_node`.
+ * Else:
+ - Look at node["content"] (if it's a string).
+ - Find all markdown image references of the form:
+ 
+ using a regex.
+ - Filter to keep only relative paths (e.g., 'images/xxx.jpg'):
+ - path does NOT start with 'http://', 'https://', 'data:', or '//'.
+ - For each such match, append the full markdown snippet
+ (e.g., '') into node["visual_node"].
+ - If `visual_node` is missing or not a list (and not equal to 1),
+ it will be overwritten as a list of these strings.
+
+ - The function overwrites the original `dag_path` with the refined DAG.
+ """
+ # === Load DAG ===
+ with open(dag_path, "r", encoding="utf-8") as f:
+ dag = json.load(f)
+
+ if "nodes" not in dag or not isinstance(dag["nodes"], list):
+ raise ValueError(f"DAG JSON at '{dag_path}' is invalid: missing 'nodes' array.")
+
+ nodes: List[dict] = dag["nodes"]
+
+ # Regex to match markdown images: 
+ # group(0) = full match, group(1) = alt, group(2) = path
+ img_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
+
+ def is_relative_path(path: str) -> bool:
+ """Return True if path looks like a relative path, not URL or absolute."""
+ lowered = path.strip().lower()
+ if lowered.startswith("http://"):
+ return False
+ if lowered.startswith("https://"):
+ return False
+ if lowered.startswith("data:"):
+ return False
+ if lowered.startswith("//"):
+ return False
+ # You can optionally reject absolute filesystem paths too:
+ # if lowered.startswith("/") or re.match(r"^[a-zA-Z]:[\\/]", lowered):
+ # return False
+ return True
+
+ for node in nodes:
+ # Skip if this is not a dict
+ if not isinstance(node, dict):
+ continue
+
+ # If visual_node == 1, this is a special visual node -> skip
+ if node.get("visual_node") == 1:
+ continue
+
+ content = node.get("content")
+ if not isinstance(content, str) or not content:
+ # No textual content to search
+ # But if visual_node should still be a list, ensure that
+ if "visual_node" not in node or not isinstance(node["visual_node"], list):
+ node["visual_node"] = []
+ continue
+
+ # Find all markdown image references
+ matches = img_pattern.findall(content) # returns list of (alt, path)
+ full_matches = img_pattern.finditer(content) # to get exact substrings
+
+ # Ensure visual_node is a list (since we already filtered out ==1)
+ visual_list = node.get("visual_node")
+ if not isinstance(visual_list, list):
+ visual_list = []
+ else:
+ # create a copy to safely modify
+ visual_list = list(visual_list)
+
+ # To keep consistent mapping, use the iterator to get full strings
+ for match in full_matches:
+ full_str = match.group(0) # e.g., ''
+ path_str = match.group(2).strip() # inside parentheses
+
+ if not is_relative_path(path_str):
+ continue # skip URLs / absolute paths
+
+ if full_str not in visual_list:
+ visual_list.append(full_str)
+
+ # Update node
+ node["visual_node"] = visual_list
+
+ # === Save back to disk (overwrite) ===
+ with open(dag_path, "w", encoding="utf-8") as f:
+ json.dump(dag, f, ensure_ascii=False, indent=4)
diff --git a/src/refinement/__init__.py b/src/refinement/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9985d3146c5f9d3b7cd1f5a7896277ce018eb18
--- /dev/null
+++ b/src/refinement/__init__.py
@@ -0,0 +1,4 @@
+from .refinement import refinement_ppt
+from .refinement import refinement_poster
+from .refinement import refinement_pr
+
diff --git a/src/refinement/__pycache__/__init__.cpython-310.pyc b/src/refinement/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7fa095124a1686d4164ea68d5901504ce36988c0
Binary files /dev/null and b/src/refinement/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/refinement/__pycache__/html_revise.cpython-310.pyc b/src/refinement/__pycache__/html_revise.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d23cd11fd93e21afee78486f12f9f4d4c629a99c
Binary files /dev/null and b/src/refinement/__pycache__/html_revise.cpython-310.pyc differ
diff --git a/src/refinement/__pycache__/refinement.cpython-310.pyc b/src/refinement/__pycache__/refinement.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3877b0b3d6a0b9e922b36ba1910d1f295f0a5627
Binary files /dev/null and b/src/refinement/__pycache__/refinement.cpython-310.pyc differ
diff --git a/src/refinement/commenter.py b/src/refinement/commenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c466b36f3122d7e21fbac930c3b8eca9006f8d55
--- /dev/null
+++ b/src/refinement/commenter.py
@@ -0,0 +1,28 @@
+# 伪代码,需接入实际 API (如 OpenAI GPT-4o)
+class Commenter:
+ def __init__(self, client):
+ self.client = client
+
+ def review_slide(self, image_path):
+ """
+ 调用 VLM 分析渲染后的 PPT 图片
+ """
+ system_prompt = """
+ You are a Presentation Design Expert. Analyze the slide screenshot for visual layout issues.
+ Focus strictly on:
+ 1. Text Overflow: Is text going out of its container or the slide boundary?
+ 2. Balance: Is one side too heavy or too empty?
+ 3. Spacing: Are images too small or text too crowded?
+
+ Output valid JSON only:
+ {
+ "status": "PASS" or "NEEDS_REVISION",
+ "issues": ["list of specific issues..."],
+ "suggestion": "Natural language suggestion for the reviser."
+ }
+ """
+
+ # 这里模拟 VLM 调用
+ # response = client.chat.completions.create(model="gpt-4o", messages=..., image=...)
+ # return json.loads(response)
+ pass
\ No newline at end of file
diff --git a/src/refinement/html_revise.py b/src/refinement/html_revise.py
new file mode 100644
index 0000000000000000000000000000000000000000..52607e274c613339514cf18db675f0f8ee613922
--- /dev/null
+++ b/src/refinement/html_revise.py
@@ -0,0 +1,601 @@
+import re
+import os
+import json
+from bs4 import BeautifulSoup
+
+
+class HTMLMapper:
+ def __init__(self, html_content):
+ # 检查输入是否为文件路径,如果是,则读取文件内容
+ if os.path.exists(html_content) and os.path.isfile(html_content):
+ with open(html_content, 'r', encoding='utf-8') as f:
+ actual_html = f.read()
+ self.soup = BeautifulSoup(actual_html, 'html.parser')
+ else:
+ # 如果已经是 HTML 字符串了,直接使用
+ self.soup = BeautifulSoup(html_content, 'html.parser')
+
+ def get_structure_tree(self):
+ """
+ 修改后的入口:
+ 不再只返回 layout-container,而是返回包含标题和布局的完整 Slide 结构。
+ """
+ # 1. 初始化 Slide 根节点
+ slide_root = {
+ "id": "slide-root",
+ "type": "slide",
+ "children": []
+ }
+
+ # 2. 提取标题 (Title) - 新增逻辑
+ # 标题位于 SVG -> g -> text 中
+ title_node = self._extract_svg_title()
+ if title_node:
+ slide_root["children"].append(title_node)
+
+ # 3. 提取布局 (Layout Container) - 原有逻辑
+ layout_elem = self.soup.find("div", class_="layout-container")
+ if layout_elem:
+ # 这里的 is_root=True 保持不变,但这将作为 slide 的子节点
+ layout_tree = self._parse_node(layout_elem, is_root=False)
+ # 为了区分,我们可以手动覆盖一下 type 或者保持 layout-container
+ layout_tree["type"] = "layout-container"
+ slide_root["children"].append(layout_tree)
+ else:
+ slide_root["error"] = "Layout container not found"
+
+ return slide_root
+
+ def _extract_svg_title(self):
+ """
+ 专门用于提取 SVG 中的标题部分
+ """
+ # 找到 SVG 元素
+ svg = self.soup.find("svg")
+ if not svg:
+ return None
+
+ # 在 SVG 中查找 text 标签
+ # 注意:为了避免找到 foreignObject 里的文字,我们需要确保 text 是 SVG 的直接后代结构
+ # 简单的做法是查找所有 text,且其父级链中没有 foreignObject
+ text_tags = svg.find_all("text")
+
+ target_title = None
+
+ for tag in text_tags:
+ # 排除掉 foreignObject 内部可能存在的 text (虽然一般 HTML 放在 div 里,但在 svg 混合结构中要小心)
+ if tag.find_parent("foreignObject"):
+ continue
+
+ # 假设第一个找到的 SVG 文本就是标题 (或者根据 shape-id="2" 来定位,示例代码中有 shape-id)
+ # 你也可以根据 font-size 大小来判断谁是标题
+ target_title = tag
+ break
+
+ if target_title:
+ # SVG 的属性直接写在标签上,例如 font-size="18.0pt"
+ return {
+ "id": "slide-title",
+ "type": "content-block",
+ "category": "title", # 标记为标题
+ "content": target_title.get_text(strip=True),
+ "typography": {
+ "font-family": target_title.get("font-family"),
+ "font-size": target_title.get("font-size"),
+ "font-weight": target_title.get("font-weight"),
+ "color": target_title.get("fill"), # SVG 中文字颜色通常是 fill
+ "text-align": "center" # SVG text 通常默认定位,这里可能需要推断或留空
+ },
+ "style_raw": "SVG Text Element" # 标记来源
+ }
+
+ return None
+
+ def _parse_node(self, element, is_root=False):
+ """
+ (保持原有逻辑大体不变,微调类型判断)
+ 递归解析单个节点,提取样式、Flex值和子节点。
+ """
+ el_id = element.get("id")
+ style_str = element.get("style", "")
+ styles = self._parse_inline_style(style_str)
+
+ # 确定节点类型
+ node_type = self._determine_node_type(element, is_root)
+
+ node_data = {
+ "id": el_id if el_id else "unknown-div",
+ "type": node_type,
+ "flex": self._extract_flex_value(styles.get("flex")),
+ "style_raw": style_str
+ }
+
+ # Logic A: 容器处理
+ if node_type in ["layout-container", "layout-field"]:
+ children_nodes = []
+ for child in element.find_all("div", recursive=False):
+ child_data = self._parse_node(child)
+ if child_data:
+ children_nodes.append(child_data)
+ node_data["children"] = children_nodes
+
+ # Logic B: 内容块处理
+ elif node_type == "content-block":
+ block_category = self._determine_block_category(el_id)
+ node_data["category"] = block_category
+
+ if block_category == "text":
+ node_data["content"] = element.get_text(strip=True)
+ node_data["typography"] = {
+ "font-family": styles.get("font-family"),
+ "font-size": styles.get("font-size"),
+ "font-weight": styles.get("font-weight"),
+ "color": styles.get("color"),
+ "line-height": styles.get("line-height"),
+ "text-align": styles.get("text-align")
+ }
+ elif block_category == "image":
+ img_tag = element.find("img")
+ if img_tag:
+ node_data["src"] = img_tag.get("src")
+ node_data["alt"] = img_tag.get("alt")
+
+ return node_data
+
+ def _determine_node_type(self, element, is_root):
+ """微调:显式识别 layout-container"""
+ classes = element.get("class", [])
+ if "layout-container" in classes:
+ return "layout-container"
+
+ el_id = element.get("id", "")
+ if el_id.endswith("-field"):
+ return "layout-field"
+ elif el_id.endswith("-block"):
+ return "content-block"
+ return "generic-container" # 默认值修改
+
+ def _determine_block_category(self, el_id):
+ if not el_id: return "unknown"
+ if "text" in el_id: return "text"
+ if "image" in el_id: return "image"
+ if "formula" in el_id: return "formula"
+ return "generic"
+
+ def _parse_inline_style(self, style_str):
+ if not style_str: return {}
+ return {
+ rule.split(':')[0].strip(): rule.split(':')[1].strip()
+ for rule in style_str.split(';') if ':' in rule
+ }
+
+ def _extract_flex_value(self, flex_str):
+ if not flex_str: return 1.0 # 默认为 1.0 如果没有指定
+ match = re.match(r'([\d\.]+)', flex_str)
+ return float(match.group(1)) if match else 1.0
+
+
+class HTMLModificationError(Exception):
+ """当 HTML 修改过程中出现严重错误时抛出此异常"""
+ def __init__(self, message, errors=None):
+ super().__init__(message)
+ self.errors = errors if errors else []
+
+
+import os
+from bs4 import BeautifulSoup, Tag
+
+class HTMLModifier:
+ def __init__(self, input_path, output_path):
+ self.input_path = input_path
+ self.output_path = output_path
+ if not os.path.exists(input_path):
+ raise FileNotFoundError(f"HTML file not found: {input_path}")
+
+ # 确保输出目录存在
+ output_dir = os.path.dirname(output_path)
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(input_path, 'r', encoding='utf-8') as f:
+ self.soup = BeautifulSoup(f, 'html.parser')
+
+ self.errors = []
+
+ def modify(self, modification_tree):
+ """
+ 主入口:根据结构树类型分发处理逻辑
+ """
+ node_type = modification_tree.get("type")
+
+ # Case 1: 新的 Slide 结构 (包含 SVG 标题和 HTML 内容)
+ if node_type == "slide":
+ self._process_slide_root(modification_tree)
+
+ # Case 2: 旧的 Layout 结构 (直接从 layout-container 开始)
+ # 兼容旧逻辑,防止 JSON 结构未更新导致崩溃
+ elif node_type == "layout-container" or modification_tree.get("id") == "root":
+ self._process_layout_root(modification_tree)
+
+ else:
+ self._log_error("Critical", f"Unknown root node type: {node_type}")
+
+ # 错误处理与保存
+ if self._has_critical_errors():
+ self._handle_aborted_modification()
+ else:
+ self._save_file()
+
+ def _process_slide_root(self, tree_node):
+ """处理 Slide 根节点,遍历子节点并分发给 SVG 或 HTML 处理器"""
+ children = tree_node.get("children", [])
+
+ for child in children:
+ child_category = child.get("category")
+ child_type = child.get("type")
+
+ # A. 处理标题 (位于 SVG 中)
+ if child_category == "title":
+ self._update_svg_title(child)
+
+ # B. 处理布局容器 (位于 HTML div 中)
+ elif child_type == "layout-container":
+ self._process_layout_root(child)
+
+ # C. 其他可能的直接子节点
+ else:
+ # 尝试通过 ID 查找并通用更新
+ self._update_generic_node(child)
+
+ def _process_layout_root(self, json_node):
+ """处理 HTML 布局部分的入口"""
+ # 尝试找到 layout-container
+ html_root = self.soup.find("div", class_="layout-container")
+ if not html_root:
+ # 备选:通过 ID 查找
+ root_id = json_node.get("id")
+ if root_id and root_id != "root":
+ html_root = self.soup.find(id=root_id)
+
+ if not html_root:
+ self._log_error("Critical", "Layout container not found in HTML.")
+ return
+
+ # 递归更新 HTML 树
+ self._update_html_recursive(html_root, json_node)
+
+ # ==========================================
+ # SVG 处理逻辑
+ # ==========================================
+
+ def _update_svg_title(self, json_node):
+ """
+ 专门处理 SVG 标题节点的更新
+ 难点:SVG 属性不同于 CSS,且文本通常嵌套在 tspan 中
+ """
+ # 1. 定位 SVG 文本节点
+ # 由于 Mapper 可能生成了虚拟 ID "slide-title",我们不能直接 find(id="slide-title")
+ # 需要复用类似的查找逻辑:找到 SVG 下的第一个非 foreignObject 的 text
+ svg = self.soup.find("svg")
+ if not svg:
+ self._log_error("Error", "SVG element not found for title update.")
+ return
+
+ text_tags = svg.find_all("text")
+ target_node = None
+
+ for tag in text_tags:
+ if not tag.find_parent("foreignObject"):
+ target_node = tag
+ break
+
+ if not target_node:
+ self._log_error("Warning", "Target SVG title text node not found.")
+ return
+
+ # 2. 更新文本内容
+ if "content" in json_node:
+ new_text = json_node["content"]
+ # SVG 文本往往包裹在 tspan 中
+ tspan = target_node.find("tspan")
+ if tspan:
+ tspan.string = new_text
+ else:
+ target_node.string = new_text
+
+ # 3. 更新样式 (SVG 属性映射)
+ if "typography" in json_node:
+ style_map = json_node["typography"]
+ for css_key, val in style_map.items():
+ if val:
+ svg_attr = self._map_css_to_svg_attr(css_key)
+ if svg_attr:
+ target_node[svg_attr] = val
+
+ def _map_css_to_svg_attr(self, css_key):
+ """CSS 样式名转 SVG 属性名"""
+ mapping = {
+ "color": "fill",
+ "font-family": "font-family",
+ "font-size": "font-size",
+ "font-weight": "font-weight",
+ "text-align": "text-anchor", # 注意:值也需要转换 (left->start, center->middle)
+ # 简单起见,这里暂不转换 value,仅映射 key
+ }
+ return mapping.get(css_key)
+
+ # ==========================================
+ # HTML 递归处理逻辑
+ # ==========================================
+
+ def _update_html_recursive(self, html_element, json_node):
+ """
+ 递归更新 HTML 节点 (div, img, standard text)
+ """
+ node_id = json_node.get("id", "unknown")
+
+ # --- A. 更新 Flex 布局属性 ---
+ if "flex" in json_node:
+ # flex 缩写处理比较复杂,这里简单处理 flex-grow
+ # 如果 json 直接给的是数字,我们假设是 flex:
+ self._update_inline_style(html_element, "flex", str(json_node["flex"]))
+
+ # --- B. 更新内容块 ---
+ if json_node.get("type") == "content-block":
+ category = json_node.get("category")
+
+ # 文本处理
+ if category == "text":
+ if "content" in json_node:
+ try:
+ # 保持原有逻辑:清除旧内容,解析新 HTML 插入
+ html_element.clear()
+ new_content_tag = BeautifulSoup(json_node["content"], 'html.parser')
+ html_element.append(new_content_tag)
+ except Exception as e:
+ self._log_error("Error", f"Failed to parse content for {node_id}: {e}")
+
+ # 样式更新
+ typography = json_node.get("typography", {})
+ for css_prop, val in typography.items():
+ if val is not None:
+ self._update_inline_style(html_element, css_prop, val)
+
+ # 图片处理
+ elif category == "image":
+ img_tag = html_element.find("img") if html_element.name != 'img' else html_element
+ if not img_tag:
+ # 尝试找子级
+ img_tag = html_element.find("img")
+
+ if img_tag:
+ if "src" in json_node: img_tag['src'] = json_node["src"]
+ if "alt" in json_node: img_tag['alt'] = json_node["alt"]
+ else:
+ self._log_error("Warning", f"Image tag not found for node {node_id}")
+
+ # --- C. 递归子节点 ---
+ json_children = json_node.get("children", [])
+ for child_json in json_children:
+ child_id = child_json.get("id")
+ if not child_id: continue
+
+ # 在当前元素的子树中查找
+ # 注意:使用 find 而不是 find_all 限制范围,或者全局查找校验父子关系
+ html_child = self.soup.find(id=child_id)
+
+ if not html_child:
+ self._log_error("Error", f"Child node '{child_id}' not found in HTML.")
+ continue
+
+ # 校验层级关系 (可选,防止跨层级错误修改)
+ # 这里简单判断 html_child 是否确实在 html_element 内部
+ if html_element not in html_child.parents:
+ self._log_error("Warning", f"Hierarchy mismatch: '{child_id}' is not inside '{node_id}'.")
+
+ self._update_html_recursive(html_child, child_json)
+
+ def _update_generic_node(self, json_node):
+ """当无法确定是布局还是SVG时的兜底处理"""
+ node_id = json_node.get("id")
+ if node_id:
+ elem = self.soup.find(id=node_id)
+ if elem:
+ # 简单判断是在 SVG 还是 HTML 中
+ if elem.find_parent("svg"):
+ # 简化版 SVG 更新(暂不支持)
+ pass
+ else:
+ self._update_html_recursive(elem, json_node)
+
+ def _update_inline_style(self, element, property_name, value):
+ """更新 HTML style 属性字符串"""
+ current_style = element.get("style", "")
+ style_dict = {}
+ if current_style:
+ for item in current_style.split(';'):
+ if ':' in item:
+ k, v = item.split(':', 1)
+ style_dict[k.strip()] = v.strip()
+
+ style_dict[property_name] = value
+ new_style = "; ".join([f"{k}: {v}" for k, v in style_dict.items()])
+ element['style'] = new_style
+
+ def _log_error(self, level, msg):
+ self.errors.append({"level": level, "msg": msg})
+
+ def _has_critical_errors(self):
+ return any(e["level"] in ["Critical"] for e in self.errors)
+
+ def _handle_aborted_modification(self):
+ print("Modification aborted due to critical errors:")
+ for err in self.errors:
+ print(f"[{err['level']}] {err['msg']}")
+ raise RuntimeError("HTML Modification Aborted")
+
+ def _save_file(self):
+ with open(self.output_path, "w", encoding='utf-8') as f:
+ # prettify() 有时会破坏布局,直接转 str 通常更安全
+ f.write(str(self.soup))
+
+ print(f"Successfully modified: {self.output_path}")
+ if self.errors:
+ print(f"Warnings: {self.errors}")
+
+
+def apply_html_modifications(input_path, output_path, modification_json):
+ """
+ 外部调用接口
+ :param input_path: HTML 文件路径
+ :param output_path: 修改后的HTML文件保存路径
+ :param modification_json: 包含修改信息的字典 (结构由 HTMLMapper 生成)
+ """
+ print(f"Starting modification for: {input_path}")
+ modifier = HTMLModifier(input_path, output_path)
+ modifier.modify(modification_json)
+
+
+
+
+
+
+
+
+
+# if __name__ == "__main__":
+# html_file_path = "./ppt_template_lzt/T3_ImageLeft.html"
+# html_file_path_refine = "./ppt_template_lzt/T3_ImageLeft_refined.html"
+
+# with open(html_file_path, "r", encoding="utf-8") as f:
+# html_string = f.read()
+
+# htlmMapper = HTMLMapper(html_string)
+
+ # print(htlmMapper.soup)
+ # print(htlmMapper.soup.find_all("div", id=True))
+ # tree = htlmMapper.get_structure_tree()
+ # print(json.dumps(tree, indent=4, ensure_ascii=False))
+
+ # target_modifications = {
+ # "id": "main-container",
+ # "type": "root",
+ # "flex": 1.0,
+ # "style_raw": "",
+ # "children": [
+ # {
+ # "id": "1-col-field",
+ # "type": "layout-field",
+ # "flex": 1.5,
+ # "style_raw": "flex: 1; display: flex; flex-direction: column; gap: 10px; height: 100%; overflow: hidden;",
+ # "children": [
+ # {
+ # "id": "1-1-text-block",
+ # "type": "content-block",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; font-family: 'Calibri', sans-serif; font-size: 21pt; color: #000000; line-height: 1.3; text-align: left; overflow: hidden; padding: 0 15px;",
+ # "category": "text",
+ # "content": "In this quiet lattice, fragments of hollow constellations drift. A procession of intangible echoes shimmers faintly, attempting to articulate a message that dissolves the moment it forms. This wandering expanse exists in an imagined time, neither aligning with memory nor contradicting blurred architecture.",
+ # "typography": {
+ # "font-family": "'Calibri', sans-serif",
+ # "font-size": "22pt",
+ # "color": "#9B2828",
+ # "line-height": "1.3",
+ # "text-align": "left"
+ # }
+ # }
+ # ]
+ # },
+ # {
+ # "id": "2-col-field",
+ # "type": "layout-field",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; display: flex; flex-direction: column; gap: 10px; height: 100%; overflow: hidden;",
+ # "children": [
+ # {
+ # "id": "2-1-image-block",
+ # "type": "content-block",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; display: flex; align-items: center; justify-content: center; overflow: hidden;",
+ # "category": "image",
+ # "src": "images/7.png",
+ # "alt": "Slide Image: Conceptual representation of the transition from 2D to 3D data space or a deep learning architecture for reconstruction."
+ # }
+ # ]
+ # }
+ # ]
+ # }
+
+ # target_modifications = {
+ # "id": "slide-root",
+ # "type": "slide",
+ # "children": [
+ # {
+ # "id": "slide-title",
+ # "type": "content-block",
+ # "category": "title",
+ # "content": "Bridging 2D and 3D: From Lifting to LearningBridging 2D and 3D: From Lifting to Learning",
+ # "typography": {
+ # "font-family": "Calibri",
+ # "font-size": "18.0pt",
+ # "font-weight": "bold",
+ # "color": "#ffffff",
+ # "text-align": "center"
+ # },
+ # "style_raw": "SVG Text Element"
+ # },
+ # {
+ # "id": "root",
+ # "type": "layout-container",
+ # "flex": 1.0,
+ # "style_raw": "",
+ # "children": [
+ # {
+ # "id": "1-col-field",
+ # "type": "layout-field",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; display: flex; flex-direction: column; gap: 10px; height: 100%; overflow: hidden;",
+ # "children": [
+ # {
+ # "id": "1-1-image-block",
+ # "type": "content-block",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; display: flex; align-items: center; justify-content: center; overflow: hidden;",
+ # "category": "image",
+ # "src": "images/7.png",
+ # "alt": "Slide Image"
+ # }
+ # ]
+ # },
+ # {
+ # "id": "2-col-field",
+ # "type": "layout-field",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; display: flex; flex-direction: column; gap: 10px; height: 100%; overflow: hidden;",
+ # "children": [
+ # {
+ # "id": "2-1-text-block",
+ # "type": "content-block",
+ # "flex": 1.0,
+ # "style_raw": "flex: 1; font-family: 'Calibri', sans-serif; font-size: 21pt; color: #000000; line-height: 1.3; text-align: left; overflow: hidden; padding: 0 15px;",
+ # "category": "text",
+ # "content": "In the quiet lattice of an unnamed elsewhere, the murmuring fragments of hollow constellations drift without purpose, weaving patterns that neither align with memory nor contradict the blurred architecture of imagined time. Within this wandering expanse, a procession of intangible echoes shimmers faintly, as if attempting to articulate a message that dissolves the moment it forms.",
+ # "typography": {
+ # "font-family": "'Calibri', sans-serif",
+ # "font-size": "21pt",
+ # "color": "#000000",
+ # "line-height": "1.3",
+ # "text-align": "left"
+ # }
+ # }
+ # ]
+ # }
+ # ]
+ # }
+ # ]
+ # }
+
+ # apply_html_modifications(html_file_path, html_file_path_refine, target_modifications)
+
+
+
diff --git a/src/refinement/refinement.py b/src/refinement/refinement.py
new file mode 100644
index 0000000000000000000000000000000000000000..a95f7b6397573a103a5fc3842948aaefc8fa71bd
--- /dev/null
+++ b/src/refinement/refinement.py
@@ -0,0 +1,818 @@
+import base64
+import os
+import re
+import json
+import time
+import PIL.Image
+import shutil
+from PIL import Image
+from pathlib import Path
+from openai import OpenAI
+from google import genai
+from google.genai import types
+from .html_revise import HTMLMapper, apply_html_modifications, HTMLModificationError
+from playwright.sync_api import sync_playwright
+
+
+class VLMCommenter:
+ def __init__(self, api_key, prompt, provider="openai", model_name=None):
+ """
+ :param api_key: API Key
+ :param prompt: 提示词文本
+ :param provider: "openai" 或 "gemini"
+ :param model_name: 指定模型名称 (可选)
+ """
+ self.provider = provider.lower()
+ self.api_key = api_key
+ self.model_name = model_name
+ self.prompt_text = prompt
+
+ if self.provider == "openai":
+ self.client = OpenAI(api_key=api_key)
+ self.model = model_name if model_name else "gpt-4o"
+ elif self.provider == "gemini":
+ self.client = genai.Client(api_key=api_key)
+ self.model = model_name if model_name else "gemini-1.5-flash"
+ else:
+ raise ValueError("Unsupported provider. Choose 'openai' or 'gemini'.")
+
+ def _encode_image(self, image_path):
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode('utf-8')
+
+ def evaluate_slide(self, image_path, outline, pre_comments):
+ """
+ 输入:截图路径
+ 输出:诊断文本 string
+ """
+ prompt_text = self.prompt_text
+ full_prompt = f"{prompt_text}\n \
+ *****previous comments******\
+ \n{pre_comments} \
+ *****begin of the outline*****\
+ \n{outline} \
+ *****end of the outline*****\
+ *****the following is the image,not the outline*****"
+
+ if not full_prompt:
+ return "Error: Commenter prompt is empty."
+
+ if self.provider == "openai":
+ base64_image = self._encode_image(image_path)
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": [
+ {"type": "text", "text": full_prompt},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
+ ]}
+ ],
+ max_tokens=300
+ )
+ return response.choices[0].message.content
+ except Exception as e:
+ return f"Error using OpenAI VLM: {e}"
+
+ elif self.provider == "gemini":
+ try:
+ img = PIL.Image.open(image_path)
+
+ response = self.client.models.generate_content(
+ model=self.model,
+ contents=[full_prompt, img]
+ )
+ return response.text
+ except Exception as e:
+ return f"Error using Gemini VLM (google-genai): {e}"
+
+
+class LLMReviser:
+ def __init__(self, api_key, prompt, provider="openai", model_name=None):
+ """
+ :param api_key: API Key
+ :param prompt: 提示词文本
+ :param provider: "openai" 或 "gemini"
+ :param model_name: 指定模型名称
+ """
+ self.provider = provider.lower()
+ self.api_key = api_key
+ self.model_name = model_name
+ self.system_prompt = prompt
+
+ if self.provider == "openai":
+ self.client = OpenAI(api_key=api_key)
+ self.model = model_name if model_name else "gpt-4"
+ elif self.provider == "gemini":
+ self.client = genai.Client(api_key=api_key)
+ self.model = model_name if model_name else "gemini-1.5-pro"
+ else:
+ raise ValueError("Unsupported provider. Choose 'openai' or 'gemini'.")
+
+ def generate_revision_plan(self, current_structure_json, vlm_critique):
+ """
+ 输入:HTML 结构 JSON 和 VLM 评价
+ 输出:修改后的 JSON
+ """
+
+ if "PASS" in vlm_critique.upper() and len(vlm_critique) < 10:
+ return None
+
+ prompt_system = self.system_prompt
+ if not prompt_system:
+ print("Error: Reviser prompt is empty.")
+ return None
+
+ user_content = f"""
+ --- CURRENT STRUCTURE JSON ---
+ {json.dumps(current_structure_json, indent=2)}
+
+ --- VISUAL CRITIQUE ---
+ {vlm_critique}
+
+ --- INSTRUCTION ---
+ Generate the modification JSON based on the system instructions.
+ """
+
+ if self.provider == "openai":
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": prompt_system},
+ {"role": "user", "content": user_content}
+ ],
+ response_format={"type": "json_object"}
+ )
+ return json.loads(response.choices[0].message.content)
+ except Exception as e:
+ print(f"OpenAI Error: {e}")
+ return None
+
+ elif self.provider == "gemini":
+ try:
+ # 拼接 System Prompt 和 User Content
+ full_prompt = f"{prompt_system}\n\n{user_content}"
+
+ response = self.client.models.generate_content(
+ model=self.model,
+ contents=full_prompt,
+ config=types.GenerateContentConfig(
+ response_mime_type="application/json"
+ )
+ )
+
+ text_response = response.text
+
+ # 清洗可能存在的 Markdown 标记 (即使指定了 JSON mime type,有些模型仍可能加 ```json)
+ if text_response.startswith("```"):
+ text_response = text_response.strip("`").replace("json", "").strip()
+
+ return json.loads(text_response)
+ except json.JSONDecodeError:
+ print(f"Gemini returned invalid JSON: {response.text}")
+ return None
+ except Exception as e:
+ print(f"Gemini Error (google-genai): {e}")
+ return None
+
+
+def take_screenshot(html_path, output_path):
+ """简单的截图工具函数示例 (Playwright)"""
+ if not os.path.exists(html_path):
+ print(f"错误:文件不存在于 {html_path}")
+ abs_path = Path(os.path.abspath(html_path)).as_uri()
+
+ with sync_playwright() as p:
+ # 1. 显式设置 device_scale_factor=1
+ browser = p.chromium.launch()
+ context = browser.new_context(
+ viewport={'width': 960, 'height': 540},
+ device_scale_factor=1
+ )
+ page = context.new_page()
+
+ # 2. 访问页面
+ page.goto(abs_path, wait_until="networkidle") # 确保图片和字体加载完成
+
+ # 3. 截取特定元素而非全屏,这样最保险
+ # 你的根 div id 是 slide1,或者直接截取 svg
+ element = page.locator(".slideImage")
+ element.screenshot(path=output_path)
+
+ browser.close()
+
+def take_screenshot_poster(html_path, output_path):
+ """适配 .poster/#flow 的 HTML 海报截图(Playwright, 同步)"""
+ if not os.path.exists(html_path):
+ raise FileNotFoundError(f"文件不存在: {html_path}")
+
+ abs_uri = Path(os.path.abspath(html_path)).as_uri()
+
+ with sync_playwright() as p:
+ browser = p.chromium.launch(headless=True, args=["--disable-dev-shm-usage"])
+ context = browser.new_context(
+ # 你的 CSS 固定 --poster-width/height=1400x900
+ viewport={"width": 1400, "height": 900},
+ device_scale_factor=1
+ )
+ page = context.new_page()
+
+ # 1) 先 DOMReady,避免 networkidle 卡死
+ page.goto(abs_uri, wait_until="domcontentloaded")
+
+ # 2) 确保关键容器存在
+ page.wait_for_selector(".poster", state="attached", timeout=30000)
+ page.wait_for_selector("#flow", state="attached", timeout=30000)
+
+ # 3) 等字体就绪(你的脚本里 fit 受字体/排版影响很大)
+ try:
+ page.evaluate("() => document.fonts ? document.fonts.ready : Promise.resolve()")
+ except Exception:
+ pass
+
+ # 4) 等 flow 内所有图片加载完成(没有图片也会立即返回)
+ page.evaluate(r"""
+ () => {
+ const flow = document.getElementById("flow");
+ if (!flow) return Promise.resolve();
+ const imgs = Array.from(flow.querySelectorAll("img"));
+ if (imgs.length === 0) return Promise.resolve();
+ return Promise.all(imgs.map(img => {
+ if (img.complete) return Promise.resolve();
+ return new Promise(res => {
+ img.addEventListener("load", res, { once: true });
+ img.addEventListener("error", res, { once: true });
+ });
+ }));
+ }
+ """)
+
+ # 5) 等你的 fit() 执行并让布局稳定:等几帧 + scrollWidth 不再变化
+ page.evaluate(r"""
+ () => new Promise((resolve) => {
+ const flow = document.getElementById("flow");
+ if (!flow) return resolve();
+
+ let last = -1;
+ let stableCount = 0;
+
+ function tick() {
+ const cur = flow.scrollWidth; // multi-column 溢出判据
+ if (cur === last) stableCount += 1;
+ else stableCount = 0;
+
+ last = cur;
+
+ // 连续若干帧稳定,就认为 fit/重排结束
+ if (stableCount >= 10) return resolve();
+ requestAnimationFrame(tick);
+ }
+
+ // 给 load 事件/fit 一点点启动时间
+ setTimeout(() => requestAnimationFrame(tick), 50);
+ })
+ """)
+
+ # 6) 截图:截 .poster(不截 stage 背景)
+ poster = page.locator(".poster").first
+ poster.screenshot(path=output_path, timeout=60000)
+
+ browser.close()
+
+
+def load_prompt(prompt_path="prompt.json", prompt_name="poster_prompt"):
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ return data.get(prompt_name, "")
+
+
+def refine_one_slide(input_path, output_path, prompts, outline, max_iterations, model, config):
+ """
+ 自动修复闭环:截图 -> 诊断 -> 修改 -> 循环
+ """
+ is_gemini = "gemini" in model.lower()
+
+ if is_gemini:
+ api_key = config['api_keys'].get('gemini_api_key')
+ else:
+ api_key = config['api_keys'].get('openai_api_key')
+
+ commenter_prompt = prompts[0]
+ reviser_prompt = prompts[1]
+
+ platform = "gemini" if "gemini" in model.lower() else "openai"
+
+ vlm = VLMCommenter(api_key, commenter_prompt, provider=platform, model_name=model)
+ reviser = LLMReviser(api_key, reviser_prompt, provider=platform, model_name=model)
+
+ current_input = input_path
+ critic_his = ""
+
+ for i in range(max_iterations):
+ print(f"\n=== Iteration {i+1} ===")
+
+ # 1. 渲染并截图 (这里用伪代码表示,实际可用 Selenium/Playwright)
+ screenshot_path = f"{Path(output_path).parent}/{Path(current_input).stem}_{i+1}.png" # 临时截图路径
+ take_screenshot(current_input, screenshot_path)
+ print(f"Screenshot taken: {screenshot_path}")
+
+ # 2. VLM 视觉诊断
+ critique = vlm.evaluate_slide(screenshot_path, outline, critic_his)
+ critic_his = critic_his + f"this is the {i}th comment: {critique}"
+ print(f"VLM Critique: {critique}")
+
+ if "PASS" in critique:
+ print("Layout looks good! Stopping loop.")
+
+ if os.path.abspath(current_input) != os.path.abspath(output_path):
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ shutil.copy2(current_input, output_path)
+ print(f"Final result saved to: {output_path}")
+ else:
+ print(f"Result is already at output path: {output_path}")
+
+ break
+
+ # 3. 读取当前 HTML 结构
+ mapper = HTMLMapper(current_input)
+ current_tree = mapper.get_structure_tree()
+
+ # 4. LLM 生成修改方案
+ modification_json = reviser.generate_revision_plan(current_tree, critique)
+
+ if not modification_json:
+ print("Reviser suggested no changes. Stopping.")
+ break
+
+ print(f"Proposed Changes: {json.dumps(modification_json, indent=2)}")
+
+ # 5. 执行修改
+ try:
+ # 调用 reviser,如果有严重错误,它现在会抛出 HTMLModificationError
+ apply_html_modifications(current_input, output_path, modification_json)
+ print("Modifications applied to HTML.")
+
+ except (HTMLModificationError, Exception) as e:
+ # 捕获自定义错误 或 其他意外错误
+ print(f"❌ Error applying modifications at iteration {i+1}: {e}")
+
+ # ====== 添加另存副本逻辑 ======
+ try:
+ # 获取 input_path (最原始文件名)
+ base_name = Path(input_path).stem
+
+ # 定义错误副本路径
+ error_backup_path = f"{Path(output_path).parent}/{base_name}_FAILED_iter{i+1}.html"
+
+ # 将导致出错的那个 HTML 文件 (current_input) 复制出来
+ shutil.copy2(current_input, error_backup_path)
+
+ print(f"⚠️ 已自动保存出错前的 HTML 副本: {error_backup_path}")
+ print(f" 你可以打开此文件,并使用控制台打印的 JSON 尝试复现问题。")
+
+ except Exception as copy_err:
+ print(f"❌ 尝试保存错误副本时失败: {copy_err}")
+ # ============================
+
+ # 出错后中断循环
+ break
+ current_input = output_path
+
+ # 等待一会,防止并发读写问题
+ time.sleep(1)
+
+ # 优化结束后输出最终截图
+ final_screenshot_path = f"{Path(output_path).parent}/{Path(current_input).stem}_final.png"
+ take_screenshot(current_input, final_screenshot_path)
+ print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
+
+
+def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None):
+ # 1. 定义路径
+ outline_path = os.path.join(input_index, "outline.json")
+ output_index = os.path.join(input_index, "final")
+ output_index_images = os.path.join(output_index, "images") # 保存图片的子目录,用于显示refinement后的html中的图片
+
+ # 确保输出目录存在
+ os.makedirs(output_index, exist_ok=True)
+
+ # 将图片复制到final/images目录下
+ import shutil
+ source_images_dir = os.path.join(input_index, "images")
+ if os.path.exists(source_images_dir):
+ shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
+ print(f"📁 Copied images to: {output_index_images}")
+
+ # 2. 加载大纲数据
+ with open(outline_path, 'r', encoding='utf-8') as f:
+ outline_data = json.load(f)
+ if isinstance(outline_data, list):
+ # 将列表转换为以索引(字符串)为 Key 的字典
+ # 假设 list[0] 对应文件 0_ppt.html 或 1_ppt.html,这里保持原始索引
+ outline_full = {str(i): item for i, item in enumerate(outline_data)}
+ else:
+ outline_full = outline_data
+
+ # ================= 核心修改逻辑开始 =================
+
+ print(f"🚀 开始扫描目录: {input_index}")
+
+ # 3.1 先过滤出所有符合 "数字_ppt.html" 格式的文件
+ target_files = []
+ for f in os.listdir(input_index):
+ # 严格匹配:数字开头 + _ppt.html 结尾
+ if re.search(r'^\d+_ppt\.html$', f):
+ target_files.append(f)
+
+ # 3.2 定义排序 Key:直接提取开头的数字
+ def get_file_number(filename):
+ # 因为上一步已经过滤过了,这里可以直接提取
+ return int(filename.split('_')[0])
+
+ # 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
+ sorted_files = sorted(target_files, key=get_file_number)
+
+ # Debug: 打印前几个文件确认顺序
+ print(f"👀 排序后文件列表前5个: {sorted_files[:5]}")
+
+ # 4. 遍历排序后的列表
+ for file_name in sorted_files:
+ # 直接提取序号 (之前已经验证过格式了)
+ num = str(get_file_number(file_name))
+
+ # 获取当前 html 对应的 outline
+ outline = outline_full.get(int(num)-1)
+
+ # 【容错逻辑】处理索引偏移 (例如文件是 1_ppt,但列表是从 0 开始)
+ # 如果 outline 为空,且 num-1 存在,则尝试自动回退
+ if outline is None and str(int(num)-1) in outline_full:
+ print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
+ outline = outline_full.get(str(int(num)-1))
+
+ if outline is None:
+ print(f"⚠️ 跳过 {file_name}: 在 outline.json 中找不到序号 {num} 或 {int(num)-1}")
+ continue
+
+ # 构建路径
+ html_file_path = os.path.join(input_index, file_name)
+ html_file_path_refine = os.path.join(output_index, file_name)
+
+ print(f"📝 [顺序处理中] 正在优化: {file_name} (对应大纲 Key: {num})")
+
+ # 6. 调用优化函数
+ try:
+ refine_one_slide(
+ input_path=html_file_path,
+ output_path=html_file_path_refine,
+ prompts=prompts,
+ outline=outline,
+ max_iterations=max_iterations,
+ model=model,
+ config=config
+ )
+ except Exception as e:
+ print(f"❌ 处理 {file_name} 时出错: {e}")
+
+ print(f"✅ 所有文件处理完成,结果保存在: {output_index}")
+
+def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
+ # ---------------- 0. 配置准备 ----------------
+ if config is None:
+ config = {}
+
+ api_keys_conf = config.get('api_keys', {})
+
+ # 判别平台
+ is_gemini = "gemini" in model.lower()
+
+ # ---------------- 1. 路径与文件准备 ----------------
+ auto_path = Path(input_html_path).parent
+ final_index = os.path.join(auto_path, "final")
+ final_index_image = os.path.join(final_index, "images")
+ os.makedirs(final_index, exist_ok=True)
+
+ # 复制图片文件夹
+ source_images_dir = os.path.join(auto_path, "images")
+ if os.path.exists(source_images_dir):
+ if not os.path.exists(final_index_image):
+ shutil.copytree(source_images_dir, final_index_image, dirs_exist_ok=True)
+ print(f"📁 Images copied to: {final_index_image}")
+
+ with open(input_html_path, 'r', encoding='utf-8') as f:
+ current_html = f.read()
+
+ # ---------------- 2. 截图逻辑 (保持不变) ----------------
+ screenshot_name = Path(input_html_path).stem + ".png"
+ screenshot_path = os.path.join(final_index, screenshot_name)
+
+ print(f"📸 Taking screenshot of {input_html_path}...")
+ # 假设 take_screenshot_poster 是外部定义的函数
+ take_screenshot_poster(input_html_path, screenshot_path)
+
+ if not os.path.exists(screenshot_path):
+ raise FileNotFoundError("Screenshot failed to generate.")
+
+ # 读取截图数据
+ with open(screenshot_path, "rb") as f:
+ image_bytes = f.read()
+
+ generated_text = ""
+
+ # ---------------- 3. 调用 LLM ----------------
+ print(f"🤖 Sending to Vision Model ({model}) on {'Gemini' if is_gemini else 'OpenAI'}...")
+
+ try:
+ if is_gemini:
+ # === Gemini Client Setup ===
+ api_key = api_keys_conf.get('gemini_api_key') or os.getenv("GOOGLE_API_KEY")
+
+ client = genai.Client(api_key=api_key)
+
+ # 构造 Gemini 所需的 Contents
+ # 新版 SDK (google.genai) 推荐的构造方式
+ response = client.models.generate_content(
+ model=model,
+ contents=[
+ types.Part.from_text(text=prompts),
+ types.Part.from_text(text=f"--- CURRENT HTML ---\n{current_html}"),
+ types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
+ ]
+ )
+
+ if response.text:
+ generated_text = response.text
+ else:
+ raise RuntimeError("Gemini returned empty text.")
+
+ else:
+ # === OpenAI Client Setup ===
+ api_key = api_keys_conf.get('openai_api_key') or os.getenv("OPENAI_API_KEY")
+
+ client = OpenAI(api_key=api_key)
+
+ # OpenAI 需要 Base64 编码的图片
+ base64_image = base64.b64encode(image_bytes).decode('utf-8')
+
+ messages = [
+ {
+ "role": "system",
+ "content": "You are an expert web designer and code refiner."
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": f"{prompts}\n\n--- CURRENT HTML ---\n{current_html}"
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/png;base64,{base64_image}",
+ "detail": "high"
+ }
+ }
+ ]
+ }
+ ]
+
+ response = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ max_tokens=4096
+ )
+ generated_text = response.choices[0].message.content
+
+ # ---------------- 4. 解析与保存结果 ----------------
+ # 清洗 Markdown 代码块标记
+ if "```html" in generated_text:
+ final_html = generated_text.split("```html")[1].split("```")[0].strip()
+ elif "```" in generated_text:
+ final_html = generated_text.split("```")[1].strip()
+ else:
+ final_html = generated_text
+
+ with open(output_html_path, 'w', encoding='utf-8') as f:
+ f.write(final_html)
+
+ print(f"✅ Refined poster saved to: {output_html_path}")
+
+ # 生成最终截图
+ final_screenshot_name = Path(input_html_path).stem + "_final" + ".png"
+ final_screenshot_path = os.path.join(final_index, final_screenshot_name)
+
+ print(f"📸 Taking final poster screenshot of {output_html_path}...")
+ take_screenshot_poster(output_html_path, final_screenshot_path)
+
+ except Exception as e:
+ print(f"❌ Error during AI generation: {e}")
+
+def refinement_pr(pr_path: str, pr_refine_path: str, prompts: dict, model: str, config: dict):
+ """
+ 提取Markdown中的特定章节,使用LLM根据传入的prompts指令进行优化,并重组文件。
+ 严格保留Markdown原有结构、图片引用以及未被选中的尾部内容(如Hashtags)。
+ """
+
+ # 1. (修改) 获取配置,不再依赖环境变量,API Key在调用前再具体提取
+ if config is None:
+ config = {}
+ api_keys = config.get('api_keys', {})
+
+ # 2. 读取原始文件
+ if not os.path.exists(pr_path):
+ raise FileNotFoundError(f"文件未找到: {pr_path}")
+
+ with open(pr_path, 'r', encoding='utf-8') as f:
+ original_content = f.read()
+
+ # 3. 定义部分标题映射
+ section_headers = {
+ "Key Question": r"🔍 \*\*Key Question\*\*",
+ "Brilliant Idea": r"💡 \*\*Brilliant Idea\*\*",
+ "Core Methods": r"🚀 \*\*Core Methods\*\*",
+ "Core Results": r"📊 \*\*Core Results\*\*",
+ "Significance/Impact": r"🧠 \*\*Significance/Impact\*\*"
+ }
+
+ footer_pattern = r"🏷️\s*\*\*Hashtag\*\*"
+
+ # 4. 定位 核心标题 位置
+ matches = []
+ for key, pattern in section_headers.items():
+ found = list(re.finditer(pattern, original_content))
+ if found:
+ match = found[0]
+ matches.append({
+ "key": key,
+ "header_start": match.start(),
+ "header_end": match.end(),
+ "header_text": match.group()
+ })
+
+ matches.sort(key=lambda x: x["header_start"])
+
+ if not matches:
+ print("未检测到目标章节,直接复制文件。")
+ with open(pr_refine_path, 'w', encoding='utf-8') as f:
+ f.write(original_content)
+ return
+
+ # 定位 Footer (Hashtag) 位置
+ footer_match = re.search(footer_pattern, original_content)
+ if footer_match:
+ global_content_end_limit = footer_match.start()
+ else:
+ print("Warning: 未检测到 '🏷️ **Hashtag**' 标记,最后一个章节将读取至文件末尾。")
+ global_content_end_limit = len(original_content)
+
+ # 5. 精确计算每个章节的“内容”范围
+ content_ranges = {}
+ for i, match in enumerate(matches):
+ key = match["key"]
+ content_start = match["header_end"]
+ if i < len(matches) - 1:
+ content_end = matches[i+1]["header_start"]
+ else:
+ content_end = max(content_start, global_content_end_limit)
+
+ content_ranges[key] = {
+ "start": content_start,
+ "end": content_end,
+ "text": original_content[content_start:content_end].strip()
+ }
+
+ # 6. 构建 LLM 请求
+ extracted_data = {k: v["text"] for k, v in content_ranges.items()}
+
+ system_prompt = (
+ "You are an expert academic editor. Your task is to refine the content of specific sections of a paper summary based on user instructions.\n"
+ "Input Format: JSON object {Section Name: Content}.\n"
+ "Output Format: JSON object {Section Name: Refined Content}.\n"
+ "CRITICAL RULES:\n"
+ "1. **KEYS**: Keep the JSON keys EXACTLY the same as the input.\n"
+ "2. **PURE BODY TEXT**: The output value must be pure body text. No Headers.\n"
+ "3. **IMAGES**: Do NOT remove or modify markdown image links.\n"
+ "4. **JSON ONLY**: Output pure JSON string.\n"
+ "5. **FORMAT**: Use bolding ONLY for emphasis."
+ )
+
+ user_message = f"""
+ [Refinement Instructions]
+ {json.dumps(prompts, ensure_ascii=False)}
+
+ [Content to Refine]
+ {json.dumps(extracted_data, ensure_ascii=False)}
+ """
+
+ # === (修改) 核心:根据模型类型分流调用 ===
+ llm_output = ""
+ try:
+ if "gemini" in model.lower():
+ # --- Google Gemini (New SDK) ---
+ api_key = api_keys.get("gemini_api_key", "").strip()
+
+ if not api_key:
+ raise ValueError("Missing config['api_keys']['gemini_api_key']")
+
+ from google import genai
+ from google.genai import types
+
+ # 配置客户端
+ client = genai.Client(api_key=api_key)
+
+ response = client.models.generate_content(
+ model=model,
+ contents=user_message,
+ config=types.GenerateContentConfig(
+ system_instruction=system_prompt,
+ temperature=0.2,
+ response_mime_type="application/json" # 强制 JSON 模式,提高稳定性
+ )
+ )
+ llm_output = response.text
+
+ else:
+ # --- OpenAI (Original) ---
+ api_key = api_keys.get("openai_api_key", "").strip()
+
+ if not api_key:
+ # 兼容性回退:如果config里没有,尝试读环境变量
+ api_key = os.getenv("OPENAI_API_KEY")
+
+ if not api_key:
+ raise ValueError("Missing config['api_keys']['openai_api_key']")
+
+ from openai import OpenAI
+ client = OpenAI(api_key=api_key)
+
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_message}
+ ],
+ temperature=0.2
+ )
+ llm_output = response.choices[0].message.content.strip()
+
+ except Exception as e:
+ print(f"LLM API 调用失败: {e}")
+ return
+
+ # 7. 清洗 LLM 返回的 JSON
+ try:
+ # 移除可能存在的 markdown 代码块标记
+ cleaned_output = llm_output.replace("```json", "").replace("```", "").strip()
+ refined_data = json.loads(cleaned_output)
+ except json.JSONDecodeError:
+ print("解析 LLM 返回的 JSON 失败。Raw output:", llm_output)
+ return
+
+ # 8. 重组文件
+ new_file_parts = []
+ current_idx = 0
+
+ # 按照原文件中的出现顺序处理
+ sorted_matches = sorted(matches, key=lambda x: x["header_start"])
+
+ for item in sorted_matches:
+ key = item["key"]
+ range_info = content_ranges[key]
+ c_start = range_info["start"]
+ c_end = range_info["end"]
+
+ # 1. 拼接未修改部分 (上一个节点结束 到 当前节点内容开始)
+ pre_content = original_content[current_idx:c_start]
+ new_file_parts.append(pre_content)
+
+ # 2. 拼接新内容
+ if key in refined_data:
+ new_text = refined_data[key]
+ # 简单格式处理:确保换行
+ if new_file_parts[-1] and not new_file_parts[-1].endswith('\n'):
+ new_text = "\n" + new_text
+ new_text = "\n" + new_text.strip() + "\n"
+ new_file_parts.append(new_text)
+ else:
+ new_file_parts.append(original_content[c_start:c_end])
+
+ new_file_parts.append('\n')
+
+ # 3. 更新游标
+ current_idx = c_end
+
+ # 9. 添加文件剩余的所有内容
+ new_file_parts.append(original_content[current_idx:])
+
+ final_markdown = "".join(new_file_parts)
+
+ # 10. 保存结果
+ os.makedirs(os.path.dirname(os.path.abspath(pr_refine_path)), exist_ok=True)
+
+ with open(pr_refine_path, 'w', encoding='utf-8') as f:
+ f.write(final_markdown)
+
+ print(f"文件优化完成,已保存至: {pr_refine_path}")
+