import math from typing import Literal import numpy as np import torch import cv2 def extract_patches( pages: [np.ndarray], patch_size=(16, 16), patches_mode: Literal['all', 'important'] = 'important', pages_mode: Literal['concat', 'index'] = 'index', ): patch_height, patch_width = patch_size pages_important_patches = [get_image_interesting_patches(page) for page in pages] if patches_mode == 'important'\ else [None for _ in pages] pages = [pad_page(page, patch_size) for page in pages] pages = [normalize_page(page) for page in pages] # make it channel first pages = [torch.from_numpy(page).to(torch.float32).permute(2, 0, 1) for page in pages] patches = [ torch.nn.functional.unfold(page, (patch_height, patch_width), stride=(patch_height, patch_width)) .reshape(page.size(0), patch_height, patch_width, -1) # (rows * cols, patch_height, patch_width, image_channels) .permute(3, 1, 2, 0) # (rows * cols, patch_height * patch_width * image_channels) .reshape( page.size(1) // patch_height, page.size(2) // patch_width, page.size(0) * patch_height * patch_width, ) for page in pages ] page_start_row = 0 all_patches = [] for page_idx, (patches, page_important_patches) in enumerate(zip(patches, pages_important_patches)): rows, cols, patch_size = patches.shape patches = patches.reshape(rows * cols, patch_size) # (rows * columns) row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, cols).reshape([rows * cols, 1]) col_ids = torch.arange(cols).reshape([1, cols]).repeat(rows, 1).reshape([rows * cols, 1]) # 0 is padding row_ids += 1 col_ids += 1 row_ids = row_ids.to(torch.float32) col_ids = col_ids.to(torch.float32) if pages_mode == 'index': page_ids = torch.full((rows * cols, 1), page_idx) page_ids += 1 page_ids = page_ids.to(torch.float32) patches = torch.cat([page_ids, row_ids, col_ids, patches], -1) else: row_ids += page_start_row page_start_row += rows patches = torch.cat([row_ids, col_ids, patches], -1) if patches_mode == 'important': important_patches_indexes = [] for y, x in page_important_patches: important_patches_indexes.append(y * cols + x) patches = patches[important_patches_indexes] all_patches.append(patches) return torch.cat(all_patches, 0) def pad_page(page, patch_size): patch_height, patch_width = patch_size if page.shape[1] % patch_width != 0: padding = patch_width - (page.shape[1] % patch_width) page = np.pad(page, ((0, 0), (0, padding), (0, 0)), constant_values=255) if page.shape[0] % patch_height != 0: padding = patch_height - (page.shape[0] % patch_height) page = np.pad(page, ((0, padding), (0, 0), (0, 0)), constant_values=255) return page def normalize_page(page): mean = np.mean(page) std = np.std(page) adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(page.shape))) return (page - mean) / adjusted_stddev def get_image_interesting_patches(image, patch_size=(16, 16)): h, w, _ = image.shape image_edge = get_image_edges(image) patches = [] for y in range(0, h, patch_size[1]): for x in range(0, w, patch_size[0]): patch_edge = \ image_edge[ max(y, 0):min(y + patch_size[1], h), max(x, 0):min(x + patch_size[0], w) ] if np.any(patch_edge == 0): patches.append((y // patch_size[1], x // patch_size[0], patch_edge.sum())) return [ (x, y) for x, y, _ in sorted(patches, key=lambda x: x[2], reverse=True) ] def get_image_edges(img): gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # Remove horizontal horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1)) detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2) cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cnts = cnts[0] if len(cnts) == 2 else cnts[1] for c in cnts: cv2.drawContours(gray, [c], -1, (255, 255, 255), 2) repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 6)) gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1) # Remove vertical vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 25)) detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2) cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cnts = cnts[0] if len(cnts) == 2 else cnts[1] for c in cnts: cv2.drawContours(gray, [c], -1, (255, 255, 255), 2) repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (6, 1)) gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1) blurred = cv2.GaussianBlur(gray, (3, 3), 0) edges = cv2.Canny(blurred, 50, 150) kernel = np.ones((3, 3), np.uint8) dilated = cv2.dilate(edges, kernel, iterations=1) result = cv2.bitwise_not(dilated) return result