| """I-JEPA multi-block masking for 1D token sequences (Weimann-style).""" |
| from __future__ import annotations |
|
|
| import torch |
|
|
|
|
| def multi_block_mask_1d( |
| n_tokens: int, |
| n_targets: int = 4, |
| target_size_range: tuple[int, int] = (4, 8), |
| mask_ratio: float = 0.5, |
| generator: torch.Generator | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return (context_idx, target_idx) for one sequence. |
| |
| Chooses `n_targets` contiguous blocks as targets (no overlap), then the |
| complement of their union is the context. mask_ratio caps total target |
| fraction. |
| """ |
| target_mask = torch.zeros(n_tokens, dtype=torch.bool) |
| max_cover = int(mask_ratio * n_tokens) |
| covered = 0 |
| attempts = 0 |
| while covered < max_cover and attempts < 64: |
| attempts += 1 |
| lo, hi = target_size_range |
| size = int(torch.randint(lo, hi + 1, (1,), generator=generator).item()) |
| size = min(size, max_cover - covered) |
| if size <= 0: |
| break |
| start = int(torch.randint(0, max(1, n_tokens - size + 1), (1,), generator=generator).item()) |
| if target_mask[start : start + size].any(): |
| continue |
| target_mask[start : start + size] = True |
| covered += size |
|
|
| target_idx = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) |
| context_idx = torch.nonzero(~target_mask, as_tuple=False).squeeze(-1) |
| return context_idx, target_idx |
|
|