PhysioJEPA / src /physiojepa /masking.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""I-JEPA multi-block masking for 1D token sequences (Weimann-style)."""
from __future__ import annotations
import torch
def multi_block_mask_1d(
n_tokens: int,
n_targets: int = 4,
target_size_range: tuple[int, int] = (4, 8),
mask_ratio: float = 0.5,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (context_idx, target_idx) for one sequence.
Chooses `n_targets` contiguous blocks as targets (no overlap), then the
complement of their union is the context. mask_ratio caps total target
fraction.
"""
target_mask = torch.zeros(n_tokens, dtype=torch.bool)
max_cover = int(mask_ratio * n_tokens)
covered = 0
attempts = 0
while covered < max_cover and attempts < 64:
attempts += 1
lo, hi = target_size_range
size = int(torch.randint(lo, hi + 1, (1,), generator=generator).item())
size = min(size, max_cover - covered)
if size <= 0:
break
start = int(torch.randint(0, max(1, n_tokens - size + 1), (1,), generator=generator).item())
if target_mask[start : start + size].any():
continue
target_mask[start : start + size] = True
covered += size
target_idx = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
context_idx = torch.nonzero(~target_mask, as_tuple=False).squeeze(-1)
return context_idx, target_idx