File size: 1,423 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""I-JEPA multi-block masking for 1D token sequences (Weimann-style)."""
from __future__ import annotations

import torch


def multi_block_mask_1d(
    n_tokens: int,
    n_targets: int = 4,
    target_size_range: tuple[int, int] = (4, 8),
    mask_ratio: float = 0.5,
    generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Return (context_idx, target_idx) for one sequence.

    Chooses `n_targets` contiguous blocks as targets (no overlap), then the
    complement of their union is the context. mask_ratio caps total target
    fraction.
    """
    target_mask = torch.zeros(n_tokens, dtype=torch.bool)
    max_cover = int(mask_ratio * n_tokens)
    covered = 0
    attempts = 0
    while covered < max_cover and attempts < 64:
        attempts += 1
        lo, hi = target_size_range
        size = int(torch.randint(lo, hi + 1, (1,), generator=generator).item())
        size = min(size, max_cover - covered)
        if size <= 0:
            break
        start = int(torch.randint(0, max(1, n_tokens - size + 1), (1,), generator=generator).item())
        if target_mask[start : start + size].any():
            continue
        target_mask[start : start + size] = True
        covered += size

    target_idx = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
    context_idx = torch.nonzero(~target_mask, as_tuple=False).squeeze(-1)
    return context_idx, target_idx