File size: 4,039 Bytes
37d290c
 
 
 
 
 
 
53ef3cd
 
 
 
 
 
 
 
 
 
 
37d290c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
import html
import logging
import json
from PIL import Image


def extract_json_content(text):
    matches = re.search(r"```json\n(.*?)\n```", text, re.DOTALL)
    if matches:
        text = matches.group(1).strip()
    else:
        partial_md_match = re.search(r"```json\n(.*)", text, re.DOTALL)
        if partial_md_match:
            text = partial_md_match.group(1).strip()
    return text


def truncate_last_incomplete_element(text: str):
    """Truncates the last incomplete element"""

    # For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":'
    needs_truncation = len(text) > 50000 or not text.strip().endswith("]")

    if needs_truncation:
        # Check how many dict objects there are
        bbox_count = text.count('{"bbox":')

        # If there is only one dict object, do not truncate to avoid deleting the only object
        if bbox_count <= 1:
            # print(f"    ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content")
            return text, False

        # Find the position of the last '{"bbox":'
        last_bbox_pos = text.rfind('{"bbox":')

        if last_bbox_pos > 0:
            # Truncate before this position
            truncated_text = text[:last_bbox_pos].rstrip()

            # Remove trailing comma
            if truncated_text.endswith(","):
                truncated_text = truncated_text[:-1]
                truncated_text = truncated_text + "]"
            # print(f"    ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}")
            return truncated_text, True

    return text, False


def obtain_origin_hw(image_paths: list[str]):
    try:
        image = Image.open(image_paths[0]).convert("RGB")
        origin_width, origin_height = image.size
        return origin_height, origin_width
    except Exception as e:
        print(f"处理图像 {image_paths[0]} 时出错: {e}")
        return 1000, 1000


def restore_abs_bbox_coordinates(ans: str, origin_height: float, origin_width: float):
    is_valid = 0
    ans = extract_json_content(ans)
    ans, _ = truncate_last_incomplete_element(ans)
    try:
        data = json.loads(ans)
        is_valid = 1
    except Exception as e:
        try:
            data = eval(ans)
            is_valid = 1
        except Exception as e:
            print(f"解析json时出错: {e}")
            return ans

    if len(data) != 0:
        for index, item in enumerate(data):
            for key in item:
                if "bbox" in key:
                    if len(item[key]) == 4 and all(isinstance(coord, (int, float)) for coord in item[key]):
                        x1, y1, x2, y2 = item[key]
                        new_x1 = int(x1 / 1000.0 * origin_width)
                        new_y1 = int(y1 / 1000.0 * origin_height)
                        new_x2 = int(x2 / 1000.0 * origin_width)
                        new_y2 = int(y2 / 1000.0 * origin_height)
                        item[key] = [new_x1, new_y1, new_x2, new_y2]
                    else:
                        eval_logger.info(f"ERROR CHECK: idx {index}, {data}")

    if is_valid:
        return json.dumps(data, indent=4)
    else:
        return ans


def convert_json_to_markdown(ans: str, keep_header_footer: bool = False):
    ans = extract_json_content(ans)
    ans, _ = truncate_last_incomplete_element(ans)
    try:
        res_str = []
        ans_dict = json.loads(ans)
        for sub_item in ans_dict:
            if "text" in sub_item:
                if sub_item["text"]:
                    if keep_header_footer:
                        res_str.append(sub_item["text"])
                    else:
                        if sub_item["category"] not in ["header", "footer", "page_footnote"]:
                            res_str.append(sub_item["text"])
        return "\n\n".join(res_str) if res_str else ans
    except Exception as e:
        print(f"process ans error: {e}")
        return ans