Spaces:
Running
Running
| """ | |
| 将标注转移到 OCR 后的 JSON 文本 chunk 中 | |
| ------------------------------------------ | |
| 支持两种 chunk 模式 × 两种文档模式: | |
| chunk: "length" (SentenceSplitter) / "structure" (DFS-based) | |
| doc: "single" (V1.csv, 单文档问题) / "cross" (cross.xlsx, 跨文档问题) | |
| 用法: | |
| python transfer_annotations.py [chunk_mode] [doc_mode] | |
| 例: | |
| python transfer_annotations.py length single # 默认 | |
| python transfer_annotations.py structure cross # structure chunk + cross 标注 | |
| """ | |
| import os | |
| import re | |
| import json | |
| import sys | |
| import pandas as pd | |
| from llama_index.core.node_parser import SentenceSplitter | |
| # ======================== 配置 ======================== | |
| import os | |
| CSV_PATH = os.path.join( | |
| "Report-Level Dataset", | |
| "ClimRetrieve_ReportLevel_V1.csv" | |
| ) | |
| CROSS_XLSX_PATH = os.path.join( | |
| "Expert-Annotated Relevant Sources Dataset", | |
| "ClimRetrieve_cross.xlsx" | |
| ) | |
| MINERU_DIR = "MinerU_Reports" | |
| # ---- Chunk 模式 ---- | |
| CHUNK_MODE = "structure" # "length" 或 "structure" | |
| # ---- 文档模式 ---- | |
| # "single" : 使用 V1.csv 的 relevant_text (单文档问题) | |
| # "cross" : 使用 cross.xlsx 的 Relevant (跨文档问题) | |
| DOC_MODE = "cross" | |
| # length 模式参数 | |
| CHUNK_SIZE = 350 # 目标 chunk 大小 (词数) | |
| CHUNK_OVERLAP = 50 # overlap 大小 (词数) | |
| # structure 模式参数 | |
| STRUCTURE_MAX_TOKENS = 550 | |
| RELEVANCE_THRESHOLD = 1 # 只使用 relevance >= 此阈值的 relevant_text (single 模式) | |
| MIN_MATCH_LEN = 30 # relevant_text 最短长度, 太短的跳过(避免误匹配) | |
| # ======================== 工具函数 ======================== | |
| def word_count_tokenizer(text: str) -> list: | |
| """按空格分词, 返回 token 列表, 用于 SentenceSplitter 按词数切分.""" | |
| return text.split() | |
| def normalize_text(text: str) -> str: | |
| """标准化文本: 去除多余空白, 小写化.""" | |
| text = re.sub(r'\s+', ' ', text).strip().lower() | |
| return text | |
| def load_content_list_json(json_path: str) -> str: | |
| """加载 content_list.json, 提取所有 text 类型的文本并拼接.""" | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| content_list = json.load(f) | |
| texts = [] | |
| for item in content_list: | |
| if item.get('type') == 'text' and item.get('text'): | |
| texts.append(item['text'].strip()) | |
| # 用换行连接各段文字 | |
| full_text = "\n".join(texts) | |
| return full_text | |
| def build_report_name_mapping(csv_reports, mineru_folders): | |
| """建立 CSV report 名称 -> MinerU 文件夹名称 的映射.""" | |
| mapping = {} | |
| for csv_name in csv_reports: | |
| # 去掉 .pdf 后缀 | |
| stripped = csv_name.replace(".pdf", "").strip() | |
| # 精确匹配 (忽略大小写和尾部空格) | |
| found = False | |
| for folder in mineru_folders: | |
| if stripped.lower() == folder.strip().lower(): | |
| mapping[csv_name] = folder | |
| found = True | |
| break | |
| if not found: | |
| # 尝试模糊匹配: 基于词集重叠 | |
| csv_words = set(stripped.lower().split()) | |
| best_match = None | |
| best_score = 0 | |
| for folder in mineru_folders: | |
| folder_words = set(folder.lower().split()) | |
| overlap = len(csv_words & folder_words) | |
| total = max(len(csv_words), len(folder_words)) | |
| score = overlap / total if total > 0 else 0 | |
| if score > best_score: | |
| best_score = score | |
| best_match = folder | |
| if best_match and best_score >= 0.8: | |
| mapping[csv_name] = best_match | |
| print(f" [模糊匹配] CSV '{csv_name}' -> MinerU '{best_match}' (score={best_score:.2f})") | |
| else: | |
| print(f" [未匹配] CSV '{csv_name}' (best: '{best_match}', score={best_score:.2f})") | |
| return mapping | |
| def find_content_list_json(mineru_dir, folder_name): | |
| """在 MinerU 文件夹中找到 _content_list.json 文件.""" | |
| folder_path = os.path.join(mineru_dir, folder_name) | |
| for f in os.listdir(folder_path): | |
| if f.endswith('_content_list.json'): | |
| return os.path.join(folder_path, f) | |
| return None | |
| def check_chunk_contains_relevant_text(chunk_normalized: str, relevant_text_normalized: str) -> bool: | |
| """检查 chunk 是否包含 relevant_text. | |
| 策略: | |
| 1. 直接子串匹配 (标准化后) | |
| 2. 如果 relevant_text 较长, 尝试匹配其中连续子句 (取前/后 50% 词) | |
| """ | |
| # 直接子串匹配 | |
| if relevant_text_normalized in chunk_normalized: | |
| return True | |
| # 尝试匹配 relevant_text 的前半部分和后半部分 | |
| # (因为 OCR 可能在句子边界处有差异) | |
| words = relevant_text_normalized.split() | |
| if len(words) >= 8: | |
| # 取前 50% 的词 | |
| front_part = ' '.join(words[:int(len(words) * 0.5)]) | |
| if len(front_part) >= MIN_MATCH_LEN and front_part in chunk_normalized: | |
| return True | |
| # 取后 50% 的词 | |
| back_part = ' '.join(words[int(len(words) * 0.5):]) | |
| if len(back_part) >= MIN_MATCH_LEN and back_part in chunk_normalized: | |
| return True | |
| return False | |
| def normalize_chunk_for_matching_raw(chunk_text: str) -> str: | |
| """标准化 chunk 原始文本 (保留标题行) 用于匹配.""" | |
| return normalize_text(chunk_text) | |
| def build_chunk_match_views(chunk_text: str) -> dict: | |
| """ | |
| 为同一 chunk 构建双通道匹配视图: | |
| - raw: 保留标题行 | |
| - stripped: 去除 markdown 标题行 | |
| 若两者文本相同, 仅保留 raw. | |
| """ | |
| raw_norm = normalize_chunk_for_matching_raw(chunk_text) | |
| stripped_norm = normalize_chunk_for_matching(chunk_text) | |
| if raw_norm == stripped_norm: | |
| return {"raw": raw_norm} | |
| return {"raw": raw_norm, "stripped": stripped_norm} | |
| def match_relevant_text_multi_view(chunk_views: dict, relevant_text_normalized: str) -> tuple: | |
| """ | |
| 在多个文本视图上匹配 relevant_text. | |
| 返回: | |
| (is_matched: bool, hit_channels: list[str]) | |
| """ | |
| hit_channels = [] | |
| for channel, chunk_norm in chunk_views.items(): | |
| if check_chunk_contains_relevant_text(chunk_norm, relevant_text_normalized): | |
| hit_channels.append(channel) | |
| return (len(hit_channels) > 0), hit_channels | |
| # ======================== 主流程 ======================== | |
| def chunk_report_length(json_path, splitter): | |
| """[length 模式] 加载 OCR 文本并用 SentenceSplitter 切分.""" | |
| full_text = load_content_list_json(json_path) | |
| if len(full_text.strip()) == 0: | |
| return [], [] | |
| chunks = splitter.split_text(full_text) | |
| # 返回 (chunk_texts, chunk_extras) — extras 为空 dict 列表 | |
| return chunks, [{} for _ in chunks] | |
| def chunk_report_structure(json_path, max_tokens): | |
| """[structure 模式] 用 Structure-based DFS Chunking 切分.""" | |
| from Experiments.structure_chunker import structure_chunk_document | |
| chunks_data = structure_chunk_document(json_path, max_tokens=max_tokens) | |
| if not chunks_data: | |
| return [], [] | |
| texts = [c["text"] for c in chunks_data] | |
| extras = [{"section_path": " > ".join(c["metadata"]["section_path"]), | |
| "document_id": c["metadata"]["document_id"]} for c in chunks_data] | |
| return texts, extras | |
| def normalize_chunk_for_matching(chunk_text: str) -> str: | |
| """ | |
| 标准化 chunk 文本用于 relevant_text 匹配. | |
| structure chunk 带有 Markdown 标题行, 匹配时需要去除. | |
| """ | |
| lines = chunk_text.split('\n') | |
| content_lines = [l for l in lines if not l.strip().startswith('#')] | |
| return normalize_text(' '.join(content_lines)) | |
| def _build_output_dir(chunk_mode, doc_mode): | |
| """构建输出目录名.""" | |
| base = "OCR_Chunked_Annotated" | |
| parts = [base] | |
| if chunk_mode == "structure": | |
| parts.append("structure") | |
| if doc_mode == "cross": | |
| parts.append("cross") | |
| return "_".join(parts) | |
| def load_annotation_source(doc_mode): | |
| """ | |
| 根据 doc_mode 加载标注数据, 返回统一格式: | |
| question_doc_relevant: dict[(report, question)] -> list[str] (标准化后的 relevant texts) | |
| all_reports: list[str] | |
| all_question_docs: list[(question, report)] | |
| """ | |
| if doc_mode == "single": | |
| df = pd.read_csv(CSV_PATH, index_col=0) | |
| print(f" 加载 V1.csv: {len(df)} 行, {df['report'].nunique()} 报告, {df['question'].nunique()} 问题") | |
| print(f" relevance 分布:\n{df['relevance'].value_counts().sort_index().to_string()}") | |
| high_rel = df[df['relevance'] >= RELEVANCE_THRESHOLD].copy() | |
| print(f" relevance >= {RELEVANCE_THRESHOLD}: {len(high_rel)} 行, {high_rel['relevant_text'].nunique()} 唯一 relevant_text") | |
| all_reports = list(df['report'].unique()) | |
| question_doc_relevant = {} | |
| for report in all_reports: | |
| report_qs = df[df['report'] == report]['question'].unique() | |
| for q in report_qs: | |
| q_rel = high_rel[(high_rel['report'] == report) & (high_rel['question'] == q)] | |
| rel_texts = q_rel['relevant_text'].dropna().unique() | |
| normalized = [normalize_text(rt) for rt in rel_texts if len(str(rt).strip()) >= MIN_MATCH_LEN] | |
| question_doc_relevant[(report, q)] = normalized | |
| return question_doc_relevant, all_reports | |
| elif doc_mode == "cross": | |
| df = pd.read_excel(CROSS_XLSX_PATH) | |
| if 'Unnamed: 0' in df.columns: | |
| df = df.drop(columns=['Unnamed: 0']) | |
| print(f" 加载 cross.xlsx: {len(df)} 行, {df['Document'].nunique()} 报告, {df['Question'].nunique()} 问题") | |
| print(f" Source Relevance Score 分布:\n{df['Source Relevance Score'].value_counts().sort_index().to_string()}") | |
| all_reports = list(df['Document'].unique()) | |
| question_doc_relevant = {} | |
| for _, row in df.iterrows(): | |
| report = row['Document'] | |
| question = row['Question'] | |
| relevant = row.get('Relevant', '') | |
| if pd.isna(relevant) or len(str(relevant).strip()) < MIN_MATCH_LEN: | |
| continue | |
| key = (report, question) | |
| if key not in question_doc_relevant: | |
| question_doc_relevant[key] = [] | |
| norm = normalize_text(str(relevant)) | |
| if norm not in question_doc_relevant[key]: | |
| question_doc_relevant[key].append(norm) | |
| print(f" (report, question) 对数: {len(question_doc_relevant)}") | |
| return question_doc_relevant, all_reports | |
| else: | |
| raise ValueError(f"未知 DOC_MODE: {doc_mode}") | |
| def main(): | |
| global CHUNK_MODE, DOC_MODE | |
| # 命令行参数: python transfer_annotations.py [chunk_mode] [doc_mode] | |
| if len(sys.argv) > 1 and sys.argv[1] in ("length", "structure"): | |
| CHUNK_MODE = sys.argv[1] | |
| if len(sys.argv) > 2 and sys.argv[2] in ("single", "cross"): | |
| DOC_MODE = sys.argv[2] | |
| output_dir = _build_output_dir(CHUNK_MODE, DOC_MODE) | |
| print("=" * 60) | |
| print(f"标注转移 (CHUNK={CHUNK_MODE}, DOC={DOC_MODE})") | |
| print(f"输出目录: {output_dir}") | |
| print("=" * 60) | |
| # ---- Step 1: 加载标注数据 ---- | |
| print("\nStep 1: 加载标注数据") | |
| print("=" * 60) | |
| question_doc_relevant, all_reports = load_annotation_source(DOC_MODE) | |
| # ---- Step 2: 建立 report 名称映射 ---- | |
| print("\n" + "=" * 60) | |
| print("Step 2: 建立 report 名称映射") | |
| print("=" * 60) | |
| mineru_folders = [f for f in os.listdir(MINERU_DIR) | |
| if os.path.isdir(os.path.join(MINERU_DIR, f))] | |
| report_mapping = build_report_name_mapping(all_reports, mineru_folders) | |
| print(f"\n 成功映射: {len(report_mapping)} / {len(all_reports)}") | |
| # ---- Step 3: 切分 chunk + 转移标注 ---- | |
| print("\n" + "=" * 60) | |
| print(f"Step 3: 切分 chunk ({CHUNK_MODE}), 转移标注 ({DOC_MODE})") | |
| print("=" * 60) | |
| splitter = None | |
| if CHUNK_MODE == "length": | |
| splitter = SentenceSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| tokenizer=word_count_tokenizer, | |
| ) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # 收集每个 report 关联的 questions | |
| report_questions_map = {} | |
| for (report, question) in question_doc_relevant.keys(): | |
| if report not in report_questions_map: | |
| report_questions_map[report] = set() | |
| report_questions_map[report].add(question) | |
| all_results = [] | |
| for report_name in all_reports: | |
| print(f"\n 处理报告: {report_name}") | |
| if report_name not in report_mapping: | |
| print(f" [跳过] 未找到对应的 MinerU 文件夹") | |
| continue | |
| folder_name = report_mapping[report_name] | |
| json_path = find_content_list_json(MINERU_DIR, folder_name) | |
| if json_path is None: | |
| print(f" [跳过] 未找到 content_list.json") | |
| continue | |
| # 切分 chunk | |
| if CHUNK_MODE == "length": | |
| chunks, extras = chunk_report_length(json_path, splitter) | |
| else: | |
| chunks, extras = chunk_report_structure(json_path, STRUCTURE_MAX_TOKENS) | |
| if not chunks: | |
| print(f" [跳过] 无 chunk 产生") | |
| continue | |
| print(f" 切分为 {len(chunks)} 个 {CHUNK_MODE} chunk") | |
| report_questions = report_questions_map.get(report_name, set()) | |
| if not report_questions: | |
| print(f" [跳过] 该报告无关联问题") | |
| continue | |
| # 对每个 chunk × question 匹配 | |
| for chunk_idx, chunk_text in enumerate(chunks): | |
| chunk_views = build_chunk_match_views(chunk_text) | |
| chunk_word_count = len(chunk_text.split()) | |
| for q in report_questions: | |
| rel_norms = question_doc_relevant.get((report_name, q), []) | |
| is_relevant = False | |
| matched_logs = {} | |
| matched_channels = set() | |
| for rt_norm in rel_norms: | |
| hit, channels = match_relevant_text_multi_view(chunk_views, rt_norm) | |
| if hit: | |
| is_relevant = True | |
| if rt_norm not in matched_logs: | |
| matched_logs[rt_norm] = set() | |
| matched_logs[rt_norm].update(channels) | |
| matched_channels.update(channels) | |
| matched_texts = [] | |
| for rt_norm in sorted(matched_logs.keys()): | |
| channel_tag = "+".join(sorted(matched_logs[rt_norm])) | |
| matched_texts.append(f"[{channel_tag}] {rt_norm[:80]}...") | |
| result = { | |
| 'report': report_name, | |
| 'chunk_idx': chunk_idx, | |
| 'chunk_text': chunk_text, | |
| 'chunk_word_count': chunk_word_count, | |
| 'question': q, | |
| 'is_relevant': 1 if is_relevant else 0, | |
| 'matched_relevant_texts': "; ".join(matched_texts) if matched_texts else "", | |
| 'num_matched': len(matched_logs), | |
| 'matched_channels': ",".join(sorted(matched_channels)) if matched_channels else "", | |
| } | |
| if extras and chunk_idx < len(extras) and extras[chunk_idx]: | |
| result.update(extras[chunk_idx]) | |
| all_results.append(result) | |
| # 统计 | |
| report_results = [r for r in all_results if r['report'] == report_name] | |
| total_pairs = len(report_results) | |
| relevant_pairs = sum(1 for r in report_results if r['is_relevant']) | |
| print(f" (chunk, question) 对总数: {total_pairs}, 标为 relevant: {relevant_pairs}") | |
| # ---- Step 4: 保存 ---- | |
| print("\n" + "=" * 60) | |
| print("Step 4: 保存结果") | |
| print("=" * 60) | |
| result_df = pd.DataFrame(all_results) | |
| tag_parts = [] | |
| if CHUNK_MODE != "length": | |
| tag_parts.append(CHUNK_MODE) | |
| if DOC_MODE != "single": | |
| tag_parts.append(DOC_MODE) | |
| tag = ("_" + "_".join(tag_parts)) if tag_parts else "" | |
| output_csv = os.path.join(output_dir, f"ocr_chunks_annotated{tag}.csv") | |
| result_df.to_csv(output_csv, index=False, encoding='utf-8-sig') | |
| print(f" 完整结果: {output_csv}") | |
| print(f" 总行数: {len(result_df)}") | |
| print(f" 标为 relevant: {result_df['is_relevant'].sum()}") | |
| unique_chunks = result_df.drop_duplicates(subset=['report', 'chunk_idx']).shape[0] | |
| print(f" 唯一 chunk 数: {unique_chunks}") | |
| # chunk 列表 | |
| dedup_cols = ['report', 'chunk_idx', 'chunk_text', 'chunk_word_count'] | |
| if 'section_path' in result_df.columns: | |
| dedup_cols += ['section_path', 'document_id'] | |
| chunks_only = result_df.drop_duplicates(subset=['report', 'chunk_idx'])[dedup_cols].reset_index(drop=True) | |
| chunks_json = os.path.join(output_dir, f"ocr_chunks_all{tag}.json") | |
| chunks_only.to_json(chunks_json, orient='records', force_ascii=False, indent=2) | |
| print(f" chunk 列表: {chunks_json} ({len(chunks_only)} chunks)") | |
| # 汇总 | |
| summary = result_df.groupby('report').agg( | |
| total_chunks=('chunk_idx', 'nunique'), | |
| total_pairs=('is_relevant', 'count'), | |
| relevant_pairs=('is_relevant', 'sum'), | |
| ).reset_index() | |
| summary['relevant_ratio'] = summary['relevant_pairs'] / summary['total_pairs'] | |
| summary_path = os.path.join(output_dir, f"annotation_summary{tag}.csv") | |
| summary.to_csv(summary_path, index=False) | |
| print(f" 汇总: {summary_path}") | |
| print(summary.to_string()) | |
| if __name__ == "__main__": | |
| main() | |