| | import math |
| | import numpy as np |
| | import torch |
| |
|
| | from typing import TypedDict |
| |
|
| | try: |
| | import pyvips |
| |
|
| | HAS_VIPS = True |
| | except: |
| | from PIL import Image |
| |
|
| | HAS_VIPS = False |
| |
|
| |
|
| | def select_tiling( |
| | height: int, width: int, crop_size: int, max_crops: int |
| | ) -> tuple[int, int]: |
| | """ |
| | Determine the optimal number of tiles to cover an image with overlapping crops. |
| | """ |
| | if height <= crop_size or width <= crop_size: |
| | return (1, 1) |
| |
|
| | |
| | min_h = math.ceil(height / crop_size) |
| | min_w = math.ceil(width / crop_size) |
| |
|
| | |
| | if min_h * min_w > max_crops: |
| | ratio = math.sqrt(max_crops / (min_h * min_w)) |
| | return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio))) |
| |
|
| | |
| | h_tiles = math.floor(math.sqrt(max_crops * height / width)) |
| | w_tiles = math.floor(math.sqrt(max_crops * width / height)) |
| |
|
| | |
| | h_tiles = max(h_tiles, min_h) |
| | w_tiles = max(w_tiles, min_w) |
| |
|
| | |
| | if h_tiles * w_tiles > max_crops: |
| | if w_tiles > h_tiles: |
| | w_tiles = math.floor(max_crops / h_tiles) |
| | else: |
| | h_tiles = math.floor(max_crops / w_tiles) |
| |
|
| | return (max(1, h_tiles), max(1, w_tiles)) |
| |
|
| |
|
| | class OverlapCropOutput(TypedDict): |
| | crops: np.ndarray |
| | tiling: tuple[int, int] |
| |
|
| |
|
| | def overlap_crop_image( |
| | image: np.ndarray, |
| | overlap_margin: int, |
| | max_crops: int, |
| | base_size: tuple[int, int] = (378, 378), |
| | patch_size: int = 14, |
| | ) -> OverlapCropOutput: |
| | """ |
| | Process an image using an overlap-and-resize cropping strategy with margin handling. |
| | |
| | This function takes an input image and creates multiple overlapping crops with |
| | consistent margins. It produces: |
| | 1. A single global crop resized to base_size |
| | 2. Multiple overlapping local crops that maintain high resolution details |
| | 3. A patch ordering matrix that tracks correspondence between crops |
| | |
| | The overlap strategy ensures: |
| | - Smooth transitions between adjacent crops |
| | - No loss of information at crop boundaries |
| | - Proper handling of features that cross crop boundaries |
| | - Consistent patch indexing across the full image |
| | |
| | Args: |
| | image (np.ndarray): Input image as numpy array with shape (H,W,C) |
| | base_size (tuple[int,int]): Target size for crops, default (378,378) |
| | patch_size (int): Size of patches in pixels, default 14 |
| | overlap_margin (int): Margin size in patch units, default 4 |
| | max_crops (int): Maximum number of crops allowed, default 12 |
| | |
| | Returns: |
| | OverlapCropOutput: Dictionary containing: |
| | - crops: A numpy array containing the global crop of the full image (index 0) |
| | followed by the overlapping cropped regions (indices 1+) |
| | - tiling: Tuple of (height,width) tile counts |
| | """ |
| | original_h, original_w = image.shape[:2] |
| |
|
| | |
| | margin_pixels = patch_size * overlap_margin |
| | total_margin_pixels = margin_pixels * 2 |
| |
|
| | |
| | crop_patches = base_size[0] // patch_size |
| | crop_window_patches = crop_patches - (2 * overlap_margin) |
| | crop_window_size = crop_window_patches * patch_size |
| |
|
| | |
| | tiling = select_tiling( |
| | original_h - total_margin_pixels, |
| | original_w - total_margin_pixels, |
| | crop_window_size, |
| | max_crops, |
| | ) |
| |
|
| | |
| | n_crops = tiling[0] * tiling[1] + 1 |
| | crops = np.zeros( |
| | (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8 |
| | ) |
| |
|
| | |
| | target_size = ( |
| | tiling[0] * crop_window_size + total_margin_pixels, |
| | tiling[1] * crop_window_size + total_margin_pixels, |
| | ) |
| |
|
| | if HAS_VIPS: |
| | |
| | vips_image = pyvips.Image.new_from_array(image) |
| | scale_x = target_size[1] / image.shape[1] |
| | scale_y = target_size[0] / image.shape[0] |
| | resized = vips_image.resize(scale_x, vscale=scale_y) |
| | image = resized.numpy() |
| |
|
| | |
| | scale_x = base_size[1] / vips_image.width |
| | scale_y = base_size[0] / vips_image.height |
| | global_vips = vips_image.resize(scale_x, vscale=scale_y) |
| | crops[0] = global_vips.numpy() |
| | else: |
| | |
| | pil_img = Image.fromarray(image) |
| | resized = pil_img.resize( |
| | (int(target_size[1]), int(target_size[0])), |
| | resample=Image.Resampling.LANCZOS, |
| | ) |
| | image = np.asarray(resized) |
| |
|
| | |
| | global_pil = pil_img.resize( |
| | (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS |
| | ) |
| | crops[0] = np.asarray(global_pil) |
| |
|
| | for i in range(tiling[0]): |
| | for j in range(tiling[1]): |
| | |
| | y0 = i * crop_window_size |
| | x0 = j * crop_window_size |
| |
|
| | |
| | y_end = min(y0 + base_size[0], image.shape[0]) |
| | x_end = min(x0 + base_size[1], image.shape[1]) |
| |
|
| | crop_region = image[y0:y_end, x0:x_end] |
| | crops[ |
| | 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1] |
| | ] = crop_region |
| |
|
| | return {"crops": crops, "tiling": tiling} |
| |
|
| |
|
| | def reconstruct_from_crops( |
| | crops: torch.Tensor, |
| | tiling: tuple[int, int], |
| | overlap_margin: int, |
| | patch_size: int = 14, |
| | ) -> torch.Tensor: |
| | """ |
| | Reconstruct the original image from overlapping crops into a single seamless image. |
| | |
| | Takes a list of overlapping image crops along with their positional metadata and |
| | reconstructs them into a single coherent image by carefully stitching together |
| | non-overlapping regions. Handles both numpy arrays and PyTorch tensors. |
| | |
| | Args: |
| | crops: List of image crops as numpy arrays or PyTorch tensors with shape |
| | (H,W,C) |
| | tiling: Tuple of (height,width) indicating crop grid layout |
| | patch_size: Size in pixels of each patch, default 14 |
| | overlap_margin: Number of overlapping patches on each edge, default 4 |
| | |
| | Returns: |
| | Reconstructed image as numpy array or PyTorch tensor matching input type, |
| | with shape (H,W,C) where H,W are the original image dimensions |
| | """ |
| | tiling_h, tiling_w = tiling |
| | crop_height, crop_width = crops[0].shape[:2] |
| | margin_pixels = overlap_margin * patch_size |
| |
|
| | |
| | output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels |
| | output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels |
| |
|
| | reconstructed = torch.zeros( |
| | (output_h, output_w, crops[0].shape[2]), |
| | device=crops[0].device, |
| | dtype=crops[0].dtype, |
| | ) |
| |
|
| | for i, crop in enumerate(crops): |
| | tile_y = i // tiling_w |
| | tile_x = i % tiling_w |
| |
|
| | |
| | |
| | x_start = 0 if tile_x == 0 else margin_pixels |
| | |
| | x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels |
| | |
| | y_start = 0 if tile_y == 0 else margin_pixels |
| | |
| | y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels |
| |
|
| | |
| | out_x = tile_x * (crop_width - 2 * margin_pixels) |
| | out_y = tile_y * (crop_height - 2 * margin_pixels) |
| |
|
| | |
| | reconstructed[ |
| | out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end |
| | ] = crop[y_start:y_end, x_start:x_end] |
| |
|
| | return reconstructed |
| |
|