Spaces:
Runtime error
Runtime error
File size: 5,498 Bytes
98159fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import math
from typing import Literal
import numpy as np
import torch
import cv2
def extract_patches(
pages: [np.ndarray],
patch_size=(16, 16),
patches_mode: Literal['all', 'important'] = 'important',
pages_mode: Literal['concat', 'index'] = 'index',
):
patch_height, patch_width = patch_size
pages_important_patches = [get_image_interesting_patches(page) for page in pages] if patches_mode == 'important'\
else [None for _ in pages]
pages = [pad_page(page, patch_size) for page in pages]
pages = [normalize_page(page) for page in pages]
# make it channel first
pages = [torch.from_numpy(page).to(torch.float32).permute(2, 0, 1) for page in pages]
patches = [
torch.nn.functional.unfold(page, (patch_height, patch_width), stride=(patch_height, patch_width))
.reshape(page.size(0), patch_height, patch_width, -1)
# (rows * cols, patch_height, patch_width, image_channels)
.permute(3, 1, 2, 0)
# (rows * cols, patch_height * patch_width * image_channels)
.reshape(
page.size(1) // patch_height,
page.size(2) // patch_width,
page.size(0) * patch_height * patch_width,
)
for page in pages
]
page_start_row = 0
all_patches = []
for page_idx, (patches, page_important_patches) in enumerate(zip(patches, pages_important_patches)):
rows, cols, patch_size = patches.shape
patches = patches.reshape(rows * cols, patch_size)
# (rows * columns)
row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, cols).reshape([rows * cols, 1])
col_ids = torch.arange(cols).reshape([1, cols]).repeat(rows, 1).reshape([rows * cols, 1])
# 0 is padding
row_ids += 1
col_ids += 1
row_ids = row_ids.to(torch.float32)
col_ids = col_ids.to(torch.float32)
if pages_mode == 'index':
page_ids = torch.full((rows * cols, 1), page_idx)
page_ids += 1
page_ids = page_ids.to(torch.float32)
patches = torch.cat([page_ids, row_ids, col_ids, patches], -1)
else:
row_ids += page_start_row
page_start_row += rows
patches = torch.cat([row_ids, col_ids, patches], -1)
if patches_mode == 'important':
important_patches_indexes = []
for y, x in page_important_patches:
important_patches_indexes.append(y * cols + x)
patches = patches[important_patches_indexes]
all_patches.append(patches)
return torch.cat(all_patches, 0)
def pad_page(page, patch_size):
patch_height, patch_width = patch_size
if page.shape[1] % patch_width != 0:
padding = patch_width - (page.shape[1] % patch_width)
page = np.pad(page, ((0, 0), (0, padding), (0, 0)), constant_values=255)
if page.shape[0] % patch_height != 0:
padding = patch_height - (page.shape[0] % patch_height)
page = np.pad(page, ((0, padding), (0, 0), (0, 0)), constant_values=255)
return page
def normalize_page(page):
mean = np.mean(page)
std = np.std(page)
adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(page.shape)))
return (page - mean) / adjusted_stddev
def get_image_interesting_patches(image, patch_size=(16, 16)):
h, w, _ = image.shape
image_edge = get_image_edges(image)
patches = []
for y in range(0, h, patch_size[1]):
for x in range(0, w, patch_size[0]):
patch_edge = \
image_edge[
max(y, 0):min(y + patch_size[1], h),
max(x, 0):min(x + patch_size[0], w)
]
if np.any(patch_edge == 0):
patches.append((y // patch_size[1], x // patch_size[0], patch_edge.sum()))
return [
(x, y) for x, y, _ in sorted(patches, key=lambda x: x[2], reverse=True)
]
def get_image_edges(img):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
# Remove horizontal
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1))
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
cv2.drawContours(gray, [c], -1, (255, 255, 255), 2)
repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 6))
gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1)
# Remove vertical
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 25))
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
cv2.drawContours(gray, [c], -1, (255, 255, 255), 2)
repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (6, 1))
gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
edges = cv2.Canny(blurred, 50, 150)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.dilate(edges, kernel, iterations=1)
result = cv2.bitwise_not(dilated)
return result
|