Spaces:
Running
Running
File size: 6,411 Bytes
ec3d86e ec5f146 43aac1a ec5f146 ec3d86e 43aac1a ec3d86e ca1833b ec3d86e 43aac1a ec5f146 ec3d86e ec5f146 ec3d86e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# figure_table_pipeline.py
import os
import shutil
import re
from pathlib import Path
from collections import defaultdict
from pragent.backend.loader import ImagePDFLoader
from pragent.backend.yolo import extract_and_save_layout_components
from tqdm.asyncio import tqdm
import asyncio
from typing import Any
async def run_figure_extraction(pdf_path: str, base_work_dir: str, progress: Any | None = None) -> str:
"""
一个完整的、从PDF提取并配对图表的流程。
这是被 app.py 调用的主函数。
Args:
pdf_path (str): 用户上传的PDF的路径。
base_work_dir (str): 本次会话的临时工作目录。
progress (Any | None): Gradio progress object.
Returns:
str: 最终配对结果的目录路径,如果失败则返回 None。
"""
if not all([ImagePDFLoader, extract_and_save_layout_components]):
tqdm.write("[!] 错误: figure_pipeline 的一个或多个核心依赖项未能加载。")
return None
pdf_file = Path(pdf_path)
pdf_stem = pdf_file.stem
model_path = "pragent/model/doclayout_yolo_docstructbench_imgsz1024.pt"
tqdm.write(f"\n--- 步骤 1/3: 将PDF '{pdf_file.name}' 转换为图片 ---")
page_save_dir = os.path.join(base_work_dir, "page_paper", pdf_stem)
os.makedirs(page_save_dir, exist_ok=True)
try:
loader = ImagePDFLoader(pdf_path)
page_image_paths = []
for i, img in enumerate(loader.load()):
path = os.path.join(page_save_dir, f"page_{i+1}.png")
img.save(path)
page_image_paths.append(path)
tqdm.write(f"[*] 所有 {len(page_image_paths)} 页已保存至: {page_save_dir}")
except Exception as e:
tqdm.write(f"[!] 错误:加载或转换PDF时失败: {e}")
return None
if len(page_image_paths) > 20:
tqdm.write(f"[!] Warning: PDF has {len(page_image_paths)} pages. Processing only the first 20 pages to avoid timeout.")
page_image_paths = page_image_paths[:20]
tqdm.write(f"\n--- 步骤 2/3: 分析页面布局以裁剪图和表 ---")
cropped_results_dir = os.path.join(base_work_dir, "cropped_results", pdf_stem)
num_pages = len(page_image_paths)
for i, path in enumerate(page_image_paths):
if progress:
progress(0.3 + (i / num_pages) * 0.2, desc=f"Analyzing page {i+1}/{num_pages}")
page_num_str = Path(path).stem
page_crop_dir = os.path.join(cropped_results_dir, page_num_str)
await asyncio.to_thread(
extract_and_save_layout_components,
image_path=path,
model_path=model_path,
save_base_dir=page_crop_dir,
imgsz=640
)
tqdm.write(f"[*] 所有裁剪结果已保存至: {cropped_results_dir}")
tqdm.write(f"\n--- 步骤 3/3: 对裁剪出的组件进行配对 ---")
final_paired_dir = os.path.join(base_work_dir, "paired_results", pdf_stem)
run_pairing_process(cropped_results_dir, final_paired_dir, threshold=30)
if os.path.isdir(final_paired_dir):
return final_paired_dir
return None
def run_pairing_process(source_dir_str: str, output_dir_str: str, threshold: int):
"""配对逻辑,现在是pipeline的一部分。"""
source_dir = Path(source_dir_str)
output_root_dir = Path(output_dir_str)
if output_root_dir.exists(): shutil.rmtree(output_root_dir)
output_root_dir.mkdir(parents=True, exist_ok=True)
tqdm.write(f" 开始最近邻配对流程 (阈值 = {threshold})")
page_dirs = sorted([d for d in source_dir.iterdir() if d.is_dir() and d.name.startswith('page_')])
for page_dir in page_dirs:
output_page_dir = output_root_dir / page_dir.name
output_page_dir.mkdir(exist_ok=True)
pair_items_on_page(str(page_dir), str(output_page_dir), threshold)
def pair_items_on_page(page_dir: str, output_dir: str, threshold: int):
"""处理单个页面目录,进行最近邻配对。"""
organized_files = defaultdict(dict)
component_types = ["figure", "figure_caption", "table", "table_caption_above", "table_caption_below"]
def parse_filename(filename: str):
match = re.match(r'([a-zA-Z_]+)_(\d+)_score([\d.]+)\.jpg', filename)
return (match.group(1), int(match.group(2))) if match else (None, None)
for comp_type in component_types:
comp_dir = os.path.join(page_dir, comp_type)
if os.path.isdir(comp_dir):
for filename in os.listdir(comp_dir):
_, index = parse_filename(filename)
if index is not None: organized_files[comp_type][index] = os.path.join(comp_dir, filename)
paired_files, used_captions = set(), defaultdict(set)
for item_type, cap_types in [("figure", ["figure_caption"]), ("table", ["table_caption_above", "table_caption_below"])]:
for item_index, item_path in organized_files[item_type].items():
best_match = {'min_diff': float('inf'), 'cap_path': None, 'cap_index': -1, 'cap_type': ''}
for cap_type in cap_types:
for cap_index, cap_path in organized_files[cap_type].items():
if cap_index in used_captions[cap_type]: continue
diff = abs(item_index - cap_index)
if diff < best_match['min_diff']:
best_match.update({'min_diff': diff, 'cap_path': cap_path, 'cap_index': cap_index, 'cap_type': cap_type})
if best_match['cap_path'] and best_match['min_diff'] <= threshold:
target_dir = os.path.join(output_dir, f"paired_{item_type}_{item_index}")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(item_path, target_dir); shutil.copy(best_match['cap_path'], target_dir)
paired_files.add(item_path); paired_files.add(best_match['cap_path'])
used_captions[best_match['cap_type']].add(best_match['cap_index'])
for files_dict in organized_files.values():
for file_path in files_dict.values():
if file_path not in paired_files:
item_type, index = parse_filename(Path(file_path).name)
if item_type:
target_dir = os.path.join(output_dir, f"unpaired_{item_type}_{index}")
os.makedirs(target_dir, exist_ok=True); shutil.copy(file_path, target_dir)
|