""" 将标注转移到 OCR 后的 JSON 文本 chunk 中 ------------------------------------------ 支持两种 chunk 模式 × 两种文档模式: chunk: "length" (SentenceSplitter) / "structure" (DFS-based) doc: "single" (V1.csv, 单文档问题) / "cross" (cross.xlsx, 跨文档问题) 用法: python transfer_annotations.py [chunk_mode] [doc_mode] 例: python transfer_annotations.py length single # 默认 python transfer_annotations.py structure cross # structure chunk + cross 标注 """ import os import re import json import sys import pandas as pd from llama_index.core.node_parser import SentenceSplitter # ======================== 配置 ======================== import os CSV_PATH = os.path.join( "Report-Level Dataset", "ClimRetrieve_ReportLevel_V1.csv" ) CROSS_XLSX_PATH = os.path.join( "Expert-Annotated Relevant Sources Dataset", "ClimRetrieve_cross.xlsx" ) MINERU_DIR = "MinerU_Reports" # ---- Chunk 模式 ---- CHUNK_MODE = "structure" # "length" 或 "structure" # ---- 文档模式 ---- # "single" : 使用 V1.csv 的 relevant_text (单文档问题) # "cross" : 使用 cross.xlsx 的 Relevant (跨文档问题) DOC_MODE = "cross" # length 模式参数 CHUNK_SIZE = 350 # 目标 chunk 大小 (词数) CHUNK_OVERLAP = 50 # overlap 大小 (词数) # structure 模式参数 STRUCTURE_MAX_TOKENS = 550 RELEVANCE_THRESHOLD = 1 # 只使用 relevance >= 此阈值的 relevant_text (single 模式) MIN_MATCH_LEN = 30 # relevant_text 最短长度, 太短的跳过(避免误匹配) # ======================== 工具函数 ======================== def word_count_tokenizer(text: str) -> list: """按空格分词, 返回 token 列表, 用于 SentenceSplitter 按词数切分.""" return text.split() def normalize_text(text: str) -> str: """标准化文本: 去除多余空白, 小写化.""" text = re.sub(r'\s+', ' ', text).strip().lower() return text def load_content_list_json(json_path: str) -> str: """加载 content_list.json, 提取所有 text 类型的文本并拼接.""" with open(json_path, 'r', encoding='utf-8') as f: content_list = json.load(f) texts = [] for item in content_list: if item.get('type') == 'text' and item.get('text'): texts.append(item['text'].strip()) # 用换行连接各段文字 full_text = "\n".join(texts) return full_text def build_report_name_mapping(csv_reports, mineru_folders): """建立 CSV report 名称 -> MinerU 文件夹名称 的映射.""" mapping = {} for csv_name in csv_reports: # 去掉 .pdf 后缀 stripped = csv_name.replace(".pdf", "").strip() # 精确匹配 (忽略大小写和尾部空格) found = False for folder in mineru_folders: if stripped.lower() == folder.strip().lower(): mapping[csv_name] = folder found = True break if not found: # 尝试模糊匹配: 基于词集重叠 csv_words = set(stripped.lower().split()) best_match = None best_score = 0 for folder in mineru_folders: folder_words = set(folder.lower().split()) overlap = len(csv_words & folder_words) total = max(len(csv_words), len(folder_words)) score = overlap / total if total > 0 else 0 if score > best_score: best_score = score best_match = folder if best_match and best_score >= 0.8: mapping[csv_name] = best_match print(f" [模糊匹配] CSV '{csv_name}' -> MinerU '{best_match}' (score={best_score:.2f})") else: print(f" [未匹配] CSV '{csv_name}' (best: '{best_match}', score={best_score:.2f})") return mapping def find_content_list_json(mineru_dir, folder_name): """在 MinerU 文件夹中找到 _content_list.json 文件.""" folder_path = os.path.join(mineru_dir, folder_name) for f in os.listdir(folder_path): if f.endswith('_content_list.json'): return os.path.join(folder_path, f) return None def check_chunk_contains_relevant_text(chunk_normalized: str, relevant_text_normalized: str) -> bool: """检查 chunk 是否包含 relevant_text. 策略: 1. 直接子串匹配 (标准化后) 2. 如果 relevant_text 较长, 尝试匹配其中连续子句 (取前/后 50% 词) """ # 直接子串匹配 if relevant_text_normalized in chunk_normalized: return True # 尝试匹配 relevant_text 的前半部分和后半部分 # (因为 OCR 可能在句子边界处有差异) words = relevant_text_normalized.split() if len(words) >= 8: # 取前 50% 的词 front_part = ' '.join(words[:int(len(words) * 0.5)]) if len(front_part) >= MIN_MATCH_LEN and front_part in chunk_normalized: return True # 取后 50% 的词 back_part = ' '.join(words[int(len(words) * 0.5):]) if len(back_part) >= MIN_MATCH_LEN and back_part in chunk_normalized: return True return False def normalize_chunk_for_matching_raw(chunk_text: str) -> str: """标准化 chunk 原始文本 (保留标题行) 用于匹配.""" return normalize_text(chunk_text) def build_chunk_match_views(chunk_text: str) -> dict: """ 为同一 chunk 构建双通道匹配视图: - raw: 保留标题行 - stripped: 去除 markdown 标题行 若两者文本相同, 仅保留 raw. """ raw_norm = normalize_chunk_for_matching_raw(chunk_text) stripped_norm = normalize_chunk_for_matching(chunk_text) if raw_norm == stripped_norm: return {"raw": raw_norm} return {"raw": raw_norm, "stripped": stripped_norm} def match_relevant_text_multi_view(chunk_views: dict, relevant_text_normalized: str) -> tuple: """ 在多个文本视图上匹配 relevant_text. 返回: (is_matched: bool, hit_channels: list[str]) """ hit_channels = [] for channel, chunk_norm in chunk_views.items(): if check_chunk_contains_relevant_text(chunk_norm, relevant_text_normalized): hit_channels.append(channel) return (len(hit_channels) > 0), hit_channels # ======================== 主流程 ======================== def chunk_report_length(json_path, splitter): """[length 模式] 加载 OCR 文本并用 SentenceSplitter 切分.""" full_text = load_content_list_json(json_path) if len(full_text.strip()) == 0: return [], [] chunks = splitter.split_text(full_text) # 返回 (chunk_texts, chunk_extras) — extras 为空 dict 列表 return chunks, [{} for _ in chunks] def chunk_report_structure(json_path, max_tokens): """[structure 模式] 用 Structure-based DFS Chunking 切分.""" from Experiments.structure_chunker import structure_chunk_document chunks_data = structure_chunk_document(json_path, max_tokens=max_tokens) if not chunks_data: return [], [] texts = [c["text"] for c in chunks_data] extras = [{"section_path": " > ".join(c["metadata"]["section_path"]), "document_id": c["metadata"]["document_id"]} for c in chunks_data] return texts, extras def normalize_chunk_for_matching(chunk_text: str) -> str: """ 标准化 chunk 文本用于 relevant_text 匹配. structure chunk 带有 Markdown 标题行, 匹配时需要去除. """ lines = chunk_text.split('\n') content_lines = [l for l in lines if not l.strip().startswith('#')] return normalize_text(' '.join(content_lines)) def _build_output_dir(chunk_mode, doc_mode): """构建输出目录名.""" base = "OCR_Chunked_Annotated" parts = [base] if chunk_mode == "structure": parts.append("structure") if doc_mode == "cross": parts.append("cross") return "_".join(parts) def load_annotation_source(doc_mode): """ 根据 doc_mode 加载标注数据, 返回统一格式: question_doc_relevant: dict[(report, question)] -> list[str] (标准化后的 relevant texts) all_reports: list[str] all_question_docs: list[(question, report)] """ if doc_mode == "single": df = pd.read_csv(CSV_PATH, index_col=0) print(f" 加载 V1.csv: {len(df)} 行, {df['report'].nunique()} 报告, {df['question'].nunique()} 问题") print(f" relevance 分布:\n{df['relevance'].value_counts().sort_index().to_string()}") high_rel = df[df['relevance'] >= RELEVANCE_THRESHOLD].copy() print(f" relevance >= {RELEVANCE_THRESHOLD}: {len(high_rel)} 行, {high_rel['relevant_text'].nunique()} 唯一 relevant_text") all_reports = list(df['report'].unique()) question_doc_relevant = {} for report in all_reports: report_qs = df[df['report'] == report]['question'].unique() for q in report_qs: q_rel = high_rel[(high_rel['report'] == report) & (high_rel['question'] == q)] rel_texts = q_rel['relevant_text'].dropna().unique() normalized = [normalize_text(rt) for rt in rel_texts if len(str(rt).strip()) >= MIN_MATCH_LEN] question_doc_relevant[(report, q)] = normalized return question_doc_relevant, all_reports elif doc_mode == "cross": df = pd.read_excel(CROSS_XLSX_PATH) if 'Unnamed: 0' in df.columns: df = df.drop(columns=['Unnamed: 0']) print(f" 加载 cross.xlsx: {len(df)} 行, {df['Document'].nunique()} 报告, {df['Question'].nunique()} 问题") print(f" Source Relevance Score 分布:\n{df['Source Relevance Score'].value_counts().sort_index().to_string()}") all_reports = list(df['Document'].unique()) question_doc_relevant = {} for _, row in df.iterrows(): report = row['Document'] question = row['Question'] relevant = row.get('Relevant', '') if pd.isna(relevant) or len(str(relevant).strip()) < MIN_MATCH_LEN: continue key = (report, question) if key not in question_doc_relevant: question_doc_relevant[key] = [] norm = normalize_text(str(relevant)) if norm not in question_doc_relevant[key]: question_doc_relevant[key].append(norm) print(f" (report, question) 对数: {len(question_doc_relevant)}") return question_doc_relevant, all_reports else: raise ValueError(f"未知 DOC_MODE: {doc_mode}") def main(): global CHUNK_MODE, DOC_MODE # 命令行参数: python transfer_annotations.py [chunk_mode] [doc_mode] if len(sys.argv) > 1 and sys.argv[1] in ("length", "structure"): CHUNK_MODE = sys.argv[1] if len(sys.argv) > 2 and sys.argv[2] in ("single", "cross"): DOC_MODE = sys.argv[2] output_dir = _build_output_dir(CHUNK_MODE, DOC_MODE) print("=" * 60) print(f"标注转移 (CHUNK={CHUNK_MODE}, DOC={DOC_MODE})") print(f"输出目录: {output_dir}") print("=" * 60) # ---- Step 1: 加载标注数据 ---- print("\nStep 1: 加载标注数据") print("=" * 60) question_doc_relevant, all_reports = load_annotation_source(DOC_MODE) # ---- Step 2: 建立 report 名称映射 ---- print("\n" + "=" * 60) print("Step 2: 建立 report 名称映射") print("=" * 60) mineru_folders = [f for f in os.listdir(MINERU_DIR) if os.path.isdir(os.path.join(MINERU_DIR, f))] report_mapping = build_report_name_mapping(all_reports, mineru_folders) print(f"\n 成功映射: {len(report_mapping)} / {len(all_reports)}") # ---- Step 3: 切分 chunk + 转移标注 ---- print("\n" + "=" * 60) print(f"Step 3: 切分 chunk ({CHUNK_MODE}), 转移标注 ({DOC_MODE})") print("=" * 60) splitter = None if CHUNK_MODE == "length": splitter = SentenceSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, tokenizer=word_count_tokenizer, ) os.makedirs(output_dir, exist_ok=True) # 收集每个 report 关联的 questions report_questions_map = {} for (report, question) in question_doc_relevant.keys(): if report not in report_questions_map: report_questions_map[report] = set() report_questions_map[report].add(question) all_results = [] for report_name in all_reports: print(f"\n 处理报告: {report_name}") if report_name not in report_mapping: print(f" [跳过] 未找到对应的 MinerU 文件夹") continue folder_name = report_mapping[report_name] json_path = find_content_list_json(MINERU_DIR, folder_name) if json_path is None: print(f" [跳过] 未找到 content_list.json") continue # 切分 chunk if CHUNK_MODE == "length": chunks, extras = chunk_report_length(json_path, splitter) else: chunks, extras = chunk_report_structure(json_path, STRUCTURE_MAX_TOKENS) if not chunks: print(f" [跳过] 无 chunk 产生") continue print(f" 切分为 {len(chunks)} 个 {CHUNK_MODE} chunk") report_questions = report_questions_map.get(report_name, set()) if not report_questions: print(f" [跳过] 该报告无关联问题") continue # 对每个 chunk × question 匹配 for chunk_idx, chunk_text in enumerate(chunks): chunk_views = build_chunk_match_views(chunk_text) chunk_word_count = len(chunk_text.split()) for q in report_questions: rel_norms = question_doc_relevant.get((report_name, q), []) is_relevant = False matched_logs = {} matched_channels = set() for rt_norm in rel_norms: hit, channels = match_relevant_text_multi_view(chunk_views, rt_norm) if hit: is_relevant = True if rt_norm not in matched_logs: matched_logs[rt_norm] = set() matched_logs[rt_norm].update(channels) matched_channels.update(channels) matched_texts = [] for rt_norm in sorted(matched_logs.keys()): channel_tag = "+".join(sorted(matched_logs[rt_norm])) matched_texts.append(f"[{channel_tag}] {rt_norm[:80]}...") result = { 'report': report_name, 'chunk_idx': chunk_idx, 'chunk_text': chunk_text, 'chunk_word_count': chunk_word_count, 'question': q, 'is_relevant': 1 if is_relevant else 0, 'matched_relevant_texts': "; ".join(matched_texts) if matched_texts else "", 'num_matched': len(matched_logs), 'matched_channels': ",".join(sorted(matched_channels)) if matched_channels else "", } if extras and chunk_idx < len(extras) and extras[chunk_idx]: result.update(extras[chunk_idx]) all_results.append(result) # 统计 report_results = [r for r in all_results if r['report'] == report_name] total_pairs = len(report_results) relevant_pairs = sum(1 for r in report_results if r['is_relevant']) print(f" (chunk, question) 对总数: {total_pairs}, 标为 relevant: {relevant_pairs}") # ---- Step 4: 保存 ---- print("\n" + "=" * 60) print("Step 4: 保存结果") print("=" * 60) result_df = pd.DataFrame(all_results) tag_parts = [] if CHUNK_MODE != "length": tag_parts.append(CHUNK_MODE) if DOC_MODE != "single": tag_parts.append(DOC_MODE) tag = ("_" + "_".join(tag_parts)) if tag_parts else "" output_csv = os.path.join(output_dir, f"ocr_chunks_annotated{tag}.csv") result_df.to_csv(output_csv, index=False, encoding='utf-8-sig') print(f" 完整结果: {output_csv}") print(f" 总行数: {len(result_df)}") print(f" 标为 relevant: {result_df['is_relevant'].sum()}") unique_chunks = result_df.drop_duplicates(subset=['report', 'chunk_idx']).shape[0] print(f" 唯一 chunk 数: {unique_chunks}") # chunk 列表 dedup_cols = ['report', 'chunk_idx', 'chunk_text', 'chunk_word_count'] if 'section_path' in result_df.columns: dedup_cols += ['section_path', 'document_id'] chunks_only = result_df.drop_duplicates(subset=['report', 'chunk_idx'])[dedup_cols].reset_index(drop=True) chunks_json = os.path.join(output_dir, f"ocr_chunks_all{tag}.json") chunks_only.to_json(chunks_json, orient='records', force_ascii=False, indent=2) print(f" chunk 列表: {chunks_json} ({len(chunks_only)} chunks)") # 汇总 summary = result_df.groupby('report').agg( total_chunks=('chunk_idx', 'nunique'), total_pairs=('is_relevant', 'count'), relevant_pairs=('is_relevant', 'sum'), ).reset_index() summary['relevant_ratio'] = summary['relevant_pairs'] / summary['total_pairs'] summary_path = os.path.join(output_dir, f"annotation_summary{tag}.csv") summary.to_csv(summary_path, index=False) print(f" 汇总: {summary_path}") print(summary.to_string()) if __name__ == "__main__": main()