File size: 1,281 Bytes
8ea2eff
61275b8
8e73ec9
61275b8
 
 
8e73ec9
8ea2eff
 
8e73ec9
8ea2eff
 
 
 
 
 
8e73ec9
8ea2eff
8e73ec9
 
 
 
61275b8
 
 
8e73ec9
 
 
 
 
61275b8
 
 
8e73ec9
 
61275b8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Balanced patch sampler (>=1% wire pixels)
"""Balanced patch sampling with >= ``min_wire_ratio`` positives.

Sampling is uniform over valid top-left positions for up to ``max_tries``.
If no patch meets ``min_wire_ratio``, it falls back to the best observed
candidate (highest wire ratio) instead of raising.
"""

from dataclasses import dataclass
import numpy as np


@dataclass
class BalancedPatchSampler:
    patch_size: int = 768
    min_wire_ratio: float = 0.01
    max_tries: int = 200

    def sample(self, image: np.ndarray, mask: np.ndarray) -> tuple[int, int]:
        h, w = mask.shape
        p = self.patch_size
        assert h >= p and w >= p, "Image smaller than patch size"
        best_ratio = -1.0
        best_y = 0
        best_x = 0
        for _ in range(self.max_tries):
            y = np.random.randint(0, h - p + 1)
            x = np.random.randint(0, w - p + 1)
            m = mask[y : y + p, x : x + p]
            ratio = float(m.sum()) / float(p * p)
            if ratio > best_ratio:
                best_ratio = ratio
                best_y, best_x = y, x
            if ratio >= self.min_wire_ratio:
                return int(y), int(x)
        # Fallback: return best candidate even if below threshold
        return int(best_y), int(best_x)