Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Literal | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| def extract_patches( | |
| pages: [np.ndarray], | |
| patch_size=(16, 16), | |
| patches_mode: Literal['all', 'important'] = 'important', | |
| pages_mode: Literal['concat', 'index'] = 'index', | |
| ): | |
| patch_height, patch_width = patch_size | |
| pages_important_patches = [get_image_interesting_patches(page) for page in pages] if patches_mode == 'important'\ | |
| else [None for _ in pages] | |
| pages = [pad_page(page, patch_size) for page in pages] | |
| pages = [normalize_page(page) for page in pages] | |
| # make it channel first | |
| pages = [torch.from_numpy(page).to(torch.float32).permute(2, 0, 1) for page in pages] | |
| patches = [ | |
| torch.nn.functional.unfold(page, (patch_height, patch_width), stride=(patch_height, patch_width)) | |
| .reshape(page.size(0), patch_height, patch_width, -1) | |
| # (rows * cols, patch_height, patch_width, image_channels) | |
| .permute(3, 1, 2, 0) | |
| # (rows * cols, patch_height * patch_width * image_channels) | |
| .reshape( | |
| page.size(1) // patch_height, | |
| page.size(2) // patch_width, | |
| page.size(0) * patch_height * patch_width, | |
| ) | |
| for page in pages | |
| ] | |
| page_start_row = 0 | |
| all_patches = [] | |
| for page_idx, (patches, page_important_patches) in enumerate(zip(patches, pages_important_patches)): | |
| rows, cols, patch_size = patches.shape | |
| patches = patches.reshape(rows * cols, patch_size) | |
| # (rows * columns) | |
| row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, cols).reshape([rows * cols, 1]) | |
| col_ids = torch.arange(cols).reshape([1, cols]).repeat(rows, 1).reshape([rows * cols, 1]) | |
| # 0 is padding | |
| row_ids += 1 | |
| col_ids += 1 | |
| row_ids = row_ids.to(torch.float32) | |
| col_ids = col_ids.to(torch.float32) | |
| if pages_mode == 'index': | |
| page_ids = torch.full((rows * cols, 1), page_idx) | |
| page_ids += 1 | |
| page_ids = page_ids.to(torch.float32) | |
| patches = torch.cat([page_ids, row_ids, col_ids, patches], -1) | |
| else: | |
| row_ids += page_start_row | |
| page_start_row += rows | |
| patches = torch.cat([row_ids, col_ids, patches], -1) | |
| if patches_mode == 'important': | |
| important_patches_indexes = [] | |
| for y, x in page_important_patches: | |
| important_patches_indexes.append(y * cols + x) | |
| patches = patches[important_patches_indexes] | |
| all_patches.append(patches) | |
| return torch.cat(all_patches, 0) | |
| def pad_page(page, patch_size): | |
| patch_height, patch_width = patch_size | |
| if page.shape[1] % patch_width != 0: | |
| padding = patch_width - (page.shape[1] % patch_width) | |
| page = np.pad(page, ((0, 0), (0, padding), (0, 0)), constant_values=255) | |
| if page.shape[0] % patch_height != 0: | |
| padding = patch_height - (page.shape[0] % patch_height) | |
| page = np.pad(page, ((0, padding), (0, 0), (0, 0)), constant_values=255) | |
| return page | |
| def normalize_page(page): | |
| mean = np.mean(page) | |
| std = np.std(page) | |
| adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(page.shape))) | |
| return (page - mean) / adjusted_stddev | |
| def get_image_interesting_patches(image, patch_size=(16, 16)): | |
| h, w, _ = image.shape | |
| image_edge = get_image_edges(image) | |
| patches = [] | |
| for y in range(0, h, patch_size[1]): | |
| for x in range(0, w, patch_size[0]): | |
| patch_edge = \ | |
| image_edge[ | |
| max(y, 0):min(y + patch_size[1], h), | |
| max(x, 0):min(x + patch_size[0], w) | |
| ] | |
| if np.any(patch_edge == 0): | |
| patches.append((y // patch_size[1], x // patch_size[0], patch_edge.sum())) | |
| return [ | |
| (x, y) for x, y, _ in sorted(patches, key=lambda x: x[2], reverse=True) | |
| ] | |
| def get_image_edges(img): | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] | |
| # Remove horizontal | |
| horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1)) | |
| detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2) | |
| cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
| for c in cnts: | |
| cv2.drawContours(gray, [c], -1, (255, 255, 255), 2) | |
| repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 6)) | |
| gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1) | |
| # Remove vertical | |
| vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 25)) | |
| detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2) | |
| cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
| for c in cnts: | |
| cv2.drawContours(gray, [c], -1, (255, 255, 255), 2) | |
| repair_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (6, 1)) | |
| gray = 255 - cv2.morphologyEx(255 - gray, cv2.MORPH_CLOSE, repair_kernel, iterations=1) | |
| blurred = cv2.GaussianBlur(gray, (3, 3), 0) | |
| edges = cv2.Canny(blurred, 50, 150) | |
| kernel = np.ones((3, 3), np.uint8) | |
| dilated = cv2.dilate(edges, kernel, iterations=1) | |
| result = cv2.bitwise_not(dilated) | |
| return result | |