dots-ocr / dots_ocr /utils /layout_utils.py
redhairedshanks1's picture
Upload 61 files
b56e481 verified
from PIL import Image
from typing import Dict, List
import fitz
from io import BytesIO
import json
from dots_ocr.utils.image_utils import smart_resize
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.output_cleaner import OutputCleaner
# Define a color map (using RGBA format)
dict_layout_type_to_color = {
"Text": (0, 128, 0, 256), # Green, translucent
"Picture": (255, 0, 255, 256), # Magenta, translucent
"Caption": (255, 165, 0, 256), # Orange, translucent
"Section-header": (0, 255, 255, 256), # Cyan, translucent
"Footnote": (0, 128, 0, 256), # Green, translucent
"Formula": (128, 128, 128, 256), # Gray, translucent
"Table": (255, 192, 203, 256), # Pink, translucent
"Title": (255, 0, 0, 256), # Red, translucent
"List-item": (0, 0, 255, 256), # Blue, translucent
"Page-header": (0, 128, 0, 256), # Green, translucent
"Page-footer": (128, 0, 128, 256), # Purple, translucent
"Other": (165, 42, 42, 256), # Brown, translucent
"Unknown": (0, 0, 0, 0),
}
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
"""
Draw transparent boxes on an image.
Args:
image: The source PIL Image.
cells: A list of cells containing bounding box information.
resized_height: The resized height.
resized_width: The resized width.
fill_bbox: Whether to fill the bounding box.
draw_bbox: Whether to draw the bounding box.
Returns:
PIL.Image: The image with drawings.
"""
# origin_image = Image.open(image_path)
original_width, original_height = image.size
# Create a new PDF document
doc = fitz.open()
# Get image information
img_bytes = BytesIO()
image.save(img_bytes, format='PNG')
# pix = fitz.Pixmap(image_path)
pix = fitz.Pixmap(img_bytes)
# Create a page
page = doc.new_page(width=pix.width, height=pix.height)
page.insert_image(
fitz.Rect(0, 0, pix.width, pix.height),
# filename=image_path
pixmap=pix
)
for i, cell in enumerate(cells):
bbox = cell['bbox']
layout_type = cell['category']
order = i
top_left = (bbox[0], bbox[1])
down_right = (bbox[2], bbox[3])
if resized_height and resized_width:
scale_x = resized_width / original_width
scale_y = resized_height / original_height
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
color = [col/255 for col in color[:3]]
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
rect_coords = fitz.Rect(x0, y0, x1, y1)
if draw_bbox:
if fill_bbox:
page.draw_rect(
rect_coords,
color=None,
fill=color,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=color,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
order_cate = f"{order}_{layout_type}"
page.insert_text(
(x1, y0 + 20), order_cate, fontsize=20, color=color
) # Insert the index in the top left corner of the rectangle
# Convert to a Pixmap (maintaining original dimensions)
mat = fitz.Matrix(1.0, 1.0)
pix = page.get_pixmap(matrix=mat)
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
def pre_process_bboxes(
origin_image,
bboxes,
input_width,
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
):
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = original_width / input_width
scale_y = original_height / input_height
bboxes_out = []
for bbox in bboxes:
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
bboxes_out.append(bbox_resized)
return bboxes_out
def post_process_cells(
origin_image: Image.Image,
cells: List[Dict],
input_width, # server input width, also has smart_resize in server
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
) -> List[Dict]:
"""
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
Args:
origin_image: The original PIL Image.
cells: A list of cells containing bounding box information.
input_width: The width of the input image sent to the server.
input_height: The height of the input image sent to the server.
factor: Resizing factor.
min_pixels: Minimum number of pixels.
max_pixels: Maximum number of pixels.
Returns:
A list of post-processed cells.
"""
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = input_width / original_width
scale_y = input_height / original_height
cells_out = []
for cell in cells:
bbox = cell['bbox']
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
cell_copy = cell.copy()
cell_copy['bbox'] = bbox_resized
cells_out.append(cell_copy)
return cells_out
def is_legal_bbox(cells):
for cell in cells:
bbox = cell['bbox']
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
return False
return True
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
return response
json_load_failed = False
cells = response
try:
cells = json.loads(cells)
cells = post_process_cells(
origin_image,
cells,
input_image.width,
input_image.height,
min_pixels=min_pixels,
max_pixels=max_pixels
)
return cells, False
except Exception as e:
print(f"cells post process error: {e}, when using {prompt_mode}")
json_load_failed = True
if json_load_failed:
cleaner = OutputCleaner()
response_clean = cleaner.clean_model_output(cells)
if isinstance(response_clean, list):
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
return response_clean, True