MRiabov's picture
Running training
61275b8
# Balanced patch sampler (>=1% wire pixels)
"""Balanced patch sampling with >= ``min_wire_ratio`` positives.
Sampling is uniform over valid top-left positions for up to ``max_tries``.
If no patch meets ``min_wire_ratio``, it falls back to the best observed
candidate (highest wire ratio) instead of raising.
"""
from dataclasses import dataclass
import numpy as np
@dataclass
class BalancedPatchSampler:
patch_size: int = 768
min_wire_ratio: float = 0.01
max_tries: int = 200
def sample(self, image: np.ndarray, mask: np.ndarray) -> tuple[int, int]:
h, w = mask.shape
p = self.patch_size
assert h >= p and w >= p, "Image smaller than patch size"
best_ratio = -1.0
best_y = 0
best_x = 0
for _ in range(self.max_tries):
y = np.random.randint(0, h - p + 1)
x = np.random.randint(0, w - p + 1)
m = mask[y : y + p, x : x + p]
ratio = float(m.sum()) / float(p * p)
if ratio > best_ratio:
best_ratio = ratio
best_y, best_x = y, x
if ratio >= self.min_wire_ratio:
return int(y), int(x)
# Fallback: return best candidate even if below threshold
return int(best_y), int(best_x)