| import torch | |
| from typing import Tuple | |
| def batch_sliding_window_inference(images: torch.Tensor, model: torch.nn.Module, | |
| device: torch.device, | |
| crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor: | |
| """ | |
| Applies sliding window inference with final argmax (prediction). | |
| Returns a tensor (B, H, W) with predicted classes. | |
| """ | |
| B, C, H, W = images.shape | |
| ph, pw = crop_size | |
| sh, sw = stride | |
| images = images.to(device) | |
| num_classes = model.config.num_labels | |
| full_logits = torch.zeros((B, num_classes, H, W), device=device) | |
| count_map = torch.zeros((H, W), device=device) | |
| for top in range(0, H, sh): | |
| for left in range(0, W, sw): | |
| bottom = min(top + ph, H) | |
| right = min(left + pw, W) | |
| top0 = max(bottom - ph, 0) | |
| left0 = max(right - pw, 0) | |
| patch = images[:, :, top0:bottom, left0:right].contiguous() | |
| with torch.no_grad(): | |
| logits = model(pixel_values=patch).logits | |
| full_logits[:, :, top0:bottom, left0:right] += logits | |
| count_map[top0:bottom, left0:right] += 1 | |
| avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1) | |
| return avg_logits.argmax(dim=1) | |