Spaces:
Sleeping
Sleeping
Update spm.py
Browse files
spm.py
CHANGED
|
@@ -32,9 +32,7 @@ def _to_divisible_by(img, N):
|
|
| 32 |
def _edgelogic(i, j, ph, pw, N, overlap):
|
| 33 |
"""
|
| 34 |
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
|
| 35 |
-
Extend with overlap, biasing inward.
|
| 36 |
-
Uses 2*overlap for edges to keep patch areas roughly comparable.
|
| 37 |
-
Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
|
| 38 |
"""
|
| 39 |
start_h = i * ph
|
| 40 |
start_w = j * pw
|
|
@@ -70,21 +68,20 @@ def spm_augment(
|
|
| 70 |
mix_prob=0.5,
|
| 71 |
beta_a=2.0,
|
| 72 |
beta_b=2.0,
|
| 73 |
-
|
| 74 |
seed=None
|
| 75 |
):
|
| 76 |
"""
|
| 77 |
SPM-style augmentation with optional overlap + feathered blending.
|
| 78 |
|
| 79 |
-
When
|
| 80 |
- Standard global shuffle over N×N patches;
|
| 81 |
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
|
| 82 |
|
| 83 |
-
When
|
| 84 |
-
- Each base cell (N×N grid) expands by
|
| 85 |
-
clipped to the image. Patches are mixed per location
|
| 86 |
-
|
| 87 |
-
- Patches are blended into the canvas with a feather mask of size `overlap_px`.
|
| 88 |
"""
|
| 89 |
# Normalize to PIL and ensure divisibility
|
| 90 |
if isinstance(image, np.ndarray):
|
|
@@ -100,10 +97,10 @@ def spm_augment(
|
|
| 100 |
ph = H // N
|
| 101 |
pw = W // N
|
| 102 |
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
overlap_px = int(
|
| 107 |
max_ov = max(0, min(ph, pw) // 2 - 1)
|
| 108 |
ov = int(np.clip(overlap_px, 0, max_ov))
|
| 109 |
|
|
@@ -167,7 +164,6 @@ def spm_augment(
|
|
| 167 |
total = len(patches)
|
| 168 |
perm = rng.permutation(total)
|
| 169 |
|
| 170 |
-
# We'll sample alpha per-patch to echo your overlap snippet
|
| 171 |
def sample_alpha():
|
| 172 |
if beta_a > 0 and beta_b > 0:
|
| 173 |
return float(rng.beta(beta_a, beta_b))
|
|
@@ -178,9 +174,7 @@ def spm_augment(
|
|
| 178 |
|
| 179 |
for k, (sh, eh, sw, ew) in enumerate(coords):
|
| 180 |
if rng.random() >= float(mix_prob):
|
| 181 |
-
|
| 182 |
-
src = patches[k]
|
| 183 |
-
patch = src
|
| 184 |
else:
|
| 185 |
lam = sample_alpha()
|
| 186 |
src = patches[k].astype(np.float32)
|
|
@@ -188,18 +182,12 @@ def spm_augment(
|
|
| 188 |
patch = lam * shf + (1.0 - lam) * src
|
| 189 |
|
| 190 |
ph_k, pw_k, _ = patch.shape
|
| 191 |
-
# Slice feather mask down if needed (near borders)
|
| 192 |
mask2d = feather_full[:ph_k, :pw_k]
|
| 193 |
-
if arr.shape[2] == 1
|
| 194 |
-
mask3d = mask2d[..., None]
|
| 195 |
-
else:
|
| 196 |
-
mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
|
| 197 |
|
| 198 |
-
# Accumulate
|
| 199 |
canvas[sh:eh, sw:ew] += patch * mask3d
|
| 200 |
weight[sh:eh, sw:ew] += mask2d
|
| 201 |
|
| 202 |
-
# Normalize
|
| 203 |
weight = np.clip(weight, 1e-8, None)
|
| 204 |
out = (canvas / weight[..., None])
|
| 205 |
out = np.clip(out, 0, 255).astype(np.uint8)
|
|
|
|
| 32 |
def _edgelogic(i, j, ph, pw, N, overlap):
|
| 33 |
"""
|
| 34 |
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
|
| 35 |
+
Extend with overlap, biasing inward. Uses 2*overlap at borders.
|
|
|
|
|
|
|
| 36 |
"""
|
| 37 |
start_h = i * ph
|
| 38 |
start_w = j * pw
|
|
|
|
| 68 |
mix_prob=0.5,
|
| 69 |
beta_a=2.0,
|
| 70 |
beta_b=2.0,
|
| 71 |
+
overlap_pct=0.0, # percentage of patch size (0..49 typically)
|
| 72 |
seed=None
|
| 73 |
):
|
| 74 |
"""
|
| 75 |
SPM-style augmentation with optional overlap + feathered blending.
|
| 76 |
|
| 77 |
+
When overlap_pct <= 0:
|
| 78 |
- Standard global shuffle over N×N patches;
|
| 79 |
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
|
| 80 |
|
| 81 |
+
When overlap_pct > 0:
|
| 82 |
+
- Each base cell (N×N grid) expands by ±overlap_px (derived from percentage),
|
| 83 |
+
clipped to the image. Patches are mixed per location with per-patch alpha.
|
| 84 |
+
- Patches are blended into the canvas with a feather mask of size overlap_px.
|
|
|
|
| 85 |
"""
|
| 86 |
# Normalize to PIL and ensure divisibility
|
| 87 |
if isinstance(image, np.ndarray):
|
|
|
|
| 97 |
ph = H // N
|
| 98 |
pw = W // N
|
| 99 |
|
| 100 |
+
# Convert percentage to pixel overlap; clamp to < half patch size
|
| 101 |
+
pct = float(overlap_pct)
|
| 102 |
+
pct = max(0.0, min(pct, 49.0)) # keep below 50% for stability
|
| 103 |
+
overlap_px = int(round((pct / 100.0) * min(ph, pw)))
|
| 104 |
max_ov = max(0, min(ph, pw) // 2 - 1)
|
| 105 |
ov = int(np.clip(overlap_px, 0, max_ov))
|
| 106 |
|
|
|
|
| 164 |
total = len(patches)
|
| 165 |
perm = rng.permutation(total)
|
| 166 |
|
|
|
|
| 167 |
def sample_alpha():
|
| 168 |
if beta_a > 0 and beta_b > 0:
|
| 169 |
return float(rng.beta(beta_a, beta_b))
|
|
|
|
| 174 |
|
| 175 |
for k, (sh, eh, sw, ew) in enumerate(coords):
|
| 176 |
if rng.random() >= float(mix_prob):
|
| 177 |
+
patch = patches[k]
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
lam = sample_alpha()
|
| 180 |
src = patches[k].astype(np.float32)
|
|
|
|
| 182 |
patch = lam * shf + (1.0 - lam) * src
|
| 183 |
|
| 184 |
ph_k, pw_k, _ = patch.shape
|
|
|
|
| 185 |
mask2d = feather_full[:ph_k, :pw_k]
|
| 186 |
+
mask3d = mask2d[..., None] if arr.shape[2] == 1 else np.repeat(mask2d[..., None], arr.shape[2], axis=2)
|
|
|
|
|
|
|
|
|
|
| 187 |
|
|
|
|
| 188 |
canvas[sh:eh, sw:ew] += patch * mask3d
|
| 189 |
weight[sh:eh, sw:ew] += mask2d
|
| 190 |
|
|
|
|
| 191 |
weight = np.clip(weight, 1e-8, None)
|
| 192 |
out = (canvas / weight[..., None])
|
| 193 |
out = np.clip(out, 0, 255).astype(np.uint8)
|