AutoPR / pragent /backend /figure_table_pipeline.py
yzweak's picture
Limit page processing to first 20 pages
ca1833b
# figure_table_pipeline.py
import os
import shutil
import re
from pathlib import Path
from collections import defaultdict
from pragent.backend.loader import ImagePDFLoader
from pragent.backend.yolo import extract_and_save_layout_components
from tqdm.asyncio import tqdm
import asyncio
from typing import Any
async def run_figure_extraction(pdf_path: str, base_work_dir: str, progress: Any | None = None) -> str:
"""
一个完整的、从PDF提取并配对图表的流程。
这是被 app.py 调用的主函数。
Args:
pdf_path (str): 用户上传的PDF的路径。
base_work_dir (str): 本次会话的临时工作目录。
progress (Any | None): Gradio progress object.
Returns:
str: 最终配对结果的目录路径,如果失败则返回 None。
"""
if not all([ImagePDFLoader, extract_and_save_layout_components]):
tqdm.write("[!] 错误: figure_pipeline 的一个或多个核心依赖项未能加载。")
return None
pdf_file = Path(pdf_path)
pdf_stem = pdf_file.stem
model_path = "pragent/model/doclayout_yolo_docstructbench_imgsz1024.pt"
tqdm.write(f"\n--- 步骤 1/3: 将PDF '{pdf_file.name}' 转换为图片 ---")
page_save_dir = os.path.join(base_work_dir, "page_paper", pdf_stem)
os.makedirs(page_save_dir, exist_ok=True)
try:
loader = ImagePDFLoader(pdf_path)
page_image_paths = []
for i, img in enumerate(loader.load()):
path = os.path.join(page_save_dir, f"page_{i+1}.png")
img.save(path)
page_image_paths.append(path)
tqdm.write(f"[*] 所有 {len(page_image_paths)} 页已保存至: {page_save_dir}")
except Exception as e:
tqdm.write(f"[!] 错误:加载或转换PDF时失败: {e}")
return None
if len(page_image_paths) > 20:
tqdm.write(f"[!] Warning: PDF has {len(page_image_paths)} pages. Processing only the first 20 pages to avoid timeout.")
page_image_paths = page_image_paths[:20]
tqdm.write(f"\n--- 步骤 2/3: 分析页面布局以裁剪图和表 ---")
cropped_results_dir = os.path.join(base_work_dir, "cropped_results", pdf_stem)
num_pages = len(page_image_paths)
for i, path in enumerate(page_image_paths):
if progress:
progress(0.3 + (i / num_pages) * 0.2, desc=f"Analyzing page {i+1}/{num_pages}")
page_num_str = Path(path).stem
page_crop_dir = os.path.join(cropped_results_dir, page_num_str)
await asyncio.to_thread(
extract_and_save_layout_components,
image_path=path,
model_path=model_path,
save_base_dir=page_crop_dir,
imgsz=640
)
tqdm.write(f"[*] 所有裁剪结果已保存至: {cropped_results_dir}")
tqdm.write(f"\n--- 步骤 3/3: 对裁剪出的组件进行配对 ---")
final_paired_dir = os.path.join(base_work_dir, "paired_results", pdf_stem)
run_pairing_process(cropped_results_dir, final_paired_dir, threshold=30)
if os.path.isdir(final_paired_dir):
return final_paired_dir
return None
def run_pairing_process(source_dir_str: str, output_dir_str: str, threshold: int):
"""配对逻辑,现在是pipeline的一部分。"""
source_dir = Path(source_dir_str)
output_root_dir = Path(output_dir_str)
if output_root_dir.exists(): shutil.rmtree(output_root_dir)
output_root_dir.mkdir(parents=True, exist_ok=True)
tqdm.write(f" 开始最近邻配对流程 (阈值 = {threshold})")
page_dirs = sorted([d for d in source_dir.iterdir() if d.is_dir() and d.name.startswith('page_')])
for page_dir in page_dirs:
output_page_dir = output_root_dir / page_dir.name
output_page_dir.mkdir(exist_ok=True)
pair_items_on_page(str(page_dir), str(output_page_dir), threshold)
def pair_items_on_page(page_dir: str, output_dir: str, threshold: int):
"""处理单个页面目录,进行最近邻配对。"""
organized_files = defaultdict(dict)
component_types = ["figure", "figure_caption", "table", "table_caption_above", "table_caption_below"]
def parse_filename(filename: str):
match = re.match(r'([a-zA-Z_]+)_(\d+)_score([\d.]+)\.jpg', filename)
return (match.group(1), int(match.group(2))) if match else (None, None)
for comp_type in component_types:
comp_dir = os.path.join(page_dir, comp_type)
if os.path.isdir(comp_dir):
for filename in os.listdir(comp_dir):
_, index = parse_filename(filename)
if index is not None: organized_files[comp_type][index] = os.path.join(comp_dir, filename)
paired_files, used_captions = set(), defaultdict(set)
for item_type, cap_types in [("figure", ["figure_caption"]), ("table", ["table_caption_above", "table_caption_below"])]:
for item_index, item_path in organized_files[item_type].items():
best_match = {'min_diff': float('inf'), 'cap_path': None, 'cap_index': -1, 'cap_type': ''}
for cap_type in cap_types:
for cap_index, cap_path in organized_files[cap_type].items():
if cap_index in used_captions[cap_type]: continue
diff = abs(item_index - cap_index)
if diff < best_match['min_diff']:
best_match.update({'min_diff': diff, 'cap_path': cap_path, 'cap_index': cap_index, 'cap_type': cap_type})
if best_match['cap_path'] and best_match['min_diff'] <= threshold:
target_dir = os.path.join(output_dir, f"paired_{item_type}_{item_index}")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(item_path, target_dir); shutil.copy(best_match['cap_path'], target_dir)
paired_files.add(item_path); paired_files.add(best_match['cap_path'])
used_captions[best_match['cap_type']].add(best_match['cap_index'])
for files_dict in organized_files.values():
for file_path in files_dict.values():
if file_path not in paired_files:
item_type, index = parse_filename(Path(file_path).name)
if item_type:
target_dir = os.path.join(output_dir, f"unpaired_{item_type}_{index}")
os.makedirs(target_dir, exist_ok=True); shutil.copy(file_path, target_dir)