# figure_table_pipeline.py import os import shutil import re from pathlib import Path from collections import defaultdict from pragent.backend.loader import ImagePDFLoader from pragent.backend.yolo import extract_and_save_layout_components from tqdm.asyncio import tqdm import asyncio from typing import Any async def run_figure_extraction(pdf_path: str, base_work_dir: str, progress: Any | None = None) -> str: """ 一个完整的、从PDF提取并配对图表的流程。 这是被 app.py 调用的主函数。 Args: pdf_path (str): 用户上传的PDF的路径。 base_work_dir (str): 本次会话的临时工作目录。 progress (Any | None): Gradio progress object. Returns: str: 最终配对结果的目录路径,如果失败则返回 None。 """ if not all([ImagePDFLoader, extract_and_save_layout_components]): tqdm.write("[!] 错误: figure_pipeline 的一个或多个核心依赖项未能加载。") return None pdf_file = Path(pdf_path) pdf_stem = pdf_file.stem model_path = "pragent/model/doclayout_yolo_docstructbench_imgsz1024.pt" tqdm.write(f"\n--- 步骤 1/3: 将PDF '{pdf_file.name}' 转换为图片 ---") page_save_dir = os.path.join(base_work_dir, "page_paper", pdf_stem) os.makedirs(page_save_dir, exist_ok=True) try: loader = ImagePDFLoader(pdf_path) page_image_paths = [] for i, img in enumerate(loader.load()): path = os.path.join(page_save_dir, f"page_{i+1}.png") img.save(path) page_image_paths.append(path) tqdm.write(f"[*] 所有 {len(page_image_paths)} 页已保存至: {page_save_dir}") except Exception as e: tqdm.write(f"[!] 错误:加载或转换PDF时失败: {e}") return None if len(page_image_paths) > 20: tqdm.write(f"[!] Warning: PDF has {len(page_image_paths)} pages. Processing only the first 20 pages to avoid timeout.") page_image_paths = page_image_paths[:20] tqdm.write(f"\n--- 步骤 2/3: 分析页面布局以裁剪图和表 ---") cropped_results_dir = os.path.join(base_work_dir, "cropped_results", pdf_stem) num_pages = len(page_image_paths) for i, path in enumerate(page_image_paths): if progress: progress(0.3 + (i / num_pages) * 0.2, desc=f"Analyzing page {i+1}/{num_pages}") page_num_str = Path(path).stem page_crop_dir = os.path.join(cropped_results_dir, page_num_str) await asyncio.to_thread( extract_and_save_layout_components, image_path=path, model_path=model_path, save_base_dir=page_crop_dir, imgsz=640 ) tqdm.write(f"[*] 所有裁剪结果已保存至: {cropped_results_dir}") tqdm.write(f"\n--- 步骤 3/3: 对裁剪出的组件进行配对 ---") final_paired_dir = os.path.join(base_work_dir, "paired_results", pdf_stem) run_pairing_process(cropped_results_dir, final_paired_dir, threshold=30) if os.path.isdir(final_paired_dir): return final_paired_dir return None def run_pairing_process(source_dir_str: str, output_dir_str: str, threshold: int): """配对逻辑,现在是pipeline的一部分。""" source_dir = Path(source_dir_str) output_root_dir = Path(output_dir_str) if output_root_dir.exists(): shutil.rmtree(output_root_dir) output_root_dir.mkdir(parents=True, exist_ok=True) tqdm.write(f" 开始最近邻配对流程 (阈值 = {threshold})") page_dirs = sorted([d for d in source_dir.iterdir() if d.is_dir() and d.name.startswith('page_')]) for page_dir in page_dirs: output_page_dir = output_root_dir / page_dir.name output_page_dir.mkdir(exist_ok=True) pair_items_on_page(str(page_dir), str(output_page_dir), threshold) def pair_items_on_page(page_dir: str, output_dir: str, threshold: int): """处理单个页面目录,进行最近邻配对。""" organized_files = defaultdict(dict) component_types = ["figure", "figure_caption", "table", "table_caption_above", "table_caption_below"] def parse_filename(filename: str): match = re.match(r'([a-zA-Z_]+)_(\d+)_score([\d.]+)\.jpg', filename) return (match.group(1), int(match.group(2))) if match else (None, None) for comp_type in component_types: comp_dir = os.path.join(page_dir, comp_type) if os.path.isdir(comp_dir): for filename in os.listdir(comp_dir): _, index = parse_filename(filename) if index is not None: organized_files[comp_type][index] = os.path.join(comp_dir, filename) paired_files, used_captions = set(), defaultdict(set) for item_type, cap_types in [("figure", ["figure_caption"]), ("table", ["table_caption_above", "table_caption_below"])]: for item_index, item_path in organized_files[item_type].items(): best_match = {'min_diff': float('inf'), 'cap_path': None, 'cap_index': -1, 'cap_type': ''} for cap_type in cap_types: for cap_index, cap_path in organized_files[cap_type].items(): if cap_index in used_captions[cap_type]: continue diff = abs(item_index - cap_index) if diff < best_match['min_diff']: best_match.update({'min_diff': diff, 'cap_path': cap_path, 'cap_index': cap_index, 'cap_type': cap_type}) if best_match['cap_path'] and best_match['min_diff'] <= threshold: target_dir = os.path.join(output_dir, f"paired_{item_type}_{item_index}") os.makedirs(target_dir, exist_ok=True) shutil.copy(item_path, target_dir); shutil.copy(best_match['cap_path'], target_dir) paired_files.add(item_path); paired_files.add(best_match['cap_path']) used_captions[best_match['cap_type']].add(best_match['cap_index']) for files_dict in organized_files.values(): for file_path in files_dict.values(): if file_path not in paired_files: item_type, index = parse_filename(Path(file_path).name) if item_type: target_dir = os.path.join(output_dir, f"unpaired_{item_type}_{index}") os.makedirs(target_dir, exist_ok=True); shutil.copy(file_path, target_dir)