ClimateRAG_QA / transfer_annotations.py
tengfeiCheng's picture
clean initial deploy
fa0db8b
"""
将标注转移到 OCR 后的 JSON 文本 chunk 中
------------------------------------------
支持两种 chunk 模式 × 两种文档模式:
chunk: "length" (SentenceSplitter) / "structure" (DFS-based)
doc: "single" (V1.csv, 单文档问题) / "cross" (cross.xlsx, 跨文档问题)
用法:
python transfer_annotations.py [chunk_mode] [doc_mode]
例:
python transfer_annotations.py length single # 默认
python transfer_annotations.py structure cross # structure chunk + cross 标注
"""
import os
import re
import json
import sys
import pandas as pd
from llama_index.core.node_parser import SentenceSplitter
# ======================== 配置 ========================
import os
CSV_PATH = os.path.join(
"Report-Level Dataset",
"ClimRetrieve_ReportLevel_V1.csv"
)
CROSS_XLSX_PATH = os.path.join(
"Expert-Annotated Relevant Sources Dataset",
"ClimRetrieve_cross.xlsx"
)
MINERU_DIR = "MinerU_Reports"
# ---- Chunk 模式 ----
CHUNK_MODE = "structure" # "length" 或 "structure"
# ---- 文档模式 ----
# "single" : 使用 V1.csv 的 relevant_text (单文档问题)
# "cross" : 使用 cross.xlsx 的 Relevant (跨文档问题)
DOC_MODE = "cross"
# length 模式参数
CHUNK_SIZE = 350 # 目标 chunk 大小 (词数)
CHUNK_OVERLAP = 50 # overlap 大小 (词数)
# structure 模式参数
STRUCTURE_MAX_TOKENS = 550
RELEVANCE_THRESHOLD = 1 # 只使用 relevance >= 此阈值的 relevant_text (single 模式)
MIN_MATCH_LEN = 30 # relevant_text 最短长度, 太短的跳过(避免误匹配)
# ======================== 工具函数 ========================
def word_count_tokenizer(text: str) -> list:
"""按空格分词, 返回 token 列表, 用于 SentenceSplitter 按词数切分."""
return text.split()
def normalize_text(text: str) -> str:
"""标准化文本: 去除多余空白, 小写化."""
text = re.sub(r'\s+', ' ', text).strip().lower()
return text
def load_content_list_json(json_path: str) -> str:
"""加载 content_list.json, 提取所有 text 类型的文本并拼接."""
with open(json_path, 'r', encoding='utf-8') as f:
content_list = json.load(f)
texts = []
for item in content_list:
if item.get('type') == 'text' and item.get('text'):
texts.append(item['text'].strip())
# 用换行连接各段文字
full_text = "\n".join(texts)
return full_text
def build_report_name_mapping(csv_reports, mineru_folders):
"""建立 CSV report 名称 -> MinerU 文件夹名称 的映射."""
mapping = {}
for csv_name in csv_reports:
# 去掉 .pdf 后缀
stripped = csv_name.replace(".pdf", "").strip()
# 精确匹配 (忽略大小写和尾部空格)
found = False
for folder in mineru_folders:
if stripped.lower() == folder.strip().lower():
mapping[csv_name] = folder
found = True
break
if not found:
# 尝试模糊匹配: 基于词集重叠
csv_words = set(stripped.lower().split())
best_match = None
best_score = 0
for folder in mineru_folders:
folder_words = set(folder.lower().split())
overlap = len(csv_words & folder_words)
total = max(len(csv_words), len(folder_words))
score = overlap / total if total > 0 else 0
if score > best_score:
best_score = score
best_match = folder
if best_match and best_score >= 0.8:
mapping[csv_name] = best_match
print(f" [模糊匹配] CSV '{csv_name}' -> MinerU '{best_match}' (score={best_score:.2f})")
else:
print(f" [未匹配] CSV '{csv_name}' (best: '{best_match}', score={best_score:.2f})")
return mapping
def find_content_list_json(mineru_dir, folder_name):
"""在 MinerU 文件夹中找到 _content_list.json 文件."""
folder_path = os.path.join(mineru_dir, folder_name)
for f in os.listdir(folder_path):
if f.endswith('_content_list.json'):
return os.path.join(folder_path, f)
return None
def check_chunk_contains_relevant_text(chunk_normalized: str, relevant_text_normalized: str) -> bool:
"""检查 chunk 是否包含 relevant_text.
策略:
1. 直接子串匹配 (标准化后)
2. 如果 relevant_text 较长, 尝试匹配其中连续子句 (取前/后 50% 词)
"""
# 直接子串匹配
if relevant_text_normalized in chunk_normalized:
return True
# 尝试匹配 relevant_text 的前半部分和后半部分
# (因为 OCR 可能在句子边界处有差异)
words = relevant_text_normalized.split()
if len(words) >= 8:
# 取前 50% 的词
front_part = ' '.join(words[:int(len(words) * 0.5)])
if len(front_part) >= MIN_MATCH_LEN and front_part in chunk_normalized:
return True
# 取后 50% 的词
back_part = ' '.join(words[int(len(words) * 0.5):])
if len(back_part) >= MIN_MATCH_LEN and back_part in chunk_normalized:
return True
return False
def normalize_chunk_for_matching_raw(chunk_text: str) -> str:
"""标准化 chunk 原始文本 (保留标题行) 用于匹配."""
return normalize_text(chunk_text)
def build_chunk_match_views(chunk_text: str) -> dict:
"""
为同一 chunk 构建双通道匹配视图:
- raw: 保留标题行
- stripped: 去除 markdown 标题行
若两者文本相同, 仅保留 raw.
"""
raw_norm = normalize_chunk_for_matching_raw(chunk_text)
stripped_norm = normalize_chunk_for_matching(chunk_text)
if raw_norm == stripped_norm:
return {"raw": raw_norm}
return {"raw": raw_norm, "stripped": stripped_norm}
def match_relevant_text_multi_view(chunk_views: dict, relevant_text_normalized: str) -> tuple:
"""
在多个文本视图上匹配 relevant_text.
返回:
(is_matched: bool, hit_channels: list[str])
"""
hit_channels = []
for channel, chunk_norm in chunk_views.items():
if check_chunk_contains_relevant_text(chunk_norm, relevant_text_normalized):
hit_channels.append(channel)
return (len(hit_channels) > 0), hit_channels
# ======================== 主流程 ========================
def chunk_report_length(json_path, splitter):
"""[length 模式] 加载 OCR 文本并用 SentenceSplitter 切分."""
full_text = load_content_list_json(json_path)
if len(full_text.strip()) == 0:
return [], []
chunks = splitter.split_text(full_text)
# 返回 (chunk_texts, chunk_extras) — extras 为空 dict 列表
return chunks, [{} for _ in chunks]
def chunk_report_structure(json_path, max_tokens):
"""[structure 模式] 用 Structure-based DFS Chunking 切分."""
from Experiments.structure_chunker import structure_chunk_document
chunks_data = structure_chunk_document(json_path, max_tokens=max_tokens)
if not chunks_data:
return [], []
texts = [c["text"] for c in chunks_data]
extras = [{"section_path": " > ".join(c["metadata"]["section_path"]),
"document_id": c["metadata"]["document_id"]} for c in chunks_data]
return texts, extras
def normalize_chunk_for_matching(chunk_text: str) -> str:
"""
标准化 chunk 文本用于 relevant_text 匹配.
structure chunk 带有 Markdown 标题行, 匹配时需要去除.
"""
lines = chunk_text.split('\n')
content_lines = [l for l in lines if not l.strip().startswith('#')]
return normalize_text(' '.join(content_lines))
def _build_output_dir(chunk_mode, doc_mode):
"""构建输出目录名."""
base = "OCR_Chunked_Annotated"
parts = [base]
if chunk_mode == "structure":
parts.append("structure")
if doc_mode == "cross":
parts.append("cross")
return "_".join(parts)
def load_annotation_source(doc_mode):
"""
根据 doc_mode 加载标注数据, 返回统一格式:
question_doc_relevant: dict[(report, question)] -> list[str] (标准化后的 relevant texts)
all_reports: list[str]
all_question_docs: list[(question, report)]
"""
if doc_mode == "single":
df = pd.read_csv(CSV_PATH, index_col=0)
print(f" 加载 V1.csv: {len(df)} 行, {df['report'].nunique()} 报告, {df['question'].nunique()} 问题")
print(f" relevance 分布:\n{df['relevance'].value_counts().sort_index().to_string()}")
high_rel = df[df['relevance'] >= RELEVANCE_THRESHOLD].copy()
print(f" relevance >= {RELEVANCE_THRESHOLD}: {len(high_rel)} 行, {high_rel['relevant_text'].nunique()} 唯一 relevant_text")
all_reports = list(df['report'].unique())
question_doc_relevant = {}
for report in all_reports:
report_qs = df[df['report'] == report]['question'].unique()
for q in report_qs:
q_rel = high_rel[(high_rel['report'] == report) & (high_rel['question'] == q)]
rel_texts = q_rel['relevant_text'].dropna().unique()
normalized = [normalize_text(rt) for rt in rel_texts if len(str(rt).strip()) >= MIN_MATCH_LEN]
question_doc_relevant[(report, q)] = normalized
return question_doc_relevant, all_reports
elif doc_mode == "cross":
df = pd.read_excel(CROSS_XLSX_PATH)
if 'Unnamed: 0' in df.columns:
df = df.drop(columns=['Unnamed: 0'])
print(f" 加载 cross.xlsx: {len(df)} 行, {df['Document'].nunique()} 报告, {df['Question'].nunique()} 问题")
print(f" Source Relevance Score 分布:\n{df['Source Relevance Score'].value_counts().sort_index().to_string()}")
all_reports = list(df['Document'].unique())
question_doc_relevant = {}
for _, row in df.iterrows():
report = row['Document']
question = row['Question']
relevant = row.get('Relevant', '')
if pd.isna(relevant) or len(str(relevant).strip()) < MIN_MATCH_LEN:
continue
key = (report, question)
if key not in question_doc_relevant:
question_doc_relevant[key] = []
norm = normalize_text(str(relevant))
if norm not in question_doc_relevant[key]:
question_doc_relevant[key].append(norm)
print(f" (report, question) 对数: {len(question_doc_relevant)}")
return question_doc_relevant, all_reports
else:
raise ValueError(f"未知 DOC_MODE: {doc_mode}")
def main():
global CHUNK_MODE, DOC_MODE
# 命令行参数: python transfer_annotations.py [chunk_mode] [doc_mode]
if len(sys.argv) > 1 and sys.argv[1] in ("length", "structure"):
CHUNK_MODE = sys.argv[1]
if len(sys.argv) > 2 and sys.argv[2] in ("single", "cross"):
DOC_MODE = sys.argv[2]
output_dir = _build_output_dir(CHUNK_MODE, DOC_MODE)
print("=" * 60)
print(f"标注转移 (CHUNK={CHUNK_MODE}, DOC={DOC_MODE})")
print(f"输出目录: {output_dir}")
print("=" * 60)
# ---- Step 1: 加载标注数据 ----
print("\nStep 1: 加载标注数据")
print("=" * 60)
question_doc_relevant, all_reports = load_annotation_source(DOC_MODE)
# ---- Step 2: 建立 report 名称映射 ----
print("\n" + "=" * 60)
print("Step 2: 建立 report 名称映射")
print("=" * 60)
mineru_folders = [f for f in os.listdir(MINERU_DIR)
if os.path.isdir(os.path.join(MINERU_DIR, f))]
report_mapping = build_report_name_mapping(all_reports, mineru_folders)
print(f"\n 成功映射: {len(report_mapping)} / {len(all_reports)}")
# ---- Step 3: 切分 chunk + 转移标注 ----
print("\n" + "=" * 60)
print(f"Step 3: 切分 chunk ({CHUNK_MODE}), 转移标注 ({DOC_MODE})")
print("=" * 60)
splitter = None
if CHUNK_MODE == "length":
splitter = SentenceSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
tokenizer=word_count_tokenizer,
)
os.makedirs(output_dir, exist_ok=True)
# 收集每个 report 关联的 questions
report_questions_map = {}
for (report, question) in question_doc_relevant.keys():
if report not in report_questions_map:
report_questions_map[report] = set()
report_questions_map[report].add(question)
all_results = []
for report_name in all_reports:
print(f"\n 处理报告: {report_name}")
if report_name not in report_mapping:
print(f" [跳过] 未找到对应的 MinerU 文件夹")
continue
folder_name = report_mapping[report_name]
json_path = find_content_list_json(MINERU_DIR, folder_name)
if json_path is None:
print(f" [跳过] 未找到 content_list.json")
continue
# 切分 chunk
if CHUNK_MODE == "length":
chunks, extras = chunk_report_length(json_path, splitter)
else:
chunks, extras = chunk_report_structure(json_path, STRUCTURE_MAX_TOKENS)
if not chunks:
print(f" [跳过] 无 chunk 产生")
continue
print(f" 切分为 {len(chunks)}{CHUNK_MODE} chunk")
report_questions = report_questions_map.get(report_name, set())
if not report_questions:
print(f" [跳过] 该报告无关联问题")
continue
# 对每个 chunk × question 匹配
for chunk_idx, chunk_text in enumerate(chunks):
chunk_views = build_chunk_match_views(chunk_text)
chunk_word_count = len(chunk_text.split())
for q in report_questions:
rel_norms = question_doc_relevant.get((report_name, q), [])
is_relevant = False
matched_logs = {}
matched_channels = set()
for rt_norm in rel_norms:
hit, channels = match_relevant_text_multi_view(chunk_views, rt_norm)
if hit:
is_relevant = True
if rt_norm not in matched_logs:
matched_logs[rt_norm] = set()
matched_logs[rt_norm].update(channels)
matched_channels.update(channels)
matched_texts = []
for rt_norm in sorted(matched_logs.keys()):
channel_tag = "+".join(sorted(matched_logs[rt_norm]))
matched_texts.append(f"[{channel_tag}] {rt_norm[:80]}...")
result = {
'report': report_name,
'chunk_idx': chunk_idx,
'chunk_text': chunk_text,
'chunk_word_count': chunk_word_count,
'question': q,
'is_relevant': 1 if is_relevant else 0,
'matched_relevant_texts': "; ".join(matched_texts) if matched_texts else "",
'num_matched': len(matched_logs),
'matched_channels': ",".join(sorted(matched_channels)) if matched_channels else "",
}
if extras and chunk_idx < len(extras) and extras[chunk_idx]:
result.update(extras[chunk_idx])
all_results.append(result)
# 统计
report_results = [r for r in all_results if r['report'] == report_name]
total_pairs = len(report_results)
relevant_pairs = sum(1 for r in report_results if r['is_relevant'])
print(f" (chunk, question) 对总数: {total_pairs}, 标为 relevant: {relevant_pairs}")
# ---- Step 4: 保存 ----
print("\n" + "=" * 60)
print("Step 4: 保存结果")
print("=" * 60)
result_df = pd.DataFrame(all_results)
tag_parts = []
if CHUNK_MODE != "length":
tag_parts.append(CHUNK_MODE)
if DOC_MODE != "single":
tag_parts.append(DOC_MODE)
tag = ("_" + "_".join(tag_parts)) if tag_parts else ""
output_csv = os.path.join(output_dir, f"ocr_chunks_annotated{tag}.csv")
result_df.to_csv(output_csv, index=False, encoding='utf-8-sig')
print(f" 完整结果: {output_csv}")
print(f" 总行数: {len(result_df)}")
print(f" 标为 relevant: {result_df['is_relevant'].sum()}")
unique_chunks = result_df.drop_duplicates(subset=['report', 'chunk_idx']).shape[0]
print(f" 唯一 chunk 数: {unique_chunks}")
# chunk 列表
dedup_cols = ['report', 'chunk_idx', 'chunk_text', 'chunk_word_count']
if 'section_path' in result_df.columns:
dedup_cols += ['section_path', 'document_id']
chunks_only = result_df.drop_duplicates(subset=['report', 'chunk_idx'])[dedup_cols].reset_index(drop=True)
chunks_json = os.path.join(output_dir, f"ocr_chunks_all{tag}.json")
chunks_only.to_json(chunks_json, orient='records', force_ascii=False, indent=2)
print(f" chunk 列表: {chunks_json} ({len(chunks_only)} chunks)")
# 汇总
summary = result_df.groupby('report').agg(
total_chunks=('chunk_idx', 'nunique'),
total_pairs=('is_relevant', 'count'),
relevant_pairs=('is_relevant', 'sum'),
).reset_index()
summary['relevant_ratio'] = summary['relevant_pairs'] / summary['total_pairs']
summary_path = os.path.join(output_dir, f"annotation_summary{tag}.csv")
summary.to_csv(summary_path, index=False)
print(f" 汇总: {summary_path}")
print(summary.to_string())
if __name__ == "__main__":
main()