APTW_CC / output_stage.py
knighterictsai's picture
Deploy to HF Space
ea058c3
# output_stage.py
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from langchain_core.embeddings import Embeddings
import numpy as np
import re
@dataclass
class OutputConfig:
dedupe_threshold: float = 0.6
add_toc: bool = False # 章內目錄(可選)
class OutputStage:
def __init__(self, cfg: OutputConfig):
self.cfg = cfg
def assemble_from_sections(self, sections_md: List[str]) -> str:
"""
將逐節 Markdown 合併;確保標題層級一致,插入簡短分隔。
"""
normalized = []
for md in sections_md:
fixed = self._normalize_headings(md)
normalized.append(fixed.strip())
return "\n\n---\n\n".join(normalized)
def normalize_and_qc(self, markdown: str, embedder: Optional[Embeddings] = None) -> str:
"""
- 程式碼 fence 語言標註檢查
- 近似段落去重(若提供 embedder)
- (可選)章內目錄
"""
md = self._ensure_code_fence_lang(markdown)
if embedder:
md = self._dedupe_by_embedding(md, embedder, self.cfg.dedupe_threshold)
if self.cfg.add_toc:
md = self._inject_toc(md)
return md
def write_markdown(self, markdown: str, out_dir: str | Path, topic: str) -> Path:
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
fname = slugify(topic) + ".md"
path = out / fname
path.write_text(markdown, encoding="utf-8")
return path
# ----- helpers -----
def _normalize_headings(self, md: str) -> str:
"""
將第一個標題提升為 H2,其下子標題最多到 H4。
"""
lines = md.splitlines()
out = []
for ln in lines:
if ln.startswith("#"):
# 簡單規則:全部標題至少 H2
ln = re.sub(r"^#{1,6}", "##", ln)
out.append(ln)
return "\n".join(out)
def _ensure_code_fence_lang(self, md: str) -> str:
"""
將沒有語言標籤的三引號程式碼,嘗試補上(依內容猜測)。
"""
def guess_lang(code: str) -> str:
# 粗略猜測
if "def " in code or "import " in code:
return "python"
if "createCanvas(" in code or "function setup()" in code:
return "javascript"
return ""
pat = re.compile(r"```(\s*)\n(.*?)\n```", re.DOTALL)
def repl(m):
lang = guess_lang(m.group(2))
return f"```{lang}\n{m.group(2)}\n```"
return pat.sub(repl, md)
def _inject_toc(self, md: str) -> str:
heads = re.findall(r"^##\s+.+$", md, flags=re.MULTILINE)
if not heads:
return md
toc = ["## 章內目錄"]
for h in heads:
title = h[3:].strip()
toc.append(f"- [{title}](#{slugify(title)})")
return "\n".join(toc) + "\n\n" + md
def _dedupe_by_embedding(self, md: str, embedder: Embeddings, threshold: float) -> str:
"""
以段落為單位做相似去重(餘弦相似度 >= threshold 視為重複)。
"""
paras = [p.strip() for p in md.split("\n\n") if p.strip()]
vecs = np.array(embedder.embed_documents(paras), dtype=np.float32)
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-8
nv = vecs / norms
keep = []
kept_vecs = []
for i, v in enumerate(nv):
if not kept_vecs:
keep.append(True)
kept_vecs.append(v)
continue
sims = np.dot(np.vstack(kept_vecs), v)
if sims.max() >= threshold:
keep.append(False)
else:
keep.append(True)
kept_vecs.append(v)
kept_paras = [p for p, k in zip(paras, keep) if k]
return "\n\n".join(kept_paras)
# utilities
def slugify(text: str) -> str:
text = text.lower().strip()
text = re.sub(r"[^\w\s-]", "", text, flags=re.UNICODE)
text = re.sub(r"[\s_-]+", "-", text)
text = re.sub(r"^-+|-+$", "", text)
# Windows 檔名安全
return text[:120]