dplotnikov's picture
Upload StrataBERT diagnostic checkpoint
c30089a verified
Raw
History Blame Contribute Delete
1.06 kB
"""Mask helpers for padding-safe model code."""
from __future__ import annotations
import torch
def make_attention_mask(input_ids: torch.Tensor, attention_mask: torch.Tensor | None, pad_token_id: int) -> torch.Tensor:
if attention_mask is None:
return input_ids.ne(pad_token_id)
return attention_mask.to(dtype=torch.bool, device=input_ids.device)
def masked_hidden(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
return hidden_states * attention_mask.to(hidden_states.dtype).unsqueeze(-1)
def reverse_valid(x: torch.Tensor, attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return torch.flip(x, dims=[1]), torch.flip(attention_mask, dims=[1])
def segment_reset_mask(segment_ids: torch.Tensor | None, attention_mask: torch.Tensor) -> torch.Tensor | None:
if segment_ids is None:
return None
reset = torch.zeros_like(attention_mask, dtype=torch.bool)
reset[:, 0] = True
reset[:, 1:] = segment_ids[:, 1:] != segment_ids[:, :-1]
return reset & attention_mask