artyomxyz's picture
init
98159fd
import math
from typing import Literal
import numpy as np
import torch
import cv2
def extract_patches(
pages: [np.ndarray],
patch_size=(16, 16),
patches_mode: Literal['all', 'important'] = 'important',
pages_mode: Literal['concat', 'index'] = 'index',
):
patch_height, patch_width = patch_size
pages_important_patches = [get_image_interesting_patches(page) for page in pages] if patches_mode == 'important'\
else [None for _ in pages]
pages = [pad_page(page, patch_size) for page in pages]
pages = [normalize_page(page) for page in pages]
# make it channel first
pages = [torch.from_numpy(page).to(torch.float32).permute(2, 0, 1) for page in pages]
patches = [
torch.nn.functional.unfold(page, (patch_height, patch_width), stride=(patch_height, patch_width))
.reshape(page.size(0), patch_height, patch_width, -1)
# (rows * cols, patch_height, patch_width, image_channels)
.permute(3, 1, 2, 0)
# (rows * cols, patch_height * patch_width * image_channels)
.reshape(
page.size(1) // patch_height,
page.size(2) // patch_width,
page.size(0) * patch_height * patch_width,
)
for page in pages
]
page_start_row = 0
all_patches = []
for page_idx, (patches, page_important_patches) in enumerate(zip(patches, pages_important_patches)):
rows, cols, patch_size = patches.shape
patches = patches.reshape(rows * cols, patch_size)
# (rows * columns)
row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, cols).reshape([rows * cols, 1])
col_ids = torch.arange(cols).reshape([1, cols]).repeat(rows, 1).reshape([rows * cols, 1])
# 0 is padding
row_ids += 1
col_ids += 1
row_ids = row_ids.to(torch.float32)
col_ids = col_ids.to(torch.float32)
if pages_mode == 'index':
page_ids = torch.full((rows * cols, 1), page_idx)
page_ids += 1
page_ids = page_ids.to(torch.float32)
patches = torch.cat([page_ids, row_ids, col_ids, patches], -1)
else:
row_ids += page_start_row
page_start_row += rows
patches = torch.cat([row_ids, col_ids, patches], -1)
if patches_mode == 'important':
important_patches_indexes = []
for y, x in page_important_patches:
important_patches_indexes.append(y * cols + x)
patches = patches[important_patches_indexes]
all_patches.append(patches)
return torch.cat(all_patches, 0)
def pad_page(page, patch_size):
patch_height, patch_width = patch_size
if page.shape[1] % patch_width != 0:
padding = patch_width - (page.shape[1] % patch_width)
page = np.pad(page, ((0, 0), (0, padding), (0, 0)), constant_values=255)
if page.shape[0] % patch_height != 0:
padding = patch_height - (page.shape[0] % patch_height)
page = np.pad(page, ((0, padding), (0, 0), (0, 0)), constant_values=255)
return page
def normalize_page(page):
mean = np.mean(page)
std = np.std(page)
adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(page.shape)))
return (page - mean) / adjusted_stddev
def get_image_interesting_patches(image, patch_size=(16, 16)):
h, w, _ = image.shape
image_edge = get_image_edges(image)
patches = []
for y in range(0, h, patch_size[1]):
for x in range(0, w, patch_size[0]):
patch_edge = \
image_edge[
max(y, 0):min(y + patch_size[1], h),
max(x, 0):min(x + patch_size[0], w)
]
if np.any(patch_edge == 0):
patches.append((y // patch_size[1], x // patch_size[0], patch_edge.sum()))
return [
(x, y) for x, y, _ in sorted(patches, key=lambda x: x[2], reverse=True)
]
def get_image_edges(img):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
# Remove horizontal
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1))
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
cv2.drawContours(gray, [c], -1, (255, 255, 255), 2)
repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 6))
gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1)
# Remove vertical
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 25))
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
cv2.drawContours(gray, [c], -1, (255, 255, 255), 2)
repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (6, 1))
gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
edges = cv2.Canny(blurred, 50, 150)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.dilate(edges, kernel, iterations=1)
result = cv2.bitwise_not(dilated)
return result