File size: 4,175 Bytes
5cbc675 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# MIXING AUGMENTATIONS
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def alphamix_data(x, y, alpha_range=(0.3, 0.7), spatial_ratio=0.25):
"""
Standard AlphaMix: Single spatially localized transparent overlay.
"""
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
y_a, y_b = y, y[index]
# Sample alpha from Beta distribution
alpha_min, alpha_max = alpha_range
beta_sample = torch.distributions.Beta(2.0, 2.0).sample().item()
alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
# Compute overlay region
_, _, H, W = x.shape
overlay_ratio = torch.sqrt(torch.tensor(spatial_ratio)).item()
overlay_h = int(H * overlay_ratio)
overlay_w = int(W * overlay_ratio)
top = torch.randint(0, H - overlay_h + 1, (1,), device=x.device).item()
left = torch.randint(0, W - overlay_w + 1, (1,), device=x.device).item()
# Blend
composited_x = x.clone()
overlay_region = alpha * x[:, :, top:top+overlay_h, left:left+overlay_w]
background_region = (1 - alpha) * x[index, :, top:top+overlay_h, left:left+overlay_w]
composited_x[:, :, top:top+overlay_h, left:left+overlay_w] = overlay_region + background_region
return composited_x, y_a, y_b, alpha
def alphamix_fractal(
x: torch.Tensor,
y: torch.Tensor,
alpha_range=(0.3, 0.7),
steps_range=(1, 3),
triad_scales=(1/3, 1/9, 1/27),
beta_shape=(2.0, 2.0),
seed: int | None = None,
):
"""
Fractal AlphaMix: Triadic multi-patch overlays aligned to Cantor geometry.
Pure torch, GPU-compatible.
"""
if seed is not None:
torch.manual_seed(seed)
B, C, H, W = x.shape
device = x.device
# Permutation for mixing
idx = torch.randperm(B, device=device)
y_a, y_b = y, y[idx]
x_mix = x.clone()
total_area = H * W
# Beta distribution for transparency sampling
k1, k2 = beta_shape
beta_dist = torch.distributions.Beta(k1, k2)
alpha_min, alpha_max = alpha_range
# Storage for effective alpha calculation
alpha_elems = []
area_weights = []
# Sample number of patches (same for all images in batch)
steps = torch.randint(steps_range[0], steps_range[1] + 1, (1,), device=device).item()
for _ in range(steps):
# Choose triadic scale
scale_idx = torch.randint(0, len(triad_scales), (1,), device=device).item()
scale = triad_scales[scale_idx]
# Compute patch dimensions (triadic area)
patch_area = max(1, int(total_area * scale))
side = int(torch.sqrt(torch.tensor(patch_area, dtype=torch.float32)).item())
h = max(1, min(H, side))
w = max(1, min(W, side))
# Random position
top = torch.randint(0, H - h + 1, (1,), device=device).item()
left = torch.randint(0, W - w + 1, (1,), device=device).item()
# Sample transparency from Beta distribution
alpha_raw = beta_dist.sample().item()
alpha = alpha_min + (alpha_max - alpha_min) * alpha_raw
# Track for effective alpha
alpha_elems.append(alpha)
area_weights.append(h * w)
# Blend patches
fg = alpha * x[:, :, top:top + h, left:left + w]
bg = (1 - alpha) * x[idx, :, top:top + h, left:left + w]
x_mix[:, :, top:top + h, left:left + w] = fg + bg
# Compute area-weighted effective alpha
alpha_t = torch.tensor(alpha_elems, dtype=torch.float32, device=device)
area_t = torch.tensor(area_weights, dtype=torch.float32, device=device)
alpha_eff = (alpha_t * area_t).sum() / (area_t.sum() + 1e-12)
alpha_eff = alpha_eff.item()
return x_mix, y_a, y_b, alpha_eff
|