Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| from typing import Dict, List | |
| import fitz | |
| from io import BytesIO | |
| import json | |
| from dots_ocr.utils.image_utils import smart_resize | |
| from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS | |
| from dots_ocr.utils.output_cleaner import OutputCleaner | |
| # Define a color map (using RGBA format) | |
| dict_layout_type_to_color = { | |
| "Text": (0, 128, 0, 256), # Green, translucent | |
| "Picture": (255, 0, 255, 256), # Magenta, translucent | |
| "Caption": (255, 165, 0, 256), # Orange, translucent | |
| "Section-header": (0, 255, 255, 256), # Cyan, translucent | |
| "Footnote": (0, 128, 0, 256), # Green, translucent | |
| "Formula": (128, 128, 128, 256), # Gray, translucent | |
| "Table": (255, 192, 203, 256), # Pink, translucent | |
| "Title": (255, 0, 0, 256), # Red, translucent | |
| "List-item": (0, 0, 255, 256), # Blue, translucent | |
| "Page-header": (0, 128, 0, 256), # Green, translucent | |
| "Page-footer": (128, 0, 128, 256), # Purple, translucent | |
| "Other": (165, 42, 42, 256), # Brown, translucent | |
| "Unknown": (0, 0, 0, 0), | |
| } | |
| def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True): | |
| """ | |
| Draw transparent boxes on an image. | |
| Args: | |
| image: The source PIL Image. | |
| cells: A list of cells containing bounding box information. | |
| resized_height: The resized height. | |
| resized_width: The resized width. | |
| fill_bbox: Whether to fill the bounding box. | |
| draw_bbox: Whether to draw the bounding box. | |
| Returns: | |
| PIL.Image: The image with drawings. | |
| """ | |
| # origin_image = Image.open(image_path) | |
| original_width, original_height = image.size | |
| # Create a new PDF document | |
| doc = fitz.open() | |
| # Get image information | |
| img_bytes = BytesIO() | |
| image.save(img_bytes, format='PNG') | |
| # pix = fitz.Pixmap(image_path) | |
| pix = fitz.Pixmap(img_bytes) | |
| # Create a page | |
| page = doc.new_page(width=pix.width, height=pix.height) | |
| page.insert_image( | |
| fitz.Rect(0, 0, pix.width, pix.height), | |
| # filename=image_path | |
| pixmap=pix | |
| ) | |
| for i, cell in enumerate(cells): | |
| bbox = cell['bbox'] | |
| layout_type = cell['category'] | |
| order = i | |
| top_left = (bbox[0], bbox[1]) | |
| down_right = (bbox[2], bbox[3]) | |
| if resized_height and resized_width: | |
| scale_x = resized_width / original_width | |
| scale_y = resized_height / original_height | |
| top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y)) | |
| down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y)) | |
| color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256)) | |
| color = [col/255 for col in color[:3]] | |
| x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1] | |
| rect_coords = fitz.Rect(x0, y0, x1, y1) | |
| if draw_bbox: | |
| if fill_bbox: | |
| page.draw_rect( | |
| rect_coords, | |
| color=None, | |
| fill=color, | |
| fill_opacity=0.3, | |
| width=0.5, | |
| overlay=True, | |
| ) # Draw the rectangle | |
| else: | |
| page.draw_rect( | |
| rect_coords, | |
| color=color, | |
| fill=None, | |
| fill_opacity=1, | |
| width=0.5, | |
| overlay=True, | |
| ) # Draw the rectangle | |
| order_cate = f"{order}_{layout_type}" | |
| page.insert_text( | |
| (x1, y0 + 20), order_cate, fontsize=20, color=color | |
| ) # Insert the index in the top left corner of the rectangle | |
| # Convert to a Pixmap (maintaining original dimensions) | |
| mat = fitz.Matrix(1.0, 1.0) | |
| pix = page.get_pixmap(matrix=mat) | |
| return Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| def pre_process_bboxes( | |
| origin_image, | |
| bboxes, | |
| input_width, | |
| input_height, | |
| factor: int = 28, | |
| min_pixels: int = 3136, | |
| max_pixels: int = 11289600 | |
| ): | |
| assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list) | |
| min_pixels = min_pixels or MIN_PIXELS | |
| max_pixels = max_pixels or MAX_PIXELS | |
| original_width, original_height = origin_image.size | |
| input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels) | |
| scale_x = original_width / input_width | |
| scale_y = original_height / input_height | |
| bboxes_out = [] | |
| for bbox in bboxes: | |
| bbox_resized = [ | |
| int(float(bbox[0]) / scale_x), | |
| int(float(bbox[1]) / scale_y), | |
| int(float(bbox[2]) / scale_x), | |
| int(float(bbox[3]) / scale_y) | |
| ] | |
| bboxes_out.append(bbox_resized) | |
| return bboxes_out | |
| def post_process_cells( | |
| origin_image: Image.Image, | |
| cells: List[Dict], | |
| input_width, # server input width, also has smart_resize in server | |
| input_height, | |
| factor: int = 28, | |
| min_pixels: int = 3136, | |
| max_pixels: int = 11289600 | |
| ) -> List[Dict]: | |
| """ | |
| Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions. | |
| Args: | |
| origin_image: The original PIL Image. | |
| cells: A list of cells containing bounding box information. | |
| input_width: The width of the input image sent to the server. | |
| input_height: The height of the input image sent to the server. | |
| factor: Resizing factor. | |
| min_pixels: Minimum number of pixels. | |
| max_pixels: Maximum number of pixels. | |
| Returns: | |
| A list of post-processed cells. | |
| """ | |
| assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict) | |
| min_pixels = min_pixels or MIN_PIXELS | |
| max_pixels = max_pixels or MAX_PIXELS | |
| original_width, original_height = origin_image.size | |
| input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels) | |
| scale_x = input_width / original_width | |
| scale_y = input_height / original_height | |
| cells_out = [] | |
| for cell in cells: | |
| bbox = cell['bbox'] | |
| bbox_resized = [ | |
| int(float(bbox[0]) / scale_x), | |
| int(float(bbox[1]) / scale_y), | |
| int(float(bbox[2]) / scale_x), | |
| int(float(bbox[3]) / scale_y) | |
| ] | |
| cell_copy = cell.copy() | |
| cell_copy['bbox'] = bbox_resized | |
| cells_out.append(cell_copy) | |
| return cells_out | |
| def is_legal_bbox(cells): | |
| for cell in cells: | |
| bbox = cell['bbox'] | |
| if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]: | |
| return False | |
| return True | |
| def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None): | |
| if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]: | |
| return response | |
| json_load_failed = False | |
| cells = response | |
| try: | |
| cells = json.loads(cells) | |
| cells = post_process_cells( | |
| origin_image, | |
| cells, | |
| input_image.width, | |
| input_image.height, | |
| min_pixels=min_pixels, | |
| max_pixels=max_pixels | |
| ) | |
| return cells, False | |
| except Exception as e: | |
| print(f"cells post process error: {e}, when using {prompt_mode}") | |
| json_load_failed = True | |
| if json_load_failed: | |
| cleaner = OutputCleaner() | |
| response_clean = cleaner.clean_model_output(cells) | |
| if isinstance(response_clean, list): | |
| response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell]) | |
| return response_clean, True | |