"""I-JEPA multi-block masking for 1D token sequences (Weimann-style).""" from __future__ import annotations import torch def multi_block_mask_1d( n_tokens: int, n_targets: int = 4, target_size_range: tuple[int, int] = (4, 8), mask_ratio: float = 0.5, generator: torch.Generator | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Return (context_idx, target_idx) for one sequence. Chooses `n_targets` contiguous blocks as targets (no overlap), then the complement of their union is the context. mask_ratio caps total target fraction. """ target_mask = torch.zeros(n_tokens, dtype=torch.bool) max_cover = int(mask_ratio * n_tokens) covered = 0 attempts = 0 while covered < max_cover and attempts < 64: attempts += 1 lo, hi = target_size_range size = int(torch.randint(lo, hi + 1, (1,), generator=generator).item()) size = min(size, max_cover - covered) if size <= 0: break start = int(torch.randint(0, max(1, n_tokens - size + 1), (1,), generator=generator).item()) if target_mask[start : start + size].any(): continue target_mask[start : start + size] = True covered += size target_idx = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) context_idx = torch.nonzero(~target_mask, as_tuple=False).squeeze(-1) return context_idx, target_idx