File size: 1,423 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | """I-JEPA multi-block masking for 1D token sequences (Weimann-style)."""
from __future__ import annotations
import torch
def multi_block_mask_1d(
n_tokens: int,
n_targets: int = 4,
target_size_range: tuple[int, int] = (4, 8),
mask_ratio: float = 0.5,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (context_idx, target_idx) for one sequence.
Chooses `n_targets` contiguous blocks as targets (no overlap), then the
complement of their union is the context. mask_ratio caps total target
fraction.
"""
target_mask = torch.zeros(n_tokens, dtype=torch.bool)
max_cover = int(mask_ratio * n_tokens)
covered = 0
attempts = 0
while covered < max_cover and attempts < 64:
attempts += 1
lo, hi = target_size_range
size = int(torch.randint(lo, hi + 1, (1,), generator=generator).item())
size = min(size, max_cover - covered)
if size <= 0:
break
start = int(torch.randint(0, max(1, n_tokens - size + 1), (1,), generator=generator).item())
if target_mask[start : start + size].any():
continue
target_mask[start : start + size] = True
covered += size
target_idx = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
context_idx = torch.nonzero(~target_mask, as_tuple=False).squeeze(-1)
return context_idx, target_idx
|