| from typing import Optional, Tuple, Union, Dict, Literal, Callable |
|
|
| import math |
| import os |
| from contextlib import nullcontext |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.utils import ModelOutput, logging |
| from transformers.activations import ACT2FN |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils.doc import add_code_sample_docstrings, add_start_docstrings |
| from transformers.utils.import_utils import is_triton_available, is_flash_attn_2_available |
| from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput, MaskedLMOutput |
|
|
| try: |
| from transformers.utils.hub import cached_file |
| except ImportError: |
| from transformers.utils import cached_file |
|
|
| try: |
| from transformers.modeling_rope_utils import RopeParameters |
| except ImportError: |
| RopeParameters = object |
|
|
|
|
| if is_flash_attn_2_available(): |
| from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func |
| from flash_attn.layers.rotary import RotaryEmbedding |
| from flash_attn.ops.triton.rotary import apply_rotary |
| else: |
| RotaryEmbedding = object |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| _HF_CONFIG_LOAD_KWARGS = { |
| "cache_dir", |
| "force_download", |
| "local_files_only", |
| "token", |
| "revision", |
| "subfolder", |
| "proxies", |
| } |
|
|
| _HF_NON_MODEL_INIT_KWARGS = { |
| "trust_remote_code", |
| "_from_auto", |
| "adapter_kwargs", |
| } |
|
|
| _HF_MODEL_INIT_BLACKLIST = { |
| "device_map", |
| "low_cpu_mem_usage", |
| "offload_folder", |
| "offload_state_dict", |
| "max_memory", |
| "quantization_config", |
| "tp_plan", |
| "tp_size", |
| "weights_only", |
| |
| |
| |
| } |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def _split_pretrained_kwargs(kwargs): |
| kwargs = dict(kwargs) |
|
|
| for k in _HF_NON_MODEL_INIT_KWARGS: |
| kwargs.pop(k, None) |
|
|
| config_load_kwargs = { |
| k: kwargs.pop(k) for k in list(kwargs) if k in _HF_CONFIG_LOAD_KWARGS |
| } |
|
|
| use_safetensors = kwargs.pop("use_safetensors", None) |
| weights_only = kwargs.pop("weights_only", True) |
| return config_load_kwargs, use_safetensors, weights_only, kwargs |
|
|
|
|
| def _resolve_weights_file(pretrained_model_name_or_path, use_safetensors=None, **load_kwargs) -> str: |
| pretrained_model_name_or_path = os.fspath(pretrained_model_name_or_path) |
|
|
| if use_safetensors is True: |
| candidates = ("model.safetensors",) |
| elif use_safetensors is False: |
| candidates = ("pytorch_model.bin",) |
| else: |
| candidates = ("model.safetensors", "pytorch_model.bin") |
|
|
| subfolder = load_kwargs.get("subfolder") |
|
|
| if os.path.isdir(pretrained_model_name_or_path): |
| base_dir = ( |
| os.path.join(pretrained_model_name_or_path, subfolder) |
| if subfolder |
| else pretrained_model_name_or_path |
| ) |
| for name in candidates: |
| path = os.path.join(base_dir, name) |
| if os.path.exists(path): |
| return path |
|
|
| for name in candidates: |
| try: |
| path = cached_file(pretrained_model_name_or_path, name, **load_kwargs) |
| if path is not None: |
| return path |
| except Exception: |
| pass |
|
|
| raise FileNotFoundError( |
| f"No checkpoint file found in {pretrained_model_name_or_path!r} " |
| f"(candidates: {', '.join(candidates)})" |
| ) |
|
|
|
|
| def _read_state_dict(weights_path, weights_only: bool = True) -> dict[str, torch.Tensor]: |
| weights_path = os.fspath(weights_path) |
|
|
| if weights_path.endswith(".safetensors"): |
| from safetensors.torch import load_file as safe_load_file |
|
|
| return safe_load_file(weights_path, device="cpu") |
|
|
| try: |
| return torch.load(weights_path, map_location="cpu", weights_only=weights_only) |
| except TypeError: |
| return torch.load(weights_path, map_location="cpu") |
|
|
|
|
| def _align_state_dict_with_base_prefix(model: nn.Module, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| state_dict = dict(state_dict) |
| base_prefix = getattr(model, "base_model_prefix", None) |
| if not base_prefix: |
| return state_dict |
|
|
| prefix = f"{base_prefix}." |
| model_keys = set(model.state_dict().keys()) |
| model_has_prefix = any(k.startswith(prefix) for k in model_keys) |
| ckpt_has_prefix = any(k.startswith(prefix) for k in state_dict.keys()) |
|
|
| |
| if ckpt_has_prefix and not model_has_prefix: |
| stripped = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} |
| return stripped if stripped else state_dict |
|
|
| |
| if model_has_prefix and not ckpt_has_prefix: |
| remapped = {} |
| for k, v in state_dict.items(): |
| prefixed = f"{prefix}{k}" |
| remapped[prefixed if prefixed in model_keys else k] = v |
| return remapped |
|
|
| return state_dict |
|
|
|
|
| class _SafeFromPretrainedMixin: |
| @classmethod |
| def _adapt_state_dict(cls, model, state_dict): |
| return state_dict |
|
|
| @staticmethod |
| def _filter_keys_with_patterns(keys, patterns): |
| if not patterns: |
| return list(keys) |
|
|
| import re |
|
|
| compiled = [re.compile(p) if isinstance(p, str) else p for p in patterns] |
| return [k for k in keys if not any(p.search(k) for p in compiled)] |
|
|
| @classmethod |
| def _resolve_config_and_init_kwargs( |
| cls, |
| pretrained_model_name_or_path, |
| config, |
| config_load_kwargs, |
| other_kwargs, |
| ): |
| if isinstance(config, PretrainedConfig): |
| return config, other_kwargs |
|
|
| if config is None: |
| config_source = pretrained_model_name_or_path |
| elif isinstance(config, (str, os.PathLike)): |
| config_source = config |
| else: |
| raise TypeError( |
| "`config` must be None, a path-like object, or an instance of PretrainedConfig" |
| ) |
|
|
| config, init_kwargs = cls.config_class.from_pretrained( |
| config_source, |
| return_unused_kwargs=True, |
| **config_load_kwargs, |
| **other_kwargs, |
| ) |
| return config, init_kwargs |
|
|
| @staticmethod |
| def _remove_mismatched_keys(model, state_dict): |
| state_dict = dict(state_dict) |
| model_state = model.state_dict() |
| mismatched_keys = [] |
|
|
| for key in list(state_dict.keys()): |
| if key not in model_state: |
| continue |
|
|
| loaded_value = state_dict[key] |
| model_value = model_state[key] |
|
|
| if not isinstance(loaded_value, torch.Tensor) or not isinstance(model_value, torch.Tensor): |
| continue |
|
|
| if tuple(loaded_value.shape) != tuple(model_value.shape): |
| mismatched_keys.append((key, tuple(loaded_value.shape), tuple(model_value.shape))) |
| state_dict.pop(key) |
|
|
| return state_dict, mismatched_keys |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| output_loading_info = kwargs.pop("output_loading_info", False) |
| state_dict = kwargs.pop("state_dict", None) |
| config = kwargs.pop("config", None) |
| ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) |
| strict = kwargs.pop("strict", False) |
|
|
| config_load_kwargs, use_safetensors, weights_only, other_kwargs = _split_pretrained_kwargs(kwargs) |
|
|
| config, init_kwargs = cls._resolve_config_and_init_kwargs( |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| config=config, |
| config_load_kwargs=config_load_kwargs, |
| other_kwargs=other_kwargs, |
| ) |
| ''' |
| config = cls._autoset_attn_implementation( |
| config, |
| use_flash_attention_2=bool(init_kwargs.get("use_flash_attention_2", False)), |
| torch_dtype=init_kwargs.get("torch_dtype", None), |
| device_map=init_kwargs.get("device_map", None), |
| ) |
| ''' |
|
|
| for k in _HF_MODEL_INIT_BLACKLIST: |
| init_kwargs.pop(k, None) |
|
|
| model = cls(config, *model_args, **init_kwargs) |
|
|
| if state_dict is None: |
| weights_path = _resolve_weights_file( |
| pretrained_model_name_or_path, |
| use_safetensors=use_safetensors, |
| **config_load_kwargs, |
| ) |
| state_dict = _read_state_dict( |
| weights_path, |
| weights_only=True if weights_only is None else bool(weights_only), |
| ) |
|
|
| if not isinstance(state_dict, dict): |
| raise TypeError(f"Expected a state dict, got {type(state_dict).__name__}") |
|
|
| state_dict = _align_state_dict_with_base_prefix(model, state_dict) |
| state_dict = cls._adapt_state_dict(model, state_dict) |
|
|
| mismatched_keys = [] |
| if ignore_mismatched_sizes: |
| state_dict, mismatched_keys = cls._remove_mismatched_keys(model, state_dict) |
|
|
| incompatible = model.load_state_dict(state_dict, strict=strict) |
|
|
| if hasattr(model, "tie_weights"): |
| model.tie_weights() |
|
|
| model.eval() |
|
|
| missing_keys = cls._filter_keys_with_patterns( |
| list(incompatible.missing_keys), |
| getattr(model, "_keys_to_ignore_on_load_missing", None), |
| ) |
| unexpected_keys = cls._filter_keys_with_patterns( |
| list(incompatible.unexpected_keys), |
| getattr(model, "_keys_to_ignore_on_load_unexpected", None), |
| ) |
|
|
| info = { |
| "missing_keys": missing_keys, |
| "unexpected_keys": unexpected_keys, |
| "mismatched_keys": mismatched_keys, |
| "error_msgs": [], |
| } |
| return (model, info) if output_loading_info else model |
|
|
|
|
|
|
| def l2_norm(input, axis=1, epsilon=1e-12): |
| norm = torch.norm(input, 2, axis, True) |
| norm = torch.clamp(norm, min=epsilon) |
| output = torch.div(input, norm) |
| return output |
|
|
|
|
| def initialize_linear_kaiming(layer: nn.Linear): |
| if isinstance(layer, nn.Linear): |
| nn.init.kaiming_uniform_(layer.weight, nonlinearity='linear') |
| if layer.bias is not None: |
| nn.init.zeros_(layer.bias) |
|
|
|
|
| def _autocast_disabled(device_type: str): |
| try: |
| return torch.amp.autocast(device_type=device_type, enabled=False) |
| except (AttributeError, TypeError): |
| if device_type == "cuda": |
| return torch.cuda.amp.autocast(enabled=False) |
| if device_type == "cpu" and hasattr(torch, "cpu") and hasattr(torch.cpu, "amp"): |
| return torch.cpu.amp.autocast(enabled=False) |
| return nullcontext() |
|
|
|
|
| def _normalize_token_attention_mask( |
| attention_mask: Optional[torch.Tensor], |
| *, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| device: Optional[torch.device] = None, |
| ) -> Optional[torch.Tensor]: |
| """Convert common attention-mask layouts to a `(batch, seq_len)` boolean keep-mask. |
| |
| Supported inputs are 2D token masks, 3D masks of shape `(batch, 1, seq_len)`, and 4D masks |
| of shape `(batch, 1, query_len, key_len)` used by eager/SDPA attention. Returned values use |
| `True` for valid (non-padding) tokens. |
| """ |
| if attention_mask is None: |
| if batch_size is None or seq_len is None: |
| return None |
| return torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| mask = attention_mask |
|
|
| if mask.dim() == 4: |
| if mask.dtype == torch.bool: |
| mask = mask[:, 0].any(dim=-2) |
| else: |
| mask = mask[:, 0].amax(dim=-2) >= 0 |
| return mask.to(dtype=torch.bool) |
|
|
| if mask.dim() == 3: |
| if mask.size(1) != 1: |
| raise ValueError( |
| "3D attention masks must have shape (batch, 1, seq_len) when used in ProkBert." |
| ) |
| mask = mask[:, 0, :] |
|
|
| if mask.dim() != 2: |
| raise ValueError( |
| "Attention masks for ProkBert must be 2D, 3D `(batch, 1, seq_len)`, or 4D " |
| "`(batch, 1, query_len, key_len)`." |
| ) |
|
|
| if mask.dtype == torch.bool: |
| return mask |
|
|
| if mask.is_floating_point(): |
| if mask.numel() == 0: |
| return mask.to(dtype=torch.bool) |
| return (mask >= 0) if torch.any(mask < 0) else (mask > 0) |
|
|
| return mask != 0 |
|
|
|
|
| def _to_additive_attention_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
| """Convert a boolean keep-mask to the additive attention-bias format expected by eager attention.""" |
| min_dtype = torch.finfo(dtype).min |
| bias = torch.zeros(mask.shape, device=mask.device, dtype=dtype) |
| return bias.masked_fill(~mask, min_dtype) |
|
|
|
|
| def _build_bidirectional_attention_biases( |
| attention_mask: torch.Tensor, |
| dtype: torch.dtype, |
| sliding_window: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Create full-attention and sliding-window additive masks for eager/SDPA attention.""" |
| if attention_mask.dim() == 4: |
| if attention_mask.dtype == torch.bool: |
| global_attention_mask = _to_additive_attention_bias(attention_mask, dtype) |
| else: |
| global_attention_mask = attention_mask.to(dtype=dtype) |
| query_length = global_attention_mask.shape[-2] |
| key_length = global_attention_mask.shape[-1] |
| else: |
| token_mask = _normalize_token_attention_mask(attention_mask) |
| if token_mask is None: |
| raise ValueError("`attention_mask` cannot be None when creating additive attention biases.") |
| query_length = token_mask.shape[-1] |
| key_length = token_mask.shape[-1] |
| expanded_token_mask = token_mask[:, None, None, :].expand(-1, 1, query_length, -1) |
| global_attention_mask = _to_additive_attention_bias(expanded_token_mask, dtype) |
|
|
| row_positions = torch.arange(query_length, device=global_attention_mask.device)[:, None] |
| col_positions = torch.arange(key_length, device=global_attention_mask.device)[None, :] |
| local_window = (row_positions - col_positions).abs() <= sliding_window |
| min_dtype = torch.finfo(dtype).min |
| sliding_window_mask = global_attention_mask.masked_fill( |
| ~local_window.view(1, 1, query_length, key_length), |
| min_dtype, |
| ) |
| return global_attention_mask, sliding_window_mask |
|
|
|
|
| class ProkBertConfig(PretrainedConfig): |
| r""" |
| Configuration for the standalone ProkBERT ModernBERT-style encoder stack. |
| |
| Canonical config names follow the current ModernBERT conventions: |
| - `layer_types` instead of `global_attn_every_n_layers` |
| - nested `rope_parameters` keyed by attention type instead of a single flat RoPE dict |
| - `tie_word_embeddings` explicitly tracked in the config |
| |
| `classifier_pooling` is also standardized as a single field and extended with the custom |
| `"attention"` option used by the standalone ProkBERT sequence-classification head. |
| |
| The tokenizer metadata fields `kmer` and `shift` are also stored on the config so the model artifact keeps |
| the sequence-tokenization contract alongside the architectural settings. |
| |
| Legacy names are still accepted for backward compatibility when loading older checkpoints/configs. |
| """ |
|
|
| model_type = "prokbert" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| default_theta = {"full_attention": 160_000.0, "sliding_attention": 10_000.0} |
|
|
| attribute_map = { |
| "classification_dropout_rate": "classifier_dropout", |
| "num_class_labels": "num_labels", |
| "curricular_num_labels": "num_labels", |
| "curricular_face_m": "curricular_margin", |
| "curricular_face_s": "curricular_scale", |
| "curriculum_hidden_size": "curricular_embedding_size", |
| } |
|
|
|
|
| @classmethod |
| def _build_layer_types( |
| cls, |
| num_hidden_layers: int, |
| global_attn_every_n_layers: int, |
| ) -> list[str]: |
| if global_attn_every_n_layers <= 0: |
| raise ValueError("`global_attn_every_n_layers` must be a positive integer.") |
| return [ |
| "sliding_attention" if bool(i % global_attn_every_n_layers) else "full_attention" |
| for i in range(num_hidden_layers) |
| ] |
|
|
| @classmethod |
| def _normalize_rope_parameters( |
| cls, |
| rope_parameters: RopeParameters | dict | None, |
| global_rope_theta: float, |
| local_rope_theta: float, |
| ) -> dict[str, dict]: |
| default_rope_parameters = { |
| "full_attention": {"rope_type": "default", "rope_theta": float(global_rope_theta)}, |
| "sliding_attention": {"rope_type": "default", "rope_theta": float(local_rope_theta)}, |
| } |
|
|
| if rope_parameters is None: |
| return default_rope_parameters |
|
|
| rope_parameters = dict(rope_parameters) |
|
|
| |
| if "rope_type" in rope_parameters or "rope_theta" in rope_parameters: |
| shared_rope_type = rope_parameters.get("rope_type", "default") |
| full_theta = float(rope_parameters.get("rope_theta", global_rope_theta)) |
| return { |
| "full_attention": { |
| **{k: v for k, v in rope_parameters.items() if k != "rope_theta"}, |
| "rope_type": shared_rope_type, |
| "rope_theta": full_theta, |
| }, |
| "sliding_attention": { |
| **{k: v for k, v in rope_parameters.items() if k != "rope_theta"}, |
| "rope_type": shared_rope_type, |
| "rope_theta": float(local_rope_theta), |
| }, |
| } |
|
|
| normalized_rope_parameters = {} |
| for layer_type in ("full_attention", "sliding_attention"): |
| layer_params = rope_parameters.get(layer_type) |
| if layer_params is None: |
| layer_params = {"rope_type": "default"} |
| else: |
| layer_params = dict(layer_params) |
| layer_params.setdefault("rope_type", "default") |
| layer_params.setdefault("rope_theta", cls.default_theta[layer_type]) |
| normalized_rope_parameters[layer_type] = layer_params |
|
|
| return normalized_rope_parameters |
|
|
| def __init__( |
| self, |
| vocab_size: int = 4608, |
| hidden_size: int = 384, |
| intermediate_size: int = 1152, |
| num_hidden_layers: int = 6, |
| num_attention_heads: int = 6, |
| hidden_activation: str = "gelu", |
| max_position_embeddings: int = 16384, |
| initializer_range: float = 0.02, |
| initializer_cutoff_factor: float = 2.0, |
| norm_eps: float = 1e-6, |
| norm_bias: bool = False, |
| kmer: int = 6, |
| shift: int = 1, |
| pad_token_id: int = 0, |
| eos_token_id: int = 3, |
| bos_token_id: int = 2, |
| cls_token_id: int = 2, |
| sep_token_id: int = 3, |
| attention_bias: bool = False, |
| attention_dropout: float = 0.0, |
| layer_types: list[str] | None = None, |
| rope_parameters: RopeParameters | dict | None = None, |
| local_attention: int = 256, |
| embedding_dropout: float = 0.0, |
| mlp_bias: bool = False, |
| mlp_dropout: float = 0.0, |
| decoder_bias: bool = True, |
| classifier_pooling: Literal["attention", "cls", "mean"] = "attention", |
| classifier_dropout: float = 0.0, |
| classifier_bias: bool = False, |
| classifier_activation: str = "gelu", |
| deterministic_flash_attn: bool = False, |
| sparse_prediction: bool = False, |
| sparse_pred_ignore_index: int = -100, |
| reference_compile: bool | None = None, |
| repad_logits_with_grad: bool = False, |
| norm_type: str = "rms", |
| tie_word_embeddings: bool = True, |
| num_labels: int = 2, |
| problem_type: str | None = None, |
| curricular_margin: float = 0.5, |
| curricular_scale: float = 64.0, |
| curricular_embedding_size: int | None = None, |
| **kwargs, |
| ): |
| legacy_global_attn_every_n_layers = int(kwargs.pop("global_attn_every_n_layers", 1)) |
| legacy_global_rope_theta = float(kwargs.pop("global_rope_theta", self.default_theta["full_attention"])) |
| legacy_local_rope_theta = float(kwargs.pop("local_rope_theta", self.default_theta["sliding_attention"])) |
|
|
| legacy_num_class_labels = kwargs.pop("num_class_labels", None) |
| legacy_curricular_num_labels = kwargs.pop("curricular_num_labels", None) |
| legacy_classifier_dropout = kwargs.pop("classification_dropout_rate", None) |
| legacy_curricular_margin = kwargs.pop("curricular_face_m", None) |
| legacy_curricular_scale = kwargs.pop("curricular_face_s", None) |
| legacy_curricular_embedding_size = kwargs.pop("curriculum_hidden_size", None) |
| kwargs.pop("bert_base_model", None) |
|
|
| loaded_id2label = kwargs.get("id2label") |
| if loaded_id2label is not None: |
| num_labels = len(loaded_id2label) |
| elif legacy_curricular_num_labels is not None: |
| num_labels = int(legacy_curricular_num_labels) |
| elif legacy_num_class_labels is not None: |
| num_labels = int(legacy_num_class_labels) |
|
|
| if legacy_classifier_dropout is not None: |
| classifier_dropout = float(legacy_classifier_dropout) |
| if legacy_curricular_margin is not None: |
| curricular_margin = float(legacy_curricular_margin) |
| if legacy_curricular_scale is not None: |
| curricular_scale = float(legacy_curricular_scale) |
| if curricular_embedding_size is None and legacy_curricular_embedding_size not in (None, -1): |
| curricular_embedding_size = int(legacy_curricular_embedding_size) |
|
|
| if layer_types is None: |
| layer_types = self._build_layer_types( |
| num_hidden_layers=num_hidden_layers, |
| global_attn_every_n_layers=legacy_global_attn_every_n_layers, |
| ) |
| else: |
| layer_types = list(layer_types) |
|
|
| rope_parameters = self._normalize_rope_parameters( |
| rope_parameters=rope_parameters, |
| global_rope_theta=legacy_global_rope_theta, |
| local_rope_theta=legacy_local_rope_theta, |
| ) |
|
|
| super().__init__( |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| cls_token_id=cls_token_id, |
| sep_token_id=sep_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| num_labels=num_labels, |
| problem_type=problem_type, |
| **kwargs, |
| ) |
|
|
| self.kmer = int(kmer) |
| self.shift = int(shift) |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.hidden_activation = hidden_activation |
| self.initializer_range = initializer_range |
| self.initializer_cutoff_factor = initializer_cutoff_factor |
| self.norm_eps = norm_eps |
| self.norm_bias = norm_bias |
| self.attention_bias = attention_bias |
| self.attention_dropout = attention_dropout |
| self.layer_types = layer_types |
| self.rope_parameters = rope_parameters |
| self.local_attention = local_attention |
| self.embedding_dropout = embedding_dropout |
| self.mlp_bias = mlp_bias |
| self.mlp_dropout = mlp_dropout |
| self.decoder_bias = decoder_bias |
| self.classifier_pooling = classifier_pooling |
| self.classifier_dropout = classifier_dropout |
| self.classifier_bias = classifier_bias |
| self.classifier_activation = classifier_activation |
| self.deterministic_flash_attn = deterministic_flash_attn |
| self.sparse_prediction = sparse_prediction |
| self.sparse_pred_ignore_index = sparse_pred_ignore_index |
| self.reference_compile = reference_compile |
| self.repad_logits_with_grad = repad_logits_with_grad |
| self.norm_type = norm_type |
| self.tie_word_embeddings = tie_word_embeddings |
| self.num_labels = num_labels |
| self.problem_type = problem_type |
| self.curricular_margin = curricular_margin |
| self.curricular_scale = curricular_scale |
| self.curricular_embedding_size = curricular_embedding_size |
|
|
| if self.kmer <= 0: |
| raise ValueError(f"`kmer` must be a positive integer, got {self.kmer}.") |
|
|
| if self.shift <= 0: |
| raise ValueError(f"`shift` must be a positive integer, got {self.shift}.") |
|
|
| if len(self.layer_types) != self.num_hidden_layers: |
| raise ValueError( |
| "`layer_types` must contain one entry per hidden layer: " |
| f"expected {self.num_hidden_layers}, got {len(self.layer_types)}." |
| ) |
|
|
| invalid_layer_types = sorted(set(self.layer_types) - {"full_attention", "sliding_attention"}) |
| if invalid_layer_types: |
| raise ValueError( |
| f"Unsupported values in `layer_types`: {invalid_layer_types}. " |
| 'Allowed values are ["full_attention", "sliding_attention"].' |
| ) |
|
|
| if self.classifier_pooling not in ["attention", "cls", "mean"]: |
| raise ValueError( |
| f'Invalid value for `classifier_pooling`, should be one of ["attention", "cls", "mean"], but is {self.classifier_pooling}.' |
| ) |
|
|
| if self.norm_type not in {"rms", "layernorm"}: |
| raise ValueError( |
| f'Invalid value for `norm_type`, should be either "rms" or "layernorm", but is {self.norm_type}.' |
| ) |
|
|
| def get_rope_parameters(self, layer_type: str) -> dict: |
| if layer_type not in {"full_attention", "sliding_attention"}: |
| raise ValueError( |
| f"Unsupported `layer_type`={layer_type!r}. Expected 'full_attention' or 'sliding_attention'." |
| ) |
| rope_params = self.rope_parameters.get(layer_type) |
| if rope_params is None: |
| rope_params = {"rope_type": "default", "rope_theta": self.default_theta[layer_type]} |
| return rope_params |
|
|
| @property |
| def sliding_window(self) -> int: |
| return self.local_attention // 2 |
|
|
| @sliding_window.setter |
| def sliding_window(self, value: int): |
| self.local_attention = int(value) * 2 |
|
|
| def to_dict(self): |
| output = super().to_dict() |
| output.pop("reference_compile", None) |
| return output |
|
|
|
|
| _CHECKPOINT_FOR_DOC = "example/prokbert-base" |
| _CONFIG_FOR_DOC = "ProkBertConfig" |
|
|
| PROK_BERT_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods |
| the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads, etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch module and refer to the PyTorch documentation for general usage and behavior. |
| |
| Parameters: |
| config ([`ProkBertConfig`]): |
| Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the model weights; see [`PreTrainedModel.from_pretrained`] for weight loading. |
| """ |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6, bias: bool = False): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.bias = nn.Parameter(torch.zeros(dim)) if bias else None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() |
| x = x * rms |
| if self.bias is not None: |
| x = x + self.bias |
| return self.weight * x |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (torch.Tensor): The query tensor. |
| k (torch.Tensor): The key tensor. |
| cos (torch.Tensor): The cosine part of the rotary embedding. |
| sin (torch.Tensor): The sine part of the rotary embedding. |
| position_ids (torch.Tensor, optional): Deprecated and unused. |
| unsqueeze_dim (int, optional): The dimension along which to unsqueeze cos and sin. |
| Returns: |
| tuple(torch.Tensor): The rotated query and key tensors. |
| """ |
| original_dtype = q.dtype |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin) |
| k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) |
| return q_embed.to(original_dtype), k_embed.to(original_dtype) |
|
|
|
|
| def eager_attention_forward( |
| module: "ProkBertAttention", |
| qkv: torch.Tensor, |
| attention_mask: torch.Tensor, |
| sliding_window_mask: torch.Tensor, |
| position_ids: Optional[torch.LongTensor], |
| local_attention: Tuple[int, int], |
| bs: int, |
| dim: int, |
| output_attentions: Optional[bool] = False, |
| **_kwargs, |
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
| |
| cos, sin = module.rotary_emb(qkv, position_ids=position_ids) |
| query, key, value = qkv.transpose(3, 1).unbind(dim=2) |
| |
| query, key = apply_rotary_pos_emb(query, key, cos, sin) |
|
|
| scale = module.head_dim ** -0.5 |
| attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale |
|
|
| if local_attention != (-1, -1): |
| attention_mask = sliding_window_mask |
|
|
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bs, -1, dim) |
| if output_attentions: |
| return (attn_output, attn_weights) |
| return (attn_output,) |
|
|
|
|
| def flash_attention_forward( |
| module: "ProkBertAttention", |
| qkv: torch.Tensor, |
| rotary_emb: "ProkBertUnpaddedRotaryEmbedding", |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| local_attention: Tuple[int, int], |
| bs: int, |
| dim: int, |
| target_dtype: torch.dtype = torch.bfloat16, |
| **_kwargs, |
| ) -> Tuple[torch.Tensor]: |
| |
| qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(target_dtype) |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=module.attention_dropout if module.training else 0.0, |
| deterministic=module.deterministic_flash_attn, |
| window_size=local_attention, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=module.attention_dropout if module.training else 0.0, |
| deterministic=module.deterministic_flash_attn, |
| window_size=local_attention, |
| ) |
| return (attn.view(bs, dim),) |
|
|
|
|
| def sdpa_attention_forward( |
| module: "ProkBertAttention", |
| qkv: torch.Tensor, |
| attention_mask: torch.Tensor, |
| sliding_window_mask: torch.Tensor, |
| position_ids: Optional[torch.LongTensor], |
| local_attention: Tuple[int, int], |
| bs: int, |
| dim: int, |
| **_kwargs, |
| ) -> Tuple[torch.Tensor]: |
| |
| cos, sin = module.rotary_emb(qkv, position_ids=position_ids) |
| query, key, value = qkv.transpose(3, 1).unbind(dim=2) |
| query, key = apply_rotary_pos_emb(query, key, cos, sin) |
|
|
| if local_attention != (-1, -1): |
| attention_mask = sliding_window_mask |
|
|
| attn_output = ( |
| F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| dropout_p=module.attention_dropout if module.training else 0.0, |
| attn_mask=attention_mask, |
| ) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
| attn_output = attn_output.view(bs, -1, dim) |
| return (attn_output,) |
|
|
|
|
| PROK_BERT_ATTENTION_FUNCTION = { |
| "flash_attention_2": flash_attention_forward, |
| "eager": eager_attention_forward, |
| "sdpa": sdpa_attention_forward, |
| } |
|
|
|
|
| def _unpad_prokbert_input( |
| inputs: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Remove padding from input sequences. |
| |
| Args: |
| inputs: (batch, seqlen, ...) or (batch, seqlen) |
| attention_mask: (batch, seqlen), where 1 means valid and 0 means padding. |
| position_ids: (batch, seqlen), optional position ids. |
| labels: (batch, seqlen), optional labels. |
| |
| Returns: |
| unpadded_inputs: Tensor of shape (total_nnz, ...) containing only valid tokens. |
| indices: Tensor of indices corresponding to valid tokens. |
| cu_seqlens: Cumulative sequence lengths of the unpadded tokens (shape: batch + 1). |
| max_seqlen_in_batch: Maximum sequence length among all sequences (excluding padding). |
| unpadded_position_ids: (total_nnz,) or None. |
| unpadded_labels: (total_nnz,) or None. |
| """ |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = int(seqlens_in_batch.max().item()) |
| cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
|
| if inputs.dim() == 2: |
| unpadded_inputs = inputs.flatten()[indices] |
| else: |
| batch, seqlen, *rest = inputs.shape |
| shape = batch * seqlen |
| unpadded_inputs = inputs.view(shape, *rest)[indices] |
|
|
| unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None |
| unpadded_labels = labels.flatten()[indices] if labels is not None else None |
|
|
| return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels |
|
|
|
|
| def _pad_prokbert_output( |
| inputs: torch.Tensor, |
| indices: torch.Tensor, |
| batch: int, |
| seqlen: int, |
| ) -> torch.Tensor: |
| """ |
| Add padding back to the output tensor. |
| |
| Args: |
| inputs: Tensor of shape (total_nnz, ...) containing outputs for only valid tokens. |
| indices: Tensor of indices indicating positions of valid tokens. |
| batch: Batch size. |
| seqlen: Maximum sequence length (including padding). |
| |
| Returns: |
| Tensor of shape (batch, seqlen, ...) with outputs in their original padded positions. |
| """ |
| if inputs.dim() == 1: |
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen) |
| else: |
| _, *rest = inputs.shape |
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen, *rest) |
| return padded_inputs |
|
|
|
|
|
|
| class ApplyRotaryEmbUnpad(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| qkv, |
| cos, |
| sin, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| |
| qkv = qkv.contiguous() |
| total_nnz, _three, _nheads, headdim = qkv.shape |
| |
| qk = qkv[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| qk, |
| cos, |
| sin, |
| seqlen_offsets=0, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=False, |
| inplace=True, |
| ) |
|
|
| ctx.save_for_backward(cos, sin, cu_seqlens) |
| ctx.max_seqlen = max_seqlen |
| return qkv |
|
|
| @staticmethod |
| def backward(ctx, do): |
| cos, sin, cu_seqlens = ctx.saved_tensors |
| do = do.contiguous() |
| total_nnz, _three, _nheads, headdim = do.shape |
| dqk = do[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| dqk, |
| cos, |
| sin, |
| seqlen_offsets=0, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=False, |
| inplace=True, |
| conjugate=True, |
| ) |
| return do, None, None, None, None |
|
|
|
|
| def apply_rotary_unpadded( |
| qkv, |
| cos, |
| sin, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| """ |
| Apply rotary embeddings to an unpadded (packed) QKV tensor. |
| |
| Args: |
| qkv: Tensor of shape (total_nnz, 3, nheads, headdim) for packed QKV. |
| cos, sin: Precomputed cosine and sine caches. |
| cu_seqlens: Cumulative sequence lengths (batch + 1,). |
| max_seqlen: Maximum sequence length in the batch. |
| Returns: |
| Tensor with rotary embeddings applied. |
| """ |
| return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) |
|
|
|
|
| class ProkBertUnpaddedRotaryEmbedding(RotaryEmbedding): |
| """ |
| Rotary embeddings for unpadded (packed) sequences used in ProkBERT. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| base: float = 16000.0, |
| max_seqlen: Optional[int] = None, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| """ |
| Args: |
| dim: Dimension of each head. |
| base: Base for the rotary frequency computation. |
| max_seqlen: Maximum sequence length to precompute the cosine and sine cache. |
| device: Device on which to create the cache. |
| dtype: Data type for the cache. |
| """ |
| |
| super().__init__(dim=dim, base=base, device=device, interleaved=False) |
|
|
| self.max_seqlen = max_seqlen |
|
|
| if max_seqlen is not None and device is not None and dtype is not None: |
| self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: Optional[int] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Apply rotary embeddings *inplace* to a packed QKV tensor. |
| |
| Args: |
| qkv: Tensor of shape (total_nnz, 3, nheads, headdim). |
| cu_seqlens: Cumulative sequence lengths tensor (batch + 1,). |
| max_seqlen: Maximum sequence length in the current batch. |
| Returns: |
| Tensor with rotary embeddings applied. |
| """ |
| if max_seqlen is not None: |
| self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) |
|
|
| qkv = apply_rotary_unpadded( |
| qkv, |
| self._cos_cached, |
| self._sin_cached, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
| return qkv |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" |
|
|
|
|
| class ProkBertEmbeddings(nn.Module): |
| """ |
| Construct the embeddings from token embeddings, layer normalization, and dropout. |
| """ |
|
|
| def __init__(self, config: ProkBertConfig): |
| super().__init__() |
| self.config = config |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| if config.norm_type == "rms": |
| self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| else: |
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.drop = nn.Dropout(config.embedding_dropout) |
|
|
| def forward( |
| self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Forward pass for the embeddings layer. |
| Args: |
| input_ids: Tensor of input token ids. |
| inputs_embeds: Alternatively, a pre-computed embedding tensor. |
| Returns: |
| Tensor of embeddings with normalization and dropout applied. |
| """ |
| if inputs_embeds is not None: |
| hidden_states = self.drop(self.norm(inputs_embeds)) |
| else: |
| hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids))) |
| return hidden_states |
|
|
|
|
| class ProkBertRotaryEmbedding(nn.Module): |
| def __init__(self, config: ProkBertConfig, layer_type: str, device: Optional[torch.device] = None): |
| super().__init__() |
|
|
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.layer_type = layer_type |
|
|
| rope_params = config.get_rope_parameters(layer_type) |
| self.rope_type = rope_params["rope_type"] |
| self.rope_init_fn: Callable = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(config, device, layer_type=layer_type) |
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: ProkBertConfig, |
| device: Optional["torch.device"] = None, |
| seq_len: Optional[int] = None, |
| layer_type: Optional[str] = None, |
| ) -> tuple["torch.Tensor", float]: |
| """ |
| Computes the inverse frequencies according to the original RoPE implementation. |
| """ |
| current_layer_type = layer_type or "full_attention" |
| rope_params = config.get_rope_parameters(current_layer_type) |
| base = rope_params["rope_theta"] |
| dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
|
| attention_factor = 1.0 |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) |
| return inv_freq, attention_factor |
|
|
| def _dynamic_frequency_update(self, position_ids, device): |
| """ |
| Dynamic RoPE layers should recompute `inv_freq` in the following situations: |
| 1 - Growing beyond the cached sequence length (allow scaling) |
| 2 - The current sequence length is in the original scale (avoid losing precision with small sequences) |
| """ |
| seq_len = torch.max(position_ids) + 1 |
| if seq_len > self.max_seq_len_cached: |
| inv_freq, self.attention_scaling = self.rope_init_fn( |
| self.config, |
| device, |
| seq_len=seq_len, |
| layer_type=self.layer_type, |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.max_seq_len_cached = seq_len |
|
|
| if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: |
| self.original_inv_freq = self.original_inv_freq.to(device) |
| self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
| self.max_seq_len_cached = self.original_max_seq_len |
|
|
| @torch.no_grad() |
| def forward(self, x, position_ids): |
| if "dynamic" in self.rope_type: |
| self._dynamic_frequency_update(position_ids, device=x.device) |
|
|
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
| position_ids_expanded = position_ids[:, None, :].float() |
| device_type = x.device.type |
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
|
|
| cos = cos * self.attention_scaling |
| sin = sin * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class ProkBertMLP(nn.Module): |
| """Applies the GLU at the end of each ModernBERT layer. |
| |
| Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| """ |
|
|
| def __init__(self, config: ProkBertConfig): |
| super().__init__() |
| self.config = config |
| self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) |
| self.act = ACT2FN[config.hidden_activation] |
| self.drop = nn.Dropout(config.mlp_dropout) |
| self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input, gate = self.Wi(hidden_states).chunk(2, dim=-1) |
| return self.Wo(self.drop(self.act(input) * gate)) |
|
|
|
|
| class ProkBertAttention(nn.Module): |
| """Performs multi-headed self attention on a batch of unpadded sequences. |
| |
| If Flash Attention 2 is available, this module uses it to improve throughput. |
| Otherwise, it falls back on PyTorch's SDPA (or eager) implementation. |
| """ |
|
|
| def __init__(self, config: ProkBertConfig, layer_id: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_id = layer_id |
|
|
| if config.hidden_size % config.num_attention_heads != 0: |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" |
| ) |
|
|
| self.attention_dropout = config.attention_dropout |
| self.deterministic_flash_attn = config.deterministic_flash_attn |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.all_head_size = self.head_dim * self.num_heads |
| self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) |
|
|
| if layer_id is None: |
| self.attention_type = "full_attention" |
| else: |
| self.attention_type = config.layer_types[layer_id] |
|
|
| if self.attention_type == "sliding_attention": |
| local_window = config.sliding_window |
| if config._attn_implementation == "flash_attention_2": |
| |
| local_window = local_window + 1 |
| self.local_attention = (local_window, local_window) |
| max_position_embeddings = config.local_attention |
| elif self.attention_type == "full_attention": |
| self.local_attention = (-1, -1) |
| max_position_embeddings = config.max_position_embeddings |
| else: |
| raise ValueError( |
| f"Unsupported attention type {self.attention_type!r}. " |
| "Expected 'full_attention' or 'sliding_attention'." |
| ) |
|
|
| rope_theta = float(config.get_rope_parameters(self.attention_type)["rope_theta"]) |
|
|
| if config._attn_implementation == "flash_attention_2": |
| self.rotary_emb = ProkBertUnpaddedRotaryEmbedding( |
| dim=self.head_dim, |
| max_seqlen=max_position_embeddings, |
| base=rope_theta, |
| ) |
| else: |
| self.rotary_emb = ProkBertRotaryEmbedding(config=config, layer_type=self.attention_type) |
|
|
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) |
| self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() |
| self.pruned_heads = set() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| output_attentions: Optional[bool] = False, |
| **kwargs, |
| ) -> torch.Tensor: |
| qkv = self.Wqkv(hidden_states) |
| bs = hidden_states.shape[0] |
| if self.config._attn_implementation == "flash_attention_2": |
| qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) |
| else: |
| qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) |
|
|
| attn_outputs = PROK_BERT_ATTENTION_FUNCTION[self.config._attn_implementation]( |
| self, |
| qkv=qkv, |
| rotary_emb=self.rotary_emb, |
| local_attention=self.local_attention, |
| bs=bs, |
| dim=self.all_head_size, |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
| hidden_states = attn_outputs[0] |
| hidden_states = self.out_drop(self.Wo(hidden_states)) |
| return (hidden_states,) + attn_outputs[1:] |
|
|
|
|
| class ProkBertEncoderLayer(nn.Module): |
| def __init__(self, config: ProkBertConfig, layer_id: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.attention_type = "full_attention" if layer_id is None else config.layer_types[layer_id] |
|
|
| Norm = RMSNorm if config.norm_type == "rms" else nn.LayerNorm |
|
|
| self.attn_norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.mlp_norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
|
|
| self.attn = ProkBertAttention(config=config, layer_id=layer_id) |
| self.mlp = ProkBertMLP(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> torch.Tensor: |
| attn_outputs = self.attn( |
| self.attn_norm(hidden_states), |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = hidden_states + attn_outputs[0] |
| mlp_output = self.mlp(self.mlp_norm(hidden_states)) |
| hidden_states = hidden_states + mlp_output |
|
|
| return (hidden_states,) + attn_outputs[1:] |
|
|
|
|
|
|
| PROK_BERT_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods |
| the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads, etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch module and refer to the PyTorch documentation for general usage and behavior. |
| |
| Parameters: |
| config ([`ProkBertConfig`]): |
| Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the model weights; see [`PreTrainedModel.from_pretrained`] for weight loading. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare ProkBert Model outputting raw hidden-states without any specific head on top.", |
| PROK_BERT_START_DOCSTRING, |
| ) |
| class ProkBertPreTrainedModel(PreTrainedModel): |
| config_class = ProkBertConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["ProkBertEmbeddings", "ProkBertEncoderLayer"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = False |
|
|
| def _init_weights(self, module: nn.Module): |
| cutoff_factor = self.config.initializer_cutoff_factor |
| if cutoff_factor is None: |
| cutoff_factor = 3 |
|
|
| def init_weight(module: nn.Module, std: float): |
| nn.init.trunc_normal_( |
| module.weight, |
| mean=0.0, |
| std=std, |
| a=-cutoff_factor * std, |
| b=cutoff_factor * std, |
| ) |
| if isinstance(module, nn.Linear): |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| stds = { |
| "in": self.config.initializer_range, |
| "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), |
| "embedding": self.config.initializer_range, |
| "final_out": self.config.hidden_size ** -0.5, |
| } |
|
|
| if isinstance(module, ProkBertEmbeddings): |
| init_weight(module.tok_embeddings, stds["embedding"]) |
| elif isinstance(module, ProkBertMLP): |
| init_weight(module.Wi, stds["in"]) |
| init_weight(module.Wo, stds["out"]) |
| elif isinstance(module, ProkBertAttention): |
| init_weight(module.Wqkv, stds["in"]) |
| init_weight(module.Wo, stds["out"]) |
| elif isinstance(module, ProkBertPredictionHead): |
| init_weight(module.dense, stds["out"]) |
| elif isinstance(module, (ProkBertForMaskedLM, ProkBertForMaskedLM2)): |
| init_weight(module.decoder, stds["out"]) |
|
|
|
|
| @classmethod |
| def _autoset_attn_implementation( |
| cls, |
| config, |
| use_flash_attention_2: bool = False, |
| torch_dtype: Optional[torch.dtype] = None, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| check_device_map: bool = True, |
| ): |
| if config._attn_implementation_internal is None: |
| config._attn_implementation_internal = "flash_attention_2" |
| try: |
| return cls._check_and_enable_flash_attn_2( |
| config, |
| torch_dtype=torch.float16, |
| device_map=device_map, |
| hard_check_only=False, |
| check_device_map=check_device_map, |
| ) |
| except (ValueError, ImportError): |
| config._attn_implementation_internal = None |
| return super()._autoset_attn_implementation( |
| config, |
| use_flash_attention_2=use_flash_attention_2, |
| torch_dtype=torch.float16, |
| device_map=device_map, |
| check_device_map=check_device_map, |
| ) |
|
|
| def resize_token_embeddings(self, *args, **kwargs): |
| model_embeds = super().resize_token_embeddings(*args, **kwargs) |
| return model_embeds |
|
|
| @add_start_docstrings( |
| "The bare ProkBert Model outputting raw hidden-states without any specific head on top.", |
| PROK_BERT_START_DOCSTRING, |
| ) |
| class ProkBertModel(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): |
| def __init__(self, config: ProkBertConfig): |
| |
| super().__init__(config) |
| self.config = config |
| self.embeddings = ProkBertEmbeddings(config) |
| self.layers = nn.ModuleList( |
| [ProkBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] |
| ) |
| if config.norm_type == "rms": |
| self.final_norm = RMSNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| else: |
| self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.tok_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.tok_embeddings = value |
|
|
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if output_attentions and self.config._attn_implementation == "flash_attention_2": |
| logger.warning_once( |
| "`output_attentions=True` is not supported with flash_attention_2 in ProkBertModel. " |
| "Falling back to `output_attentions=False`." |
| ) |
| output_attentions = False |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| all_hidden_states = () if output_hidden_states else None |
| all_self_attentions = () if output_attentions else None |
|
|
| if input_ids is not None: |
| self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
| if batch_size is None or seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| repad = False |
| restore_attn_implementation = None |
| if self.config._attn_implementation == "flash_attention_2": |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| repad = True |
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_prokbert_input( |
| inputs=input_ids, |
| attention_mask=attention_mask, |
| ) |
| else: |
| inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_prokbert_input( |
| inputs=inputs_embeds, |
| attention_mask=attention_mask, |
| ) |
| else: |
| if output_attentions and self.config._attn_implementation == "sdpa": |
| restore_attn_implementation = self.config._attn_implementation |
| if position_ids is None: |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
| attention_mask, sliding_window_mask = self._update_attention_mask( |
| attention_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) |
|
|
| for encoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| encoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| sliding_window_mask, |
| position_ids, |
| cu_seqlens, |
| max_seqlen, |
| output_attentions, |
| ) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = layer_outputs[0] |
| if output_attentions and len(layer_outputs) > 1: |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
| hidden_states = self.final_norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if repad: |
| hidden_states = _pad_prokbert_output(inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len) |
| if all_hidden_states is not None: |
| all_hidden_states = tuple( |
| _pad_prokbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) |
| for hs in all_hidden_states |
| ) |
|
|
| if restore_attn_implementation is not None: |
| self.config._attn_implementation = restore_attn_implementation |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| ) |
|
|
| def _update_attention_mask( |
| self, |
| attention_mask: torch.Tensor, |
| output_attentions: bool, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if output_attentions: |
| if self.config._attn_implementation == "sdpa": |
| logger.warning_once( |
| "Outputting attentions is only supported with the 'eager' attention implementation, " |
| 'not with "sdpa". Falling back to `attn_implementation="eager"`.' |
| ) |
| self.config._attn_implementation = "eager" |
| elif self.config._attn_implementation != "eager": |
| logger.warning_once( |
| "Outputting attentions is only supported with the eager attention implementation, " |
| f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`. ' |
| "Setting `output_attentions=False`." |
| ) |
|
|
| return _build_bidirectional_attention_biases( |
| attention_mask=attention_mask, |
| dtype=self.dtype, |
| sliding_window=self.config.sliding_window, |
| ) |
|
|
|
|
|
|
| class ProkBertPredictionHead(nn.Module): |
| def __init__(self, config: ProkBertConfig): |
| super().__init__() |
| Norm = RMSNorm if getattr(config, "norm_type", "layernorm") == "rms" else nn.LayerNorm |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) |
| self.act = ACT2FN[config.classifier_activation] |
| self.norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.norm(self.act(self.dense(hidden_states))) |
|
|
|
|
|
|
| @add_start_docstrings( |
| "The ProkBert Model with a decoder head on top that is used for masked language modeling.", |
| PROK_BERT_START_DOCSTRING, |
| ) |
|
|
| class ProkBertForMaskedLM(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): |
| _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} |
|
|
| def __init__(self, config: ProkBertConfig): |
| super().__init__(config) |
| self.config = config |
| self.model = ProkBertModel(config) |
| self.head = ProkBertPredictionHead(config) |
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) |
|
|
| self.sparse_prediction = self.config.sparse_prediction |
| self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index |
|
|
| self.post_init() |
|
|
| def get_output_embeddings(self): |
| return self.decoder |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Linear): |
| self.decoder = new_embeddings |
|
|
| def _prediction_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.decoder(self.head(hidden_states)) |
|
|
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=MaskedLMOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| if batch_size is None or seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_prokbert_input( |
| inputs=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| labels=labels, |
| ) |
| else: |
| inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_prokbert_input( |
| inputs=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| labels=labels, |
| ) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| last_hidden_state = outputs[0] |
|
|
| logits_are_sparse = False |
| if self.sparse_prediction and labels is not None: |
| labels = labels.view(-1) |
| last_hidden_state = last_hidden_state.view(labels.shape[0], -1) |
| mask_tokens = labels != self.sparse_pred_ignore_index |
| last_hidden_state = last_hidden_state[mask_tokens] |
| labels = labels[mask_tokens] |
| logits_are_sparse = True |
|
|
| logits = self._prediction_logits(last_hidden_state) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| should_repad_logits = ( |
| self.config._attn_implementation == "flash_attention_2" |
| and indices is not None |
| and batch_size is not None |
| and seq_len is not None |
| and not logits_are_sparse |
| ) |
| if should_repad_logits: |
| with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): |
| logits = _pad_prokbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
|
|
| class ProkBertForSequenceClassification(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): |
| """Standalone ProkBERT sequence classifier with mask-aware pooling.""" |
|
|
| def __init__(self, config: ProkBertConfig): |
| super().__init__(config) |
| self.num_labels = int(config.num_labels) |
| self.config = config |
|
|
| self.model = ProkBertModel(config) |
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
|
|
| |
| self.weighting_layer = nn.Linear(config.hidden_size, 1) |
| self.dropout = nn.Dropout(config.classifier_dropout) |
| self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
|
|
| self.post_init() |
|
|
| def _init_weights(self, module: nn.Module): |
| super()._init_weights(module) |
|
|
| if module is self.weighting_layer: |
| nn.init.zeros_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| if module is self.classifier: |
| nn.init.xavier_uniform_(module.weight, gain=1.0) |
| module.weight.data /= math.sqrt(self.classifier.in_features) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def _pool_cls(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| if token_mask is None: |
| return sequence_output[:, 0] |
|
|
| first_token_indices = token_mask.to(dtype=torch.long).argmax(dim=-1) |
| batch_indices = torch.arange(sequence_output.shape[0], device=sequence_output.device) |
| return sequence_output[batch_indices, first_token_indices] |
|
|
| def _pool_mean(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| if token_mask is None: |
| return sequence_output.mean(dim=1) |
|
|
| weights = token_mask.unsqueeze(-1).to(dtype=sequence_output.dtype) |
| denom = weights.sum(dim=1).clamp(min=1.0) |
| return (sequence_output * weights).sum(dim=1) / denom |
|
|
| def _pool_attention(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| scores = self.weighting_layer(sequence_output) |
|
|
| if token_mask is not None: |
| empty_rows = token_mask.sum(dim=1) == 0 |
| if empty_rows.any(): |
| token_mask = token_mask.clone() |
| token_mask[empty_rows, 0] = True |
| scores = scores.masked_fill(~token_mask.unsqueeze(-1), torch.finfo(scores.dtype).min) |
|
|
| weights = torch.softmax(scores.float(), dim=1).to(dtype=sequence_output.dtype) |
| return torch.sum(weights * sequence_output, dim=1) |
|
|
| def _pool_sequence(self, sequence_output: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| token_mask = _normalize_token_attention_mask( |
| attention_mask, |
| batch_size=sequence_output.shape[0], |
| seq_len=sequence_output.shape[1], |
| device=sequence_output.device, |
| ) |
|
|
| pooling = self.config.classifier_pooling |
| if pooling == "attention": |
| return self._pool_attention(sequence_output, token_mask) |
| if pooling == "mean": |
| return self._pool_mean(sequence_output, token_mask) |
| if pooling == "cls": |
| return self._pool_cls(sequence_output, token_mask) |
| raise ValueError( |
| f"Unsupported `classifier_pooling`={pooling!r}. Expected one of ['attention', 'cls', 'mean']." |
| ) |
|
|
| def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif labels.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.long, torch.uint8): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = nn.MSELoss() |
| if self.num_labels == 1: |
| return loss_fct(logits.squeeze(), labels.squeeze().to(logits.dtype)) |
| return loss_fct(logits, labels.to(logits.dtype)) |
|
|
| if self.config.problem_type == "single_label_classification": |
| loss_fct = nn.CrossEntropyLoss() |
| return loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if self.config.problem_type == "multi_label_classification": |
| loss_fct = nn.BCEWithLogitsLoss() |
| return loss_fct(logits, labels.to(logits.dtype)) |
|
|
| raise ValueError( |
| f"Unsupported `problem_type`={self.config.problem_type!r}. " |
| "Expected 'regression', 'single_label_classification', or 'multi_label_classification'." |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: torch.Tensor = None, |
| sliding_window_mask: torch.Tensor = None, |
| position_ids: torch.LongTensor = None, |
| inputs_embeds: torch.Tensor = None, |
| labels: torch.Tensor = None, |
| indices: torch.Tensor = None, |
| cu_seqlens: torch.Tensor = None, |
| max_seqlen: int = None, |
| batch_size: int = None, |
| seq_len: int = None, |
| output_attentions: bool = None, |
| output_hidden_states: bool = None, |
| return_dict: bool = None, |
| **kwargs, |
| ) -> SequenceClassifierOutput: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
| if sequence_output.dim() == 2: |
| if indices is None or batch_size is None or seq_len is None: |
| raise ValueError( |
| "Received unpadded hidden states from `ProkBertModel`, but `indices`, `batch_size`, and " |
| "`seq_len` were not provided to repad them for sequence classification." |
| ) |
| sequence_output = _pad_prokbert_output( |
| inputs=sequence_output, |
| indices=indices, |
| batch=batch_size, |
| seqlen=seq_len, |
| ) |
|
|
| pooled_output = self._pool_sequence(sequence_output, attention_mask) |
| pooled_output = self.norm(pooled_output) |
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| loss = self._compute_loss(logits, labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @dataclass |
| class CurricularSequenceClassifierOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| embeddings: Optional[torch.FloatTensor] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
| class CurricularFace(nn.Module): |
| def __init__(self, in_features: int, out_features: int, m: float = 0.5, s: float = 64.0, ema_alpha: float = 0.01): |
| super().__init__() |
| self.in_features = int(in_features) |
| self.out_features = int(out_features) |
| self.m = float(m) |
| self.s = float(s) |
| self.ema_alpha = float(ema_alpha) |
|
|
| self.cos_m = math.cos(self.m) |
| self.sin_m = math.sin(self.m) |
| self.threshold = math.cos(math.pi - self.m) |
| self.mm = math.sin(math.pi - self.m) * self.m |
|
|
| self.kernel = nn.Parameter(torch.empty(self.in_features, self.out_features)) |
| self.register_buffer("t", torch.zeros(1, dtype=torch.float32)) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.kernel) |
| self.t.zero_() |
|
|
| def cosine(self, embeddings: torch.Tensor) -> torch.Tensor: |
| with _autocast_disabled(embeddings.device.type): |
| x = F.normalize(embeddings.float(), p=2.0, dim=1, eps=1e-12) |
| w = F.normalize(self.kernel.float(), p=2.0, dim=0, eps=1e-12) |
| cos_theta = F.linear(x, w.t()).clamp(-1.0, 1.0) |
| return cos_theta |
|
|
| def inference_logits(self, embeddings: torch.Tensor) -> torch.Tensor: |
| return self.cosine(embeddings) * self.s |
|
|
| def margin_logits_from_cosine( |
| self, |
| cos_theta: torch.Tensor, |
| labels: torch.LongTensor, |
| update_t: bool = False, |
| ) -> torch.Tensor: |
| labels = labels.reshape(-1).long() |
| target = cos_theta.gather(1, labels.unsqueeze(1)) |
|
|
| sin_theta = torch.sqrt((1.0 - target.square()).clamp(min=0.0)) |
| cos_theta_m = target * self.cos_m - sin_theta * self.sin_m |
|
|
| hard_mask = cos_theta > cos_theta_m |
| final_target = torch.where( |
| target > self.threshold, |
| cos_theta_m, |
| target - self.mm, |
| ) |
|
|
| if update_t: |
| with torch.no_grad(): |
| target_mean = target.mean().to(dtype=self.t.dtype).view_as(self.t) |
| self.t.lerp_(target_mean, self.ema_alpha) |
|
|
| t = self.t.to(device=cos_theta.device, dtype=cos_theta.dtype) |
| adjusted = torch.where(hard_mask, cos_theta * (t + cos_theta), cos_theta) |
| adjusted = adjusted.scatter(1, labels.unsqueeze(1), final_target) |
|
|
| return adjusted * self.s |
|
|
|
|
| class ProkBertForCurricularClassification(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): |
| """ProkBERT sequence classifier with CurricularFace logits for single-label classification.""" |
|
|
| def __init__(self, config: ProkBertConfig): |
| super().__init__(config) |
| self.config = config |
| self.num_labels = int(config.num_labels) |
|
|
| if self.num_labels < 2: |
| raise ValueError( |
| "`ProkBertForCurricularClassification` requires `config.num_labels >= 2`. " |
| "CurricularFace is intended for single-label classification." |
| ) |
| if self.config.problem_type is None: |
| self.config.problem_type = "single_label_classification" |
| elif self.config.problem_type != "single_label_classification": |
| raise ValueError( |
| "`ProkBertForCurricularClassification` only supports `problem_type='single_label_classification'`." |
| ) |
|
|
| self.model = ProkBertModel(config) |
| self.weighting_layer = nn.Linear(config.hidden_size, 1) |
| self.dropout = nn.Dropout(config.classifier_dropout) |
|
|
| use_projection = config.curricular_embedding_size not in (None, -1) |
| embedding_dim = config.hidden_size if not use_projection else int(config.curricular_embedding_size) |
| self.linear = nn.Linear(config.hidden_size, embedding_dim) if use_projection else nn.Identity() |
|
|
| self.curricular_face = CurricularFace( |
| in_features=embedding_dim, |
| out_features=self.num_labels, |
| m=float(config.curricular_margin), |
| s=float(config.curricular_scale), |
| ) |
| self.loss_fct = nn.CrossEntropyLoss() |
|
|
| self.post_init() |
|
|
| with torch.no_grad(): |
| nn.init.zeros_(self.weighting_layer.weight) |
| if self.weighting_layer.bias is not None: |
| nn.init.zeros_(self.weighting_layer.bias) |
| if isinstance(self.linear, nn.Linear): |
| initialize_linear_kaiming(self.linear) |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def _pool_cls(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| if token_mask is None: |
| return sequence_output[:, 0] |
|
|
| first_token_indices = token_mask.to(dtype=torch.long).argmax(dim=-1) |
| batch_indices = torch.arange(sequence_output.shape[0], device=sequence_output.device) |
| return sequence_output[batch_indices, first_token_indices] |
|
|
| def _pool_mean(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| if token_mask is None: |
| return sequence_output.mean(dim=1) |
|
|
| weights = token_mask.unsqueeze(-1).to(dtype=sequence_output.dtype) |
| denom = weights.sum(dim=1).clamp(min=1.0) |
| return (sequence_output * weights).sum(dim=1) / denom |
|
|
| def _pool_attention(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| scores = self.weighting_layer(sequence_output) |
|
|
| if token_mask is not None: |
| empty_rows = token_mask.sum(dim=1) == 0 |
| if empty_rows.any(): |
| token_mask = token_mask.clone() |
| token_mask[empty_rows, 0] = True |
| scores = scores.masked_fill(~token_mask.unsqueeze(-1), torch.finfo(scores.dtype).min) |
|
|
| weights = torch.softmax(scores.float(), dim=1).to(dtype=sequence_output.dtype) |
| return torch.sum(weights * sequence_output, dim=1) |
|
|
| def _pool_sequence(self, sequence_output: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| token_mask = _normalize_token_attention_mask( |
| attention_mask, |
| batch_size=sequence_output.shape[0], |
| seq_len=sequence_output.shape[1], |
| device=sequence_output.device, |
| ) |
|
|
| pooling = self.config.classifier_pooling |
| if pooling == "attention": |
| return self._pool_attention(sequence_output, token_mask) |
| if pooling == "mean": |
| return self._pool_mean(sequence_output, token_mask) |
| if pooling == "cls": |
| return self._pool_cls(sequence_output, token_mask) |
| raise ValueError( |
| f"Unsupported `classifier_pooling`={pooling!r}. Expected one of ['attention', 'cls', 'mean']." |
| ) |
|
|
| def _repad_sequence_output_if_needed( |
| self, |
| sequence_output: torch.Tensor, |
| indices: Optional[torch.Tensor], |
| batch_size: Optional[int], |
| seq_len: Optional[int], |
| ) -> torch.Tensor: |
| if sequence_output.dim() != 2: |
| return sequence_output |
|
|
| if indices is None or batch_size is None or seq_len is None: |
| raise ValueError( |
| "Received unpadded hidden states from `ProkBertModel`, but `indices`, `batch_size`, and " |
| "`seq_len` were not provided to repad them for curricular classification." |
| ) |
|
|
| return _pad_prokbert_output( |
| inputs=sequence_output, |
| indices=indices, |
| batch=batch_size, |
| seqlen=seq_len, |
| ) |
|
|
| def _compute_embeddings( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| apply_dropout: bool = True, |
| ) -> tuple[torch.Tensor, BaseModelOutput]: |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| ) |
|
|
| sequence_output = self._repad_sequence_output_if_needed( |
| outputs.last_hidden_state, |
| indices=indices, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| ) |
| pooled_output = self._pool_sequence(sequence_output, attention_mask) |
|
|
| if apply_dropout: |
| pooled_output = self.dropout(pooled_output) |
|
|
| embeddings = self.linear(pooled_output) |
| return embeddings, outputs |
|
|
| def encode( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| embeddings, _ = self._compute_embeddings( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| apply_dropout=False, |
| ) |
| return l2_norm(embeddings, axis=1) if normalize else embeddings |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| return_embeddings: bool = False, |
| normalize_embeddings: bool = True, |
| **kwargs, |
| ) -> Union[Tuple, CurricularSequenceClassifierOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| embeddings, outputs = self._compute_embeddings( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| apply_dropout=self.training, |
| ) |
|
|
| exported_embeddings = None |
| if return_embeddings: |
| exported_embeddings = l2_norm(embeddings, axis=1) if normalize_embeddings else embeddings |
|
|
| cos_theta = self.curricular_face.cosine(embeddings) |
| logits = cos_theta * self.curricular_face.s |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.view(-1).long() |
| train_logits = self.curricular_face.margin_logits_from_cosine( |
| cos_theta, |
| labels, |
| update_t=self.training, |
| ) |
| loss = self.loss_fct(train_logits, labels) |
|
|
| if not return_dict: |
| out = (logits,) |
| if return_embeddings: |
| out = out + (exported_embeddings,) |
| if output_hidden_states: |
| out = out + (outputs.hidden_states,) |
| if output_attentions: |
| out = out + (outputs.attentions,) |
| return ((loss,) + out) if loss is not None else out |
|
|
| return CurricularSequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| embeddings=exported_embeddings, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class ProkBertForMaskedLM2(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): |
| _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.model = ProkBertModel(config) |
| self.head = ProkBertPredictionHead(config) |
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) |
|
|
| |
| self.sparse_prediction = config.sparse_prediction |
| self.sparse_pred_ignore_index = config.sparse_pred_ignore_index |
|
|
| self.post_init() |
|
|
| def get_output_embeddings(self): |
| return self.decoder |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Linear): |
| self.decoder = new_embeddings |
|
|
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=MaskedLMOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor]= None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| labels_dist: Optional[torch.FloatTensor] = None, |
| loss_mask: Optional[torch.BoolTensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if self.config._attn_implementation == "flash_attention_2" \ |
| and indices is None and cu_seqlens is None and max_seqlen is None: |
| |
| if batch_size is None or seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| attention_mask = attention_mask if attention_mask is not None else \ |
| torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = \ |
| _unpad_prokbert_input( |
| inputs=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| labels=labels |
| ) |
| else: |
| inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = \ |
| _unpad_prokbert_input( |
| inputs=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| labels=labels |
| ) |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| sequence_output = outputs[0] |
|
|
| |
| if self.sparse_prediction and labels is not None: |
| flat_labels = labels.view(-1) |
| flat_hidden = sequence_output.view(flat_labels.shape[0], -1) |
| mask_tokens = flat_labels != self.sparse_pred_ignore_index |
| sequence_output = flat_hidden[mask_tokens] |
| labels = flat_labels[mask_tokens] |
|
|
| hidden = self.head(sequence_output) |
| logits = self.decoder(hidden) |
|
|
| loss = None |
| V = self.config.vocab_size |
|
|
| |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, V), labels.view(-1)) |
|
|
| |
| elif labels_dist is not None and loss_mask is not None: |
| B, L = loss_mask.shape |
| flat_mask = loss_mask.view(-1) |
| flat_dist = labels_dist.view(-1, V) |
|
|
| |
| if logits.dim() == 2 and logits.shape[0] != flat_mask.sum().item(): |
| full_attn = attention_mask.view(-1) |
| assert logits.shape[0] == full_attn.sum().item() |
| dist_attn = flat_dist[full_attn] |
| mask_in_attn = flat_mask[full_attn] |
| pred = logits[mask_in_attn] |
| targ = dist_attn[mask_in_attn] |
|
|
| |
| elif logits.dim() == 2 and logits.shape[0] == flat_mask.sum().item(): |
| pred = logits |
| targ = flat_dist[flat_mask] |
|
|
| |
| else: |
| flat_logits = logits.view(-1, V) |
| pred = flat_logits[flat_mask] |
| targ = flat_dist[flat_mask] |
|
|
| eps = 1e-8 |
| targ = targ.clamp_min(eps) |
| targ = targ / targ.sum(dim=-1, keepdim=True) |
| targ = targ.to(pred.dtype).detach() |
|
|
| logp = F.log_softmax(pred, dim=-1) |
| loss = F.kl_div(logp, targ, reduction="batchmean") |
|
|
| should_repad_logits = ( |
| self.config._attn_implementation == "flash_attention_2" |
| and indices is not None |
| and batch_size is not None |
| and seq_len is not None |
| and not (self.sparse_prediction and labels is not None) |
| ) |
| if should_repad_logits: |
| with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): |
| logits = _pad_prokbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
| |
| if not return_dict: |
| out = (logits,) + outputs[1:] |
| return ((loss,) + out) if loss is not None else out |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|