Spaces:
Runtime error
Runtime error
File size: 1,748 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | from dataclasses import asdict
from ..data_types import BatchedExample, BatchedViews, UnbatchedViews, BatchedViewsDict, UnbatchedExample
def apply_patch_shim_to_views(views: BatchedViews | UnbatchedViews | BatchedViewsDict,
patch_size: int | list[int]) -> BatchedViews | UnbatchedViews | BatchedViewsDict:
*_, h, w = views["image"].shape
if isinstance(patch_size, int):
patch_size_x = patch_size
patch_size_y = patch_size
else:
patch_size_x, patch_size_y = patch_size
h_new = (h // patch_size_x) * patch_size_x
row = (h - h_new) // 2
w_new = (w // patch_size_y) * patch_size_y
col = (w - w_new) // 2
# Center-crop the image.
image = views["image"][..., row: row + h_new, col: col + w_new]
# Adjust the intrinsics to account for the cropping.
intrinsics = views["intrinsics"].clone()
intrinsics[..., 0, 0] *= w / w_new # fx
intrinsics[..., 1, 1] *= h / h_new # fy
if isinstance(views, BatchedViews):
return BatchedViews.from_dict({
**asdict(views),
"image": image,
"intrinsics": intrinsics,
})
else:
return {
**views,
"image": image,
"intrinsics": intrinsics,
}
def apply_patch_shim(batch: BatchedExample | UnbatchedExample,
patch_size: int | list[int]) -> BatchedExample | UnbatchedExample:
"""Crop images in the batch so that their dimensions are cleanly divisible by the
specified patch size.
"""
return {
**batch,
"context": apply_patch_shim_to_views(batch["context"], patch_size),
"target": apply_patch_shim_to_views(batch["target"], patch_size),
}
|