| import torch |
| import numpy as np |
| from tqdm import tqdm |
| from pathlib import Path |
| from typing import Dict, Tuple, List |
|
|
| def batch_sliding_window_logits(images: torch.Tensor, model: torch.nn.Module, |
| device: torch.device, |
| crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor: |
| """ |
| Version of sliding inference that returns average logits for export. |
| """ |
| B, C, H, W = images.shape |
| ph, pw = crop_size |
| sh, sw = stride |
| images = images.to(device) |
|
|
| num_classes = model.config.num_labels |
| full_logits = torch.zeros((B, num_classes, H, W), device=device) |
| count_map = torch.zeros((H, W), device=device) |
|
|
| with torch.no_grad(): |
| for top in range(0, H, sh): |
| for left in range(0, W, sw): |
| bottom = min(top + ph, H) |
| right = min(left + pw, W) |
| top0 = max(bottom - ph, 0) |
| left0 = max(right - pw, 0) |
| patch = images[:, :, top0:bottom, left0:right] |
| logits = model(pixel_values=patch).logits |
| full_logits[:, :, top0:bottom, left0:right] += logits |
| count_map[top0:bottom, left0:right] += 1 |
|
|
| avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1) |
| return avg_logits |
|
|
|
|
| def export_logits_images(model, loader, device, crop_size, stride, output_dir: Path): |
| """ |
| Applies batch sliding window and exports probs as .npy files. |
| """ |
| model.eval() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for images, rel_paths in tqdm(loader, desc=f"Export logits to {output_dir}", leave=False): |
| avg_logits = batch_sliding_window_logits(images, model, device, crop_size, stride) |
| probs = torch.softmax(avg_logits, dim=1) |
| probs = (probs * 255.0).clamp(0, 255).byte().cpu() |
|
|
| B, C, H, W = probs.shape |
| for b in range(B): |
| arr = probs[b].permute(1, 2, 0).numpy() |
| out_path = output_dir / rel_paths[b] |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| np.save(out_path.with_suffix('.npy'), arr) |
|
|