Spaces:
Runtime error
Runtime error
File size: 8,017 Bytes
b56e481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
from PIL import Image
from typing import Dict, List
import fitz
from io import BytesIO
import json
from dots_ocr.utils.image_utils import smart_resize
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.output_cleaner import OutputCleaner
# Define a color map (using RGBA format)
dict_layout_type_to_color = {
"Text": (0, 128, 0, 256), # Green, translucent
"Picture": (255, 0, 255, 256), # Magenta, translucent
"Caption": (255, 165, 0, 256), # Orange, translucent
"Section-header": (0, 255, 255, 256), # Cyan, translucent
"Footnote": (0, 128, 0, 256), # Green, translucent
"Formula": (128, 128, 128, 256), # Gray, translucent
"Table": (255, 192, 203, 256), # Pink, translucent
"Title": (255, 0, 0, 256), # Red, translucent
"List-item": (0, 0, 255, 256), # Blue, translucent
"Page-header": (0, 128, 0, 256), # Green, translucent
"Page-footer": (128, 0, 128, 256), # Purple, translucent
"Other": (165, 42, 42, 256), # Brown, translucent
"Unknown": (0, 0, 0, 0),
}
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
"""
Draw transparent boxes on an image.
Args:
image: The source PIL Image.
cells: A list of cells containing bounding box information.
resized_height: The resized height.
resized_width: The resized width.
fill_bbox: Whether to fill the bounding box.
draw_bbox: Whether to draw the bounding box.
Returns:
PIL.Image: The image with drawings.
"""
# origin_image = Image.open(image_path)
original_width, original_height = image.size
# Create a new PDF document
doc = fitz.open()
# Get image information
img_bytes = BytesIO()
image.save(img_bytes, format='PNG')
# pix = fitz.Pixmap(image_path)
pix = fitz.Pixmap(img_bytes)
# Create a page
page = doc.new_page(width=pix.width, height=pix.height)
page.insert_image(
fitz.Rect(0, 0, pix.width, pix.height),
# filename=image_path
pixmap=pix
)
for i, cell in enumerate(cells):
bbox = cell['bbox']
layout_type = cell['category']
order = i
top_left = (bbox[0], bbox[1])
down_right = (bbox[2], bbox[3])
if resized_height and resized_width:
scale_x = resized_width / original_width
scale_y = resized_height / original_height
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
color = [col/255 for col in color[:3]]
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
rect_coords = fitz.Rect(x0, y0, x1, y1)
if draw_bbox:
if fill_bbox:
page.draw_rect(
rect_coords,
color=None,
fill=color,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=color,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
order_cate = f"{order}_{layout_type}"
page.insert_text(
(x1, y0 + 20), order_cate, fontsize=20, color=color
) # Insert the index in the top left corner of the rectangle
# Convert to a Pixmap (maintaining original dimensions)
mat = fitz.Matrix(1.0, 1.0)
pix = page.get_pixmap(matrix=mat)
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
def pre_process_bboxes(
origin_image,
bboxes,
input_width,
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
):
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = original_width / input_width
scale_y = original_height / input_height
bboxes_out = []
for bbox in bboxes:
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
bboxes_out.append(bbox_resized)
return bboxes_out
def post_process_cells(
origin_image: Image.Image,
cells: List[Dict],
input_width, # server input width, also has smart_resize in server
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
) -> List[Dict]:
"""
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
Args:
origin_image: The original PIL Image.
cells: A list of cells containing bounding box information.
input_width: The width of the input image sent to the server.
input_height: The height of the input image sent to the server.
factor: Resizing factor.
min_pixels: Minimum number of pixels.
max_pixels: Maximum number of pixels.
Returns:
A list of post-processed cells.
"""
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = input_width / original_width
scale_y = input_height / original_height
cells_out = []
for cell in cells:
bbox = cell['bbox']
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
cell_copy = cell.copy()
cell_copy['bbox'] = bbox_resized
cells_out.append(cell_copy)
return cells_out
def is_legal_bbox(cells):
for cell in cells:
bbox = cell['bbox']
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
return False
return True
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
return response
json_load_failed = False
cells = response
try:
cells = json.loads(cells)
cells = post_process_cells(
origin_image,
cells,
input_image.width,
input_image.height,
min_pixels=min_pixels,
max_pixels=max_pixels
)
return cells, False
except Exception as e:
print(f"cells post process error: {e}, when using {prompt_mode}")
json_load_failed = True
if json_load_failed:
cleaner = OutputCleaner()
response_clean = cleaner.clean_model_output(cells)
if isinstance(response_clean, list):
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
return response_clean, True
|