File size: 1,892 Bytes
ec3d86e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# yolo.py
import os
from PIL import Image
from doclayout_yolo import YOLOv10
from tqdm.asyncio import tqdm
CLASS_NAMES = {
    0: "title",
    1: "plain_text",
    2: "abandon",
    3: "figure",
    4: "figure_caption",
    5: "table",
    6: "table_caption_above",
    7: "table_caption_below",
    8: "formula",
    9: "formula_caption",
}

def extract_and_save_layout_components(image_path, model_path, save_base_dir="./cropped_results", imgsz=1024, conf=0.2, device="cpu"):
    """
    从图像中提取文档布局组件,并按类别保存截图。

    Args:
        image_path (str): 输入图像路径
        model_path (str): 模型权重路径(.pt)
        save_base_dir (str): 保存截图的根目录
        imgsz (int): 输入图像的尺寸(会缩放到这个大小)
        conf (float): 检测框的置信度阈值
        device (str): 使用的计算设备,比如 'cuda:0' 或 'cpu'
    """
    model = YOLOv10(model_path)
    image = Image.open(image_path)
    det_results = model.predict(image_path, imgsz=imgsz, conf=conf, device=device)

    result = det_results[0]
    boxes = result.boxes.xyxy.cpu().tolist()
    classes = result.boxes.cls.cpu().tolist()
    scores = result.boxes.conf.cpu().tolist()

    for idx, (box, cls_id, score) in enumerate(zip(boxes, classes, scores)):
        cls_id = int(cls_id)
        class_name = CLASS_NAMES.get(cls_id, f"cls{cls_id}")
        save_dir = os.path.join(save_base_dir, class_name)
        os.makedirs(save_dir, exist_ok=True)
        x1, y1, x2, y2 = map(int, box)
        cropped = image.crop((x1, y1, x2, y2))
        if cropped.mode == 'RGBA':
            cropped = cropped.convert('RGB')
        save_path = os.path.join(save_dir, f"{class_name}_{idx}_score{score:.2f}.jpg")
        cropped.save(save_path)
    tqdm.write(f"共保存 {len(boxes)} 张截图,按类别分别保存在 {save_base_dir}/")