File size: 4,175 Bytes
5cbc675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# MIXING AUGMENTATIONS
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

def alphamix_data(x, y, alpha_range=(0.3, 0.7), spatial_ratio=0.25):
    """
    Standard AlphaMix: Single spatially localized transparent overlay.
    """
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    
    y_a, y_b = y, y[index]
    
    # Sample alpha from Beta distribution
    alpha_min, alpha_max = alpha_range
    beta_sample = torch.distributions.Beta(2.0, 2.0).sample().item()
    alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
    
    # Compute overlay region
    _, _, H, W = x.shape
    overlay_ratio = torch.sqrt(torch.tensor(spatial_ratio)).item()
    overlay_h = int(H * overlay_ratio)
    overlay_w = int(W * overlay_ratio)
    
    top = torch.randint(0, H - overlay_h + 1, (1,), device=x.device).item()
    left = torch.randint(0, W - overlay_w + 1, (1,), device=x.device).item()
    
    # Blend
    composited_x = x.clone()
    overlay_region = alpha * x[:, :, top:top+overlay_h, left:left+overlay_w]
    background_region = (1 - alpha) * x[index, :, top:top+overlay_h, left:left+overlay_w]
    composited_x[:, :, top:top+overlay_h, left:left+overlay_w] = overlay_region + background_region
    
    return composited_x, y_a, y_b, alpha


def alphamix_fractal(
    x: torch.Tensor,
    y: torch.Tensor,
    alpha_range=(0.3, 0.7),
    steps_range=(1, 3),
    triad_scales=(1/3, 1/9, 1/27),
    beta_shape=(2.0, 2.0),
    seed: int | None = None,
):
    """
    Fractal AlphaMix: Triadic multi-patch overlays aligned to Cantor geometry.
    Pure torch, GPU-compatible.
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    B, C, H, W = x.shape
    device = x.device
    
    # Permutation for mixing
    idx = torch.randperm(B, device=device)
    y_a, y_b = y, y[idx]
    
    x_mix = x.clone()
    total_area = H * W
    
    # Beta distribution for transparency sampling
    k1, k2 = beta_shape
    beta_dist = torch.distributions.Beta(k1, k2)
    alpha_min, alpha_max = alpha_range
    
    # Storage for effective alpha calculation
    alpha_elems = []
    area_weights = []
    
    # Sample number of patches (same for all images in batch)
    steps = torch.randint(steps_range[0], steps_range[1] + 1, (1,), device=device).item()
    
    for _ in range(steps):
        # Choose triadic scale
        scale_idx = torch.randint(0, len(triad_scales), (1,), device=device).item()
        scale = triad_scales[scale_idx]
        
        # Compute patch dimensions (triadic area)
        patch_area = max(1, int(total_area * scale))
        side = int(torch.sqrt(torch.tensor(patch_area, dtype=torch.float32)).item())
        h = max(1, min(H, side))
        w = max(1, min(W, side))
        
        # Random position
        top = torch.randint(0, H - h + 1, (1,), device=device).item()
        left = torch.randint(0, W - w + 1, (1,), device=device).item()
        
        # Sample transparency from Beta distribution
        alpha_raw = beta_dist.sample().item()
        alpha = alpha_min + (alpha_max - alpha_min) * alpha_raw
        
        # Track for effective alpha
        alpha_elems.append(alpha)
        area_weights.append(h * w)
        
        # Blend patches
        fg = alpha * x[:, :, top:top + h, left:left + w]
        bg = (1 - alpha) * x[idx, :, top:top + h, left:left + w]
        x_mix[:, :, top:top + h, left:left + w] = fg + bg
    
    # Compute area-weighted effective alpha
    alpha_t = torch.tensor(alpha_elems, dtype=torch.float32, device=device)
    area_t = torch.tensor(area_weights, dtype=torch.float32, device=device)
    alpha_eff = (alpha_t * area_t).sum() / (area_t.sum() + 1e-12)
    alpha_eff = alpha_eff.item()
    
    return x_mix, y_a, y_b, alpha_eff