|
|
|
|
|
"""Balanced patch sampling with >= ``min_wire_ratio`` positives. |
|
|
|
|
|
Sampling is uniform over valid top-left positions for up to ``max_tries``. |
|
|
If no patch meets ``min_wire_ratio``, it falls back to the best observed |
|
|
candidate (highest wire ratio) instead of raising. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BalancedPatchSampler: |
|
|
patch_size: int = 768 |
|
|
min_wire_ratio: float = 0.01 |
|
|
max_tries: int = 200 |
|
|
|
|
|
def sample(self, image: np.ndarray, mask: np.ndarray) -> tuple[int, int]: |
|
|
h, w = mask.shape |
|
|
p = self.patch_size |
|
|
assert h >= p and w >= p, "Image smaller than patch size" |
|
|
best_ratio = -1.0 |
|
|
best_y = 0 |
|
|
best_x = 0 |
|
|
for _ in range(self.max_tries): |
|
|
y = np.random.randint(0, h - p + 1) |
|
|
x = np.random.randint(0, w - p + 1) |
|
|
m = mask[y : y + p, x : x + p] |
|
|
ratio = float(m.sum()) / float(p * p) |
|
|
if ratio > best_ratio: |
|
|
best_ratio = ratio |
|
|
best_y, best_x = y, x |
|
|
if ratio >= self.min_wire_ratio: |
|
|
return int(y), int(x) |
|
|
|
|
|
return int(best_y), int(best_x) |
|
|
|