Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from PIL import Image | |
| from app.preprocessing.base import PreprocessingStep, PreprocessingContext, PreprocessingError | |
| class PatchStep(PreprocessingStep): | |
| name = "patch" | |
| description = "Tile large satellite image into overlapping 640×640 patches, preserving pixel offsets for stitching" | |
| version = "2.1.0" | |
| order = 6 | |
| enabled = True | |
| required = False | |
| async def process(self, ctx: PreprocessingContext, params: dict) -> PreprocessingContext: | |
| if ctx.image is None: | |
| raise PreprocessingError("No image to tile — decode step must run first") | |
| patch_size = params.get("patch_size", 640) | |
| overlap = params.get("overlap", 128) | |
| stride = patch_size - overlap | |
| # Convert to numpy once — avoids PIL C-allocator conflicts with PyTorch | |
| arr = np.array(ctx.image) | |
| H, W = arr.shape[:2] | |
| tiles: list[Image.Image] = [] | |
| offsets: list[tuple[int, int]] = [] | |
| seen: set[tuple[int, int]] = set() | |
| ys = _grid_positions(H, patch_size, stride) | |
| xs = _grid_positions(W, patch_size, stride) | |
| for y in ys: | |
| for x in xs: | |
| if (x, y) in seen: | |
| continue | |
| seen.add((x, y)) | |
| patch = arr[y:y + patch_size, x:x + patch_size] | |
| # Pad if edge tile is smaller than patch_size | |
| ph, pw = patch.shape[:2] | |
| if ph < patch_size or pw < patch_size: | |
| pad = np.zeros((patch_size, patch_size, arr.shape[2] if arr.ndim == 3 else 1), | |
| dtype=arr.dtype) | |
| if arr.ndim == 2: | |
| pad = np.zeros((patch_size, patch_size), dtype=arr.dtype) | |
| pad[:ph, :pw] = patch | |
| patch = pad | |
| tiles.append(Image.fromarray(patch)) | |
| offsets.append((x, y)) | |
| ctx.tiles = tiles | |
| ctx.tile_offsets = offsets | |
| ctx.metadata["num_tiles"] = len(tiles) | |
| ctx.metadata["image_size"] = [W, H] | |
| ctx.metadata["patch_size"] = patch_size | |
| ctx.metadata["overlap"] = overlap | |
| ctx.metadata["stride"] = stride | |
| ctx.step_outputs["patch"] = { | |
| "num_tiles": len(tiles), | |
| "patch_size": patch_size, | |
| "overlap": overlap, | |
| "stride": stride, | |
| "image_size": [W, H], | |
| } | |
| return ctx | |
| def _grid_positions(length: int, patch_size: int, stride: int) -> list[int]: | |
| """Return start positions covering [0, length) with given stride, always including the last patch.""" | |
| if length <= patch_size: | |
| return [0] | |
| positions = list(range(0, length - patch_size, stride)) | |
| last = length - patch_size | |
| if not positions or positions[-1] != last: | |
| positions.append(last) | |
| return positions | |