File size: 6,411 Bytes
ec3d86e
 
 
 
 
 
 
 
 
 
ec5f146
43aac1a
 
ec5f146
ec3d86e
 
 
 
 
 
 
43aac1a
ec3d86e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1833b
 
 
 
ec3d86e
 
43aac1a
 
 
 
ec5f146
ec3d86e
 
ec5f146
 
 
 
 
 
 
 
 
ec3d86e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# figure_table_pipeline.py
import os
import shutil
import re
from pathlib import Path
from collections import defaultdict
from pragent.backend.loader import ImagePDFLoader
from pragent.backend.yolo import extract_and_save_layout_components
from tqdm.asyncio import tqdm

import asyncio
from typing import Any

async def run_figure_extraction(pdf_path: str, base_work_dir: str, progress: Any | None = None) -> str:
    """
    一个完整的、从PDF提取并配对图表的流程。
    这是被 app.py 调用的主函数。

    Args:
        pdf_path (str): 用户上传的PDF的路径。
        base_work_dir (str): 本次会话的临时工作目录。
        progress (Any | None): Gradio progress object.

    Returns:
        str: 最终配对结果的目录路径,如果失败则返回 None。
    """
    if not all([ImagePDFLoader, extract_and_save_layout_components]):
        tqdm.write("[!] 错误: figure_pipeline 的一个或多个核心依赖项未能加载。")
        return None

    pdf_file = Path(pdf_path)
    pdf_stem = pdf_file.stem
    model_path = "pragent/model/doclayout_yolo_docstructbench_imgsz1024.pt"

    tqdm.write(f"\n--- 步骤 1/3: 将PDF '{pdf_file.name}' 转换为图片 ---")
    page_save_dir = os.path.join(base_work_dir, "page_paper", pdf_stem)
    os.makedirs(page_save_dir, exist_ok=True)
    try:
        loader = ImagePDFLoader(pdf_path)
        page_image_paths = []
        for i, img in enumerate(loader.load()):
            path = os.path.join(page_save_dir, f"page_{i+1}.png")
            img.save(path)
            page_image_paths.append(path)
        tqdm.write(f"[*] 所有 {len(page_image_paths)} 页已保存至: {page_save_dir}")
    except Exception as e:
        tqdm.write(f"[!] 错误:加载或转换PDF时失败: {e}")
        return None

    if len(page_image_paths) > 20:
        tqdm.write(f"[!] Warning: PDF has {len(page_image_paths)} pages. Processing only the first 20 pages to avoid timeout.")
        page_image_paths = page_image_paths[:20]

    tqdm.write(f"\n--- 步骤 2/3: 分析页面布局以裁剪图和表 ---")
    cropped_results_dir = os.path.join(base_work_dir, "cropped_results", pdf_stem)
    num_pages = len(page_image_paths)
    for i, path in enumerate(page_image_paths):
        if progress:
            progress(0.3 + (i / num_pages) * 0.2, desc=f"Analyzing page {i+1}/{num_pages}")
        
        page_num_str = Path(path).stem
        page_crop_dir = os.path.join(cropped_results_dir, page_num_str)
        
        await asyncio.to_thread(
            extract_and_save_layout_components,
            image_path=path,
            model_path=model_path,
            save_base_dir=page_crop_dir,
            imgsz=640
        )

    tqdm.write(f"[*] 所有裁剪结果已保存至: {cropped_results_dir}")

    tqdm.write(f"\n--- 步骤 3/3: 对裁剪出的组件进行配对 ---")
    final_paired_dir = os.path.join(base_work_dir, "paired_results", pdf_stem)
    run_pairing_process(cropped_results_dir, final_paired_dir, threshold=30)
    
    if os.path.isdir(final_paired_dir):
        return final_paired_dir
    return None

def run_pairing_process(source_dir_str: str, output_dir_str: str, threshold: int):
    """配对逻辑,现在是pipeline的一部分。"""
    source_dir = Path(source_dir_str)
    output_root_dir = Path(output_dir_str)
    if output_root_dir.exists(): shutil.rmtree(output_root_dir)
    output_root_dir.mkdir(parents=True, exist_ok=True)
    
    tqdm.write(f"    开始最近邻配对流程 (阈值 = {threshold})")

    page_dirs = sorted([d for d in source_dir.iterdir() if d.is_dir() and d.name.startswith('page_')])
    for page_dir in page_dirs:
        output_page_dir = output_root_dir / page_dir.name
        output_page_dir.mkdir(exist_ok=True)
        pair_items_on_page(str(page_dir), str(output_page_dir), threshold)

def pair_items_on_page(page_dir: str, output_dir: str, threshold: int):
    """处理单个页面目录,进行最近邻配对。"""
    organized_files = defaultdict(dict)
    component_types = ["figure", "figure_caption", "table", "table_caption_above", "table_caption_below"]
    
    def parse_filename(filename: str):
        match = re.match(r'([a-zA-Z_]+)_(\d+)_score([\d.]+)\.jpg', filename)
        return (match.group(1), int(match.group(2))) if match else (None, None)

    for comp_type in component_types:
        comp_dir = os.path.join(page_dir, comp_type)
        if os.path.isdir(comp_dir):
            for filename in os.listdir(comp_dir):
                _, index = parse_filename(filename)
                if index is not None: organized_files[comp_type][index] = os.path.join(comp_dir, filename)

    paired_files, used_captions = set(), defaultdict(set)

    for item_type, cap_types in [("figure", ["figure_caption"]), ("table", ["table_caption_above", "table_caption_below"])]:
        for item_index, item_path in organized_files[item_type].items():
            best_match = {'min_diff': float('inf'), 'cap_path': None, 'cap_index': -1, 'cap_type': ''}
            for cap_type in cap_types:
                for cap_index, cap_path in organized_files[cap_type].items():
                    if cap_index in used_captions[cap_type]: continue
                    diff = abs(item_index - cap_index)
                    if diff < best_match['min_diff']:
                        best_match.update({'min_diff': diff, 'cap_path': cap_path, 'cap_index': cap_index, 'cap_type': cap_type})
            
            if best_match['cap_path'] and best_match['min_diff'] <= threshold:
                target_dir = os.path.join(output_dir, f"paired_{item_type}_{item_index}")
                os.makedirs(target_dir, exist_ok=True)
                shutil.copy(item_path, target_dir); shutil.copy(best_match['cap_path'], target_dir)
                paired_files.add(item_path); paired_files.add(best_match['cap_path'])
                used_captions[best_match['cap_type']].add(best_match['cap_index'])

    for files_dict in organized_files.values():
        for file_path in files_dict.values():
            if file_path not in paired_files:
                item_type, index = parse_filename(Path(file_path).name)
                if item_type:
                    target_dir = os.path.join(output_dir, f"unpaired_{item_type}_{index}")
                    os.makedirs(target_dir, exist_ok=True); shutil.copy(file_path, target_dir)