Spaces:
Sleeping
Sleeping
| import re | |
| import html | |
| import logging | |
| import json | |
| from PIL import Image | |
| def extract_json_content(text): | |
| matches = re.search(r"```json\n(.*?)\n```", text, re.DOTALL) | |
| if matches: | |
| text = matches.group(1).strip() | |
| else: | |
| partial_md_match = re.search(r"```json\n(.*)", text, re.DOTALL) | |
| if partial_md_match: | |
| text = partial_md_match.group(1).strip() | |
| return text | |
| def truncate_last_incomplete_element(text: str): | |
| """Truncates the last incomplete element""" | |
| # For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":' | |
| needs_truncation = len(text) > 50000 or not text.strip().endswith("]") | |
| if needs_truncation: | |
| # Check how many dict objects there are | |
| bbox_count = text.count('{"bbox":') | |
| # If there is only one dict object, do not truncate to avoid deleting the only object | |
| if bbox_count <= 1: | |
| # print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content") | |
| return text, False | |
| # Find the position of the last '{"bbox":' | |
| last_bbox_pos = text.rfind('{"bbox":') | |
| if last_bbox_pos > 0: | |
| # Truncate before this position | |
| truncated_text = text[:last_bbox_pos].rstrip() | |
| # Remove trailing comma | |
| if truncated_text.endswith(","): | |
| truncated_text = truncated_text[:-1] | |
| truncated_text = truncated_text + "]" | |
| # print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}") | |
| return truncated_text, True | |
| return text, False | |
| def obtain_origin_hw(image_paths: list[str]): | |
| try: | |
| image = Image.open(image_paths[0]).convert("RGB") | |
| origin_width, origin_height = image.size | |
| return origin_height, origin_width | |
| except Exception as e: | |
| print(f"处理图像 {image_paths[0]} 时出错: {e}") | |
| return 1000, 1000 | |
| def restore_abs_bbox_coordinates(ans: str, origin_height: float, origin_width: float): | |
| is_valid = 0 | |
| ans = extract_json_content(ans) | |
| ans, _ = truncate_last_incomplete_element(ans) | |
| try: | |
| data = json.loads(ans) | |
| is_valid = 1 | |
| except Exception as e: | |
| try: | |
| data = eval(ans) | |
| is_valid = 1 | |
| except Exception as e: | |
| print(f"解析json时出错: {e}") | |
| return ans | |
| if len(data) != 0: | |
| for index, item in enumerate(data): | |
| for key in item: | |
| if "bbox" in key: | |
| if len(item[key]) == 4 and all(isinstance(coord, (int, float)) for coord in item[key]): | |
| x1, y1, x2, y2 = item[key] | |
| new_x1 = int(x1 / 1000.0 * origin_width) | |
| new_y1 = int(y1 / 1000.0 * origin_height) | |
| new_x2 = int(x2 / 1000.0 * origin_width) | |
| new_y2 = int(y2 / 1000.0 * origin_height) | |
| item[key] = [new_x1, new_y1, new_x2, new_y2] | |
| else: | |
| eval_logger.info(f"ERROR CHECK: idx {index}, {data}") | |
| if is_valid: | |
| return json.dumps(data, indent=4) | |
| else: | |
| return ans | |
| def convert_json_to_markdown(ans: str, keep_header_footer: bool = False): | |
| ans = extract_json_content(ans) | |
| ans, _ = truncate_last_incomplete_element(ans) | |
| try: | |
| res_str = [] | |
| ans_dict = json.loads(ans) | |
| for sub_item in ans_dict: | |
| if "text" in sub_item: | |
| if sub_item["text"]: | |
| if keep_header_footer: | |
| res_str.append(sub_item["text"]) | |
| else: | |
| if sub_item["category"] not in ["header", "footer", "page_footnote"]: | |
| res_str.append(sub_item["text"]) | |
| return "\n\n".join(res_str) if res_str else ans | |
| except Exception as e: | |
| print(f"process ans error: {e}") | |
| return ans |