Spaces:
Sleeping
Sleeping
File size: 2,864 Bytes
1c77735 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import numpy as np
from PIL import Image
from app.preprocessing.base import PreprocessingStep, PreprocessingContext, PreprocessingError
class PatchStep(PreprocessingStep):
name = "patch"
description = "Tile large satellite image into overlapping 640×640 patches, preserving pixel offsets for stitching"
version = "2.1.0"
order = 6
enabled = True
required = False
async def process(self, ctx: PreprocessingContext, params: dict) -> PreprocessingContext:
if ctx.image is None:
raise PreprocessingError("No image to tile — decode step must run first")
patch_size = params.get("patch_size", 640)
overlap = params.get("overlap", 128)
stride = patch_size - overlap
# Convert to numpy once — avoids PIL C-allocator conflicts with PyTorch
arr = np.array(ctx.image)
H, W = arr.shape[:2]
tiles: list[Image.Image] = []
offsets: list[tuple[int, int]] = []
seen: set[tuple[int, int]] = set()
ys = _grid_positions(H, patch_size, stride)
xs = _grid_positions(W, patch_size, stride)
for y in ys:
for x in xs:
if (x, y) in seen:
continue
seen.add((x, y))
patch = arr[y:y + patch_size, x:x + patch_size]
# Pad if edge tile is smaller than patch_size
ph, pw = patch.shape[:2]
if ph < patch_size or pw < patch_size:
pad = np.zeros((patch_size, patch_size, arr.shape[2] if arr.ndim == 3 else 1),
dtype=arr.dtype)
if arr.ndim == 2:
pad = np.zeros((patch_size, patch_size), dtype=arr.dtype)
pad[:ph, :pw] = patch
patch = pad
tiles.append(Image.fromarray(patch))
offsets.append((x, y))
ctx.tiles = tiles
ctx.tile_offsets = offsets
ctx.metadata["num_tiles"] = len(tiles)
ctx.metadata["image_size"] = [W, H]
ctx.metadata["patch_size"] = patch_size
ctx.metadata["overlap"] = overlap
ctx.metadata["stride"] = stride
ctx.step_outputs["patch"] = {
"num_tiles": len(tiles),
"patch_size": patch_size,
"overlap": overlap,
"stride": stride,
"image_size": [W, H],
}
return ctx
def _grid_positions(length: int, patch_size: int, stride: int) -> list[int]:
"""Return start positions covering [0, length) with given stride, always including the last patch."""
if length <= patch_size:
return [0]
positions = list(range(0, length - patch_size, stride))
last = length - patch_size
if not positions or positions[-1] != last:
positions.append(last)
return positions
|