"""Mask helpers for padding-safe model code.""" from __future__ import annotations import torch def make_attention_mask(input_ids: torch.Tensor, attention_mask: torch.Tensor | None, pad_token_id: int) -> torch.Tensor: if attention_mask is None: return input_ids.ne(pad_token_id) return attention_mask.to(dtype=torch.bool, device=input_ids.device) def masked_hidden(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return hidden_states * attention_mask.to(hidden_states.dtype).unsqueeze(-1) def reverse_valid(x: torch.Tensor, attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return torch.flip(x, dims=[1]), torch.flip(attention_mask, dims=[1]) def segment_reset_mask(segment_ids: torch.Tensor | None, attention_mask: torch.Tensor) -> torch.Tensor | None: if segment_ids is None: return None reset = torch.zeros_like(attention_mask, dtype=torch.bool) reset[:, 0] = True reset[:, 1:] = segment_ids[:, 1:] != segment_ids[:, :-1] return reset & attention_mask