ClimateRAG_QA / Experiments /structure_chunker.py
tengfeiCheng's picture
add cleaned experiments code
12323e1
"""
Structure-based Chunking + DFS-based Grouping
==============================================
对 MinerU 导出的 content_list.json 进行基于文档结构的智能分块。
核心流程:
A. 预处理: 将 JSON 中的 blocks 拉平并按阅读顺序排序
B. 标题识别: 从 blocks 中筛选候选标题 (text_level / 编号模式 / 启发式)
C. 构建标题层级树: 用栈维护层级, 编号推断 header_level
D. 将普通内容块挂到标题节点下
E. DFS 遍历层级树, 生成 chunk (超过 max_len 则切分)
F. 输出 chunk 列表 (含 document_id, section_path 等元数据)
"""
import os
import re
import json
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
# ======================== 配置 ========================
DEFAULT_MAX_TOKENS = 550 # chunk 最大 token 数 (近似词数)
# ======================== 数据结构 ========================
@dataclass
class Block:
"""OCR 输出的一个内容块."""
index: int # 在 all_segments 中的位置
page_idx: int
bbox: List[int] # [x0, y0, x1, y1]
text: str
block_type: str # "text", "table", "image", "equation", "discarded"
text_level: Optional[int] = None # MinerU 标记的 text_level (1 = 标题)
image_caption: str = ""
image_footnote: str = ""
# 推断字段
is_header: bool = False
header_level: int = 0 # 推断的层级深度: 1=顶级, 2=二级, 3=三级 ...
@dataclass
class TreeNode:
"""层级树的节点."""
header_text: str
header_level: int # 0=root, 1=title, 2=section, 3=subsection ...
block: Optional[Block] = None # 对应的原始 block (root 没有)
children: List['TreeNode'] = field(default_factory=list)
content_blocks: List[Block] = field(default_factory=list) # 挂载的普通内容块
parent: Optional['TreeNode'] = None
# ======================== A. 预处理 ========================
def parse_content_list(json_path: str) -> Tuple[List[Block], str]:
"""
加载 content_list.json, 解析为 Block 列表并按阅读顺序排序.
返回 (all_segments, document_id).
"""
with open(json_path, 'r', encoding='utf-8') as f:
content_list = json.load(f)
# 从路径提取 document_id
folder_name = os.path.basename(os.path.dirname(json_path))
document_id = folder_name
all_segments = []
for i, item in enumerate(content_list):
btype = item.get('type', 'text')
if btype == 'discarded':
continue # 跳过丢弃块
text = item.get('text', '').strip()
bbox = item.get('bbox', [0, 0, 0, 0])
page_idx = item.get('page_idx', 0)
text_level = item.get('text_level', None)
# image_caption / image_footnote 可能是 str 或 list
raw_caption = item.get('image_caption', '') or ''
if isinstance(raw_caption, list):
raw_caption = ' '.join(str(x) for x in raw_caption if x)
raw_footnote = item.get('image_footnote', '') or ''
if isinstance(raw_footnote, list):
raw_footnote = ' '.join(str(x) for x in raw_footnote if x)
block = Block(
index=i,
page_idx=page_idx,
bbox=bbox,
text=text,
block_type=btype,
text_level=text_level,
image_caption=str(raw_caption).strip(),
image_footnote=str(raw_footnote).strip(),
)
all_segments.append(block)
# 按阅读顺序排序: page 升序 -> y0 升序 -> x0 升序
all_segments.sort(key=lambda b: (b.page_idx, b.bbox[1] if len(b.bbox) >= 2 else 0,
b.bbox[0] if len(b.bbox) >= 1 else 0))
# 重新编排 index
for idx, seg in enumerate(all_segments):
seg.index = idx
return all_segments, document_id
# ======================== B. 标题识别 ========================
# 编号模式: "1", "1.1", "1.1.1", "2.1 Title", "A.1", etc.
_NUMBERED_HEADING_RE = re.compile(
r'^(?:'
r'(?P<num>\d+(?:\.\d+)*)\.?\s+' # "1", "1.1", "1.1.1", "2.1 "
r'|(?P<letter>[A-Z](?:\.\d+)*)\.\s+' # "A.", "A.1.", "B.2.1."
r')'
r'(?P<title>.+)',
re.DOTALL
)
# 罗马数字模式: "I.", "II.", "III.", "IV." ...
_ROMAN_HEADING_RE = re.compile(
r'^(?P<roman>(?:X{0,3})(?:IX|IV|V?I{0,3}))[\.\)]\s+(?P<title>.+)',
re.IGNORECASE | re.DOTALL
)
def infer_heading_level_from_numbering(text: str) -> Optional[int]:
"""
从文本的编号模式推断标题层级.
"1 Title" -> level 1, "1.1 Title" -> level 2, "1.1.1 Title" -> level 3
返回 None 表示无法识别编号.
"""
text = text.strip()
m = _NUMBERED_HEADING_RE.match(text)
if m:
num = m.group('num')
letter = m.group('letter')
if num:
# 按 '.' 分割计数层级
parts = num.split('.')
return len(parts)
if letter:
parts = letter.split('.')
return len(parts)
# 罗马数字视为 level 1
m2 = _ROMAN_HEADING_RE.match(text)
if m2:
return 1
return None
def is_likely_page_header_footer(block: Block, all_segments: List[Block]) -> bool:
"""启发式判断是否为页眉/页脚 (应排除)."""
text = block.text.strip()
if not text:
return True
# 纯数字 (页码)
if re.match(r'^\d{1,4}\s*$', text):
return True
# 很短且在页面顶部或底部
bbox = block.bbox
if len(bbox) >= 4:
height = bbox[3] - bbox[1]
# 非常短的文本行在页面最顶部 (y0 < 50) 或最底部 (y0 > 750)
if height < 20 and len(text) < 30:
if bbox[1] < 50 or bbox[1] > 750:
return True
return False
def is_all_caps_short(text: str, max_words: int = 12) -> bool:
"""判断是否是全大写的短文本 (可能是标题)."""
cleaned = re.sub(r'[\d\.\s\-:,&/]', '', text)
if not cleaned:
return False
words = text.split()
return len(words) <= max_words and cleaned.isupper() and len(cleaned) >= 3
def identify_headers(all_segments: List[Block]) -> List[Block]:
"""
从 all_segments 中识别标题块, 并设置 is_header / header_level.
返回 header_list (按原顺序).
"""
header_list = []
for block in all_segments:
if block.block_type != 'text':
continue
text = block.text.strip()
if not text or len(text) < 2:
continue
if is_likely_page_header_footer(block, all_segments):
continue
# === 策略 1: MinerU text_level ===
if block.text_level is not None and block.text_level >= 1:
block.is_header = True
# 尝试从编号推断更精确的层级
inferred = infer_heading_level_from_numbering(text)
if inferred is not None:
block.header_level = inferred
else:
# text_level=1 但无编号, 用启发式
if is_all_caps_short(text):
block.header_level = 1 # 全大写通常是顶级章节
else:
block.header_level = 2 # 默认为二级
header_list.append(block)
continue
# === 策略 2: 编号模式 (无 text_level 但有编号) ===
inferred = infer_heading_level_from_numbering(text)
if inferred is not None:
words = text.split()
# 编号+标题通常不会太长 (排除正文中以数字开头的段落)
if len(words) <= 20:
block.is_header = True
block.header_level = inferred
header_list.append(block)
continue
# === 策略 3: 全大写短文本 ===
if is_all_caps_short(text, max_words=8):
# 额外检查: bbox 高度和位置
words = text.split()
if len(words) <= 8:
block.is_header = True
block.header_level = 1
header_list.append(block)
return header_list
# ======================== C. 构建标题层级树 ========================
def build_header_tree(header_list: List[Block], document_id: str) -> TreeNode:
"""
用栈从 header_list 构建层级树.
"""
# Root 节点
root = TreeNode(header_text="ROOT", header_level=0)
# Title 节点 (文档标题, 取第一个 header 或 document_id)
if header_list:
first = header_list[0]
title_node = TreeNode(
header_text=first.text.strip(),
header_level=1,
block=first,
parent=root,
)
else:
title_node = TreeNode(
header_text=document_id,
header_level=1,
parent=root,
)
root.children.append(title_node)
# 用栈维护当前层级路径
stack = [title_node] # stack[-1] 是当前活跃节点
for header_block in header_list[1:]:
level = header_block.header_level
# 确保 level >= 1
if level < 1:
level = 1
# 弹栈直到 stack.top.level < level (找到合适的 parent)
while len(stack) > 1 and stack[-1].header_level >= level:
stack.pop()
parent = stack[-1]
node = TreeNode(
header_text=header_block.text.strip(),
header_level=level,
block=header_block,
parent=parent,
)
parent.children.append(node)
stack.append(node)
return root
# ======================== D. 将普通内容块挂到标题节点下 ========================
def attach_content_to_tree(root: TreeNode, all_segments: List[Block], header_list: List[Block]):
"""
遍历 all_segments, 将非标题块挂到 current_header 的 content_blocks 中.
"""
# 构建 header block -> TreeNode 的映射
header_block_set = set(id(h) for h in header_list)
header_to_node = {}
def _map_nodes(node: TreeNode):
if node.block is not None:
header_to_node[id(node.block)] = node
for child in node.children:
_map_nodes(child)
_map_nodes(root)
# 找到 title_node 作为初始 current_header
current_node = root.children[0] if root.children else root
for block in all_segments:
if id(block) in header_block_set and block.is_header:
# 这是一个标题块, 切换 current_header
if id(block) in header_to_node:
current_node = header_to_node[id(block)]
continue
if block.block_type == 'discarded':
continue
# 普通块: 挂到 current_node
current_node.content_blocks.append(block)
# ======================== E. DFS-based Grouping 生成 Chunk ========================
def _count_tokens(text: str) -> int:
"""近似 token 计数 (按空白分词)."""
return len(text.split())
def _block_to_text(block: Block) -> str:
"""将一个内容 block 转化为文本."""
if block.block_type == 'table':
table_text = block.text.strip()
if table_text:
return f"[TABLE]\n{table_text}\n[/TABLE]"
return ""
elif block.block_type == 'image':
parts = []
if block.image_caption:
parts.append(f"[IMAGE CAPTION] {block.image_caption}")
if block.image_footnote:
parts.append(f"[IMAGE NOTE] {block.image_footnote}")
if block.text.strip():
parts.append(block.text.strip())
return "\n".join(parts) if parts else ""
elif block.block_type == 'equation':
eq_text = block.text.strip()
return f"[EQUATION] {eq_text}" if eq_text else ""
else:
return block.text.strip()
def _get_section_path(node: TreeNode) -> List[str]:
"""获取从 Title 到当前节点的标题路径."""
path = []
current = node
while current is not None and current.header_level > 0:
path.append(current.header_text)
current = current.parent
path.reverse()
return path
def _make_md_heading(text: str, depth: int) -> str:
"""生成 Markdown 标题行."""
prefix = "#" * min(depth, 6)
return f"{prefix} {text}"
def _collect_node_text(node: TreeNode) -> List[str]:
"""递归收集一个节点及其所有子孙的文本行 (不含标题上下文前缀)."""
lines = []
# 当前节点标题
path = _get_section_path(node)
depth = len(path)
if depth > 0:
lines.append(_make_md_heading(node.header_text, depth))
# 当前节点内容
for block in node.content_blocks:
piece = _block_to_text(block)
if piece:
lines.append(piece)
# 子节点递归
for child in node.children:
lines.extend(_collect_node_text(child))
return lines
def dfs_generate_chunks(
root: TreeNode,
document_id: str,
max_tokens: int = DEFAULT_MAX_TOKENS,
min_tokens: int = 50,
) -> List[Dict]:
"""
DFS 遍历层级树, 生成 chunk 列表.
核心策略:
- 自底向上: 如果一个节点 (含子树) 的全部文本 <= max_tokens, 整体输出为一个 chunk
- 如果超出: 先输出该节点自身的内容, 再分别处理每个子节点
- 小于 min_tokens 的 chunk 会与相邻 chunk 合并
- 每个 chunk 开头带 Markdown 标题链 (section_path), 便于理解上下文
每个 chunk:
{
"text": "<Markdown 格式文本>",
"metadata": {
"document_id": str,
"section_path": List[str]
}
}
"""
raw_chunks = [] # 先收集所有原始 chunk, 最后合并小 chunk
def _build_heading_context(node: TreeNode) -> List[str]:
"""构建从 Title 到当前节点的 Markdown 标题链."""
path = _get_section_path(node)
lines = []
for i, title_text in enumerate(path):
lines.append(_make_md_heading(title_text, i + 1))
return lines
def _emit_chunk(lines: List[str], section_path: List[str]):
"""输出一个 chunk (行列表 -> 文本)."""
text = "\n".join(lines).strip()
if not text:
return
raw_chunks.append({
"text": text,
"metadata": {
"document_id": document_id,
"section_path": list(section_path),
}
})
def _dfs(node: TreeNode):
"""
DFS 处理节点. 策略:
1. 尝试将整个子树作为一个 chunk
2. 如果太大, 则节点自身内容 + 各子节点分别处理
"""
section_path = _get_section_path(node)
heading_lines = _build_heading_context(node)
# 尝试收集整个子树文本
all_lines = _collect_node_text(node)
total_tokens = _count_tokens("\n".join(all_lines))
if total_tokens <= max_tokens:
# 整个子树在限制内 -> 输出为一个 chunk (带标题上下文)
# 如果 heading_lines 和 all_lines 的第一行相同, 避免重复
if heading_lines and all_lines and heading_lines[-1] == all_lines[0]:
output_lines = heading_lines[:-1] + all_lines
else:
output_lines = heading_lines + all_lines
# 去重连续的相同标题行
output_lines = _dedup_heading_lines(output_lines)
_emit_chunk(output_lines, section_path)
return
# 子树太大, 分拆:
# 1) 先输出当前节点自身的 content_blocks (如果有)
own_content = []
for block in node.content_blocks:
piece = _block_to_text(block)
if piece:
own_content.append(piece)
if own_content:
buffer = list(heading_lines)
for piece in own_content:
piece_tokens = _count_tokens(piece)
buffer_tokens = _count_tokens("\n".join(buffer))
if buffer_tokens + piece_tokens > max_tokens and len(buffer) > len(heading_lines):
_emit_chunk(buffer, section_path)
buffer = list(heading_lines)
# 如果单个 piece 本身超长, 按句子/行边界切分
if piece_tokens > max_tokens:
heading_overhead = _count_tokens("\n".join(heading_lines))
available = max_tokens - heading_overhead
words = piece.split()
while len(words) > available:
sub = " ".join(words[:available])
buffer.append(sub)
_emit_chunk(buffer, section_path)
buffer = list(heading_lines)
words = words[available:]
if words:
buffer.append(" ".join(words))
else:
buffer.append(piece)
# flush 剩余
if _count_tokens("\n".join(buffer)) > _count_tokens("\n".join(heading_lines)):
_emit_chunk(buffer, section_path)
# 2) 递归处理每个子节点
for child in node.children:
_dfs(child)
def _dedup_heading_lines(lines: List[str]) -> List[str]:
"""去除连续重复的标题行."""
result = []
for line in lines:
if result and line == result[-1]:
continue
result.append(line)
return result
# 处理 root 的直接 content_blocks
if root.content_blocks:
buf = []
for block in root.content_blocks:
piece = _block_to_text(block)
if piece:
buf.append(piece)
if buf:
_emit_chunk(buf, [])
# DFS 处理 root 的所有 children
for child in root.children:
_dfs(child)
# ---- 后处理: 合并过小的 chunk ----
if not raw_chunks:
return []
def _strip_shared_heading(existing_text: str, new_text: str) -> str:
"""
如果 new_text 的开头标题行与 existing_text 中已包含的标题行重复,
则去掉 new_text 开头的重复标题行, 避免合并后出现冗余.
"""
existing_lines = existing_text.split("\n")
new_lines = new_text.split("\n")
# 找到 existing_text 中所有以 # 开头的标题行 (集合)
existing_headings = set()
for line in existing_lines:
stripped = line.strip()
if stripped.startswith("#"):
existing_headings.add(stripped)
# 跳过 new_text 开头的重复标题行
skip = 0
for line in new_lines:
stripped = line.strip()
if stripped.startswith("#") and stripped in existing_headings:
skip += 1
else:
break
if skip > 0:
return "\n".join(new_lines[skip:])
return new_text
merged = []
buffer_text = ""
buffer_section_path = raw_chunks[0]["metadata"]["section_path"]
for chunk in raw_chunks:
chunk_text = chunk["text"]
chunk_tokens = _count_tokens(chunk_text)
buffer_tokens = _count_tokens(buffer_text)
if buffer_tokens == 0:
# buffer 为空, 直接填入
buffer_text = chunk_text
buffer_section_path = chunk["metadata"]["section_path"]
elif buffer_tokens < min_tokens:
# buffer 过小, 合并 (去除重复标题)
stripped = _strip_shared_heading(buffer_text, chunk_text)
buffer_text = buffer_text + "\n" + stripped
if len(chunk["metadata"]["section_path"]) > len(buffer_section_path):
buffer_section_path = chunk["metadata"]["section_path"]
elif buffer_tokens + chunk_tokens <= max_tokens:
# 合并不超限 (去除重复标题)
stripped = _strip_shared_heading(buffer_text, chunk_text)
buffer_text = buffer_text + "\n" + stripped
if len(chunk["metadata"]["section_path"]) > len(buffer_section_path):
buffer_section_path = chunk["metadata"]["section_path"]
else:
# flush buffer, 开始新 buffer
final = buffer_text.strip()
if final:
merged.append({
"text": final,
"metadata": {
"document_id": document_id,
"section_path": list(buffer_section_path),
}
})
buffer_text = chunk_text
buffer_section_path = chunk["metadata"]["section_path"]
# flush 最后的 buffer
final = buffer_text.strip()
if final:
merged.append({
"text": final,
"metadata": {
"document_id": document_id,
"section_path": list(buffer_section_path),
}
})
return merged
# ======================== 主入口 ========================
def structure_chunk_document(
json_path: str,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> List[Dict]:
"""
对单个文档进行 Structure-based Chunking.
Args:
json_path: content_list.json 路径
max_tokens: chunk 最大 token 数
Returns:
chunk 列表, 每个 chunk:
{
"text": "<Markdown 标题 + 内容>",
"metadata": {
"document_id": str,
"section_path": List[str]
}
}
"""
# A. 预处理
all_segments, document_id = parse_content_list(json_path)
if not all_segments:
return []
# B. 标题识别
header_list = identify_headers(all_segments)
# C. 构建层级树
root = build_header_tree(header_list, document_id)
# D. 挂载普通内容
attach_content_to_tree(root, all_segments, header_list)
# E. DFS 生成 chunks
chunks = dfs_generate_chunks(root, document_id, max_tokens=max_tokens)
return chunks
def structure_chunk_document_flat(
json_path: str,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> Tuple[List[str], List[Dict]]:
"""
便捷接口: 返回 (chunk_texts, chunk_metadatas).
chunk_texts[i] 对应 chunk_metadatas[i].
"""
chunks = structure_chunk_document(json_path, max_tokens=max_tokens)
texts = [c["text"] for c in chunks]
metadatas = [c["metadata"] for c in chunks]
return texts, metadatas
# ======================== CLI / 测试 ========================
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("用法: python structure_chunker.py <content_list.json> [max_tokens]")
print("示例: python structure_chunker.py ../MinerU_Reports/AEO_2022_ESG_Report/AEO_2022_ESG_Report_content_list.json 550")
sys.exit(1)
json_path = sys.argv[1]
max_tokens = int(sys.argv[2]) if len(sys.argv) > 2 else DEFAULT_MAX_TOKENS
print(f"输入: {json_path}")
print(f"max_tokens: {max_tokens}")
print()
chunks = structure_chunk_document(json_path, max_tokens=max_tokens)
print(f"总 chunk 数: {len(chunks)}")
print()
for i, chunk in enumerate(chunks):
text = chunk["text"]
meta = chunk["metadata"]
word_count = len(text.split())
section = " > ".join(meta["section_path"])
print(f"--- Chunk {i} ({word_count} words) ---")
print(f" doc: {meta['document_id']}")
print(f" path: {section}")
# 只显示前 200 字符
preview = text[:200].replace("\n", "\\n")
print(f" text: {preview}...")
print()
# 输出 JSON
output_path = json_path.replace(".json", "_chunks.json")
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(chunks, f, ensure_ascii=False, indent=2)
print(f"输出已保存: {output_path}")