| """This file contains the definition of utility functions for masking.""" | |
| import math | |
| from typing import Text, Tuple | |
| import torch | |
| def get_mask_tokens( | |
| tokens: torch.Tensor, | |
| mask_token: int, | |
| mode: Text = "arccos", | |
| min_masking_ratio: float = 0.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Get the masked tokens. | |
| Args: | |
| tokens -> torch.Tensor: The input tokens. | |
| mask_token -> int: The special `mask` token. | |
| mode -> Text: The masking function to use (default: "arccos"). | |
| Returns: | |
| masked_tokens -> torch.Tensor: The masked input tokens. Each masked token is set to mask_token. | |
| mask -> torch.Tensor: A boolean tensor mask indicating which tokens are masked. | |
| """ | |
| r = torch.rand(tokens.size(0)) * (1 - min_masking_ratio) | |
| if mode == "linear": | |
| val_to_mask = 1 - r | |
| elif mode == "square": | |
| val_to_mask = 1 - (r**2) | |
| elif mode == "cosine": | |
| val_to_mask = torch.cos(r * math.pi * 0.5) | |
| elif mode == "arccos": | |
| val_to_mask = torch.acos(r) / (math.pi * 0.5) | |
| else: | |
| raise ValueError( | |
| "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos'." | |
| ) | |
| masked_tokens = tokens.detach().clone() | |
| mask = torch.rand(tokens.size()) < val_to_mask.view(-1, 1, 1) | |
| masked_tokens[mask] = torch.full_like(masked_tokens[mask], mask_token) | |
| return masked_tokens, mask | |
| def get_masking_ratio(progress: float, mode: Text = "arccos") -> torch.Tensor: | |
| """Get masking ratio. | |
| Args: | |
| progress -> float: The percentage of iterations already done. | |
| mode -> Text: The masking function to use (default: "arccos"). | |
| Returns: | |
| val_to_mask -> torch.Tensor: The masking ratio. | |
| """ | |
| r = torch.tensor(progress) | |
| if mode == "root": | |
| val_to_mask = 1 - (r**0.5) | |
| elif mode == "square": | |
| val_to_mask = 1 - (r**2) | |
| elif mode == "cosine": | |
| val_to_mask = torch.cos(r * math.pi * 0.5) | |
| elif mode == "arccos": | |
| val_to_mask = torch.acos(r) / (math.pi * 0.5) | |
| elif mode == "linear": | |
| val_to_mask = 1 - r | |
| else: | |
| raise ValueError( | |
| "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos', 'root'." | |
| ) | |
| val_to_mask = torch.clamp(val_to_mask, 1e-6, 1.0) | |
| return val_to_mask | |