from typing import Optional, Tuple, Union, Dict, Literal, Callable import math import os from contextlib import nullcontext from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.utils import ModelOutput, logging from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.configuration_utils import PretrainedConfig from transformers.utils.doc import add_code_sample_docstrings, add_start_docstrings from transformers.utils.import_utils import is_triton_available, is_flash_attn_2_available from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput, MaskedLMOutput try: from transformers.utils.hub import cached_file except ImportError: # transformers<5 compatibility from transformers.utils import cached_file # type: ignore try: from transformers.modeling_rope_utils import RopeParameters except ImportError: RopeParameters = object if is_flash_attn_2_available(): from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary else: RotaryEmbedding = object logger = logging.get_logger(__name__) _HF_CONFIG_LOAD_KWARGS = { "cache_dir", "force_download", "local_files_only", "token", "revision", "subfolder", "proxies", } _HF_NON_MODEL_INIT_KWARGS = { "trust_remote_code", "_from_auto", "adapter_kwargs", } _HF_MODEL_INIT_BLACKLIST = { "device_map", "low_cpu_mem_usage", "offload_folder", "offload_state_dict", "max_memory", "quantization_config", "tp_plan", "tp_size", "weights_only", #"use_flash_attention_2", #"attn_implementation", #"torch_dtype" } logger = logging.get_logger(__name__) def _split_pretrained_kwargs(kwargs): kwargs = dict(kwargs) for k in _HF_NON_MODEL_INIT_KWARGS: kwargs.pop(k, None) config_load_kwargs = { k: kwargs.pop(k) for k in list(kwargs) if k in _HF_CONFIG_LOAD_KWARGS } use_safetensors = kwargs.pop("use_safetensors", None) weights_only = kwargs.pop("weights_only", True) return config_load_kwargs, use_safetensors, weights_only, kwargs def _resolve_weights_file(pretrained_model_name_or_path, use_safetensors=None, **load_kwargs) -> str: pretrained_model_name_or_path = os.fspath(pretrained_model_name_or_path) if use_safetensors is True: candidates = ("model.safetensors",) elif use_safetensors is False: candidates = ("pytorch_model.bin",) else: candidates = ("model.safetensors", "pytorch_model.bin") subfolder = load_kwargs.get("subfolder") if os.path.isdir(pretrained_model_name_or_path): base_dir = ( os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path ) for name in candidates: path = os.path.join(base_dir, name) if os.path.exists(path): return path for name in candidates: try: path = cached_file(pretrained_model_name_or_path, name, **load_kwargs) if path is not None: return path except Exception: pass raise FileNotFoundError( f"No checkpoint file found in {pretrained_model_name_or_path!r} " f"(candidates: {', '.join(candidates)})" ) def _read_state_dict(weights_path, weights_only: bool = True) -> dict[str, torch.Tensor]: weights_path = os.fspath(weights_path) if weights_path.endswith(".safetensors"): from safetensors.torch import load_file as safe_load_file return safe_load_file(weights_path, device="cpu") try: return torch.load(weights_path, map_location="cpu", weights_only=weights_only) except TypeError: # older torch versions return torch.load(weights_path, map_location="cpu") def _align_state_dict_with_base_prefix(model: nn.Module, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: state_dict = dict(state_dict) base_prefix = getattr(model, "base_model_prefix", None) if not base_prefix: return state_dict prefix = f"{base_prefix}." model_keys = set(model.state_dict().keys()) model_has_prefix = any(k.startswith(prefix) for k in model_keys) ckpt_has_prefix = any(k.startswith(prefix) for k in state_dict.keys()) # Loading a wrapped checkpoint (model.*) into the bare encoder. if ckpt_has_prefix and not model_has_prefix: stripped = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} return stripped if stripped else state_dict # Loading a bare encoder checkpoint into a wrapped task model. if model_has_prefix and not ckpt_has_prefix: remapped = {} for k, v in state_dict.items(): prefixed = f"{prefix}{k}" remapped[prefixed if prefixed in model_keys else k] = v return remapped return state_dict class _SafeFromPretrainedMixin: @classmethod def _adapt_state_dict(cls, model, state_dict): return state_dict @staticmethod def _filter_keys_with_patterns(keys, patterns): if not patterns: return list(keys) import re compiled = [re.compile(p) if isinstance(p, str) else p for p in patterns] return [k for k in keys if not any(p.search(k) for p in compiled)] @classmethod def _resolve_config_and_init_kwargs( cls, pretrained_model_name_or_path, config, config_load_kwargs, other_kwargs, ): if isinstance(config, PretrainedConfig): return config, other_kwargs if config is None: config_source = pretrained_model_name_or_path elif isinstance(config, (str, os.PathLike)): config_source = config else: raise TypeError( "`config` must be None, a path-like object, or an instance of PretrainedConfig" ) config, init_kwargs = cls.config_class.from_pretrained( config_source, return_unused_kwargs=True, **config_load_kwargs, **other_kwargs, ) return config, init_kwargs @staticmethod def _remove_mismatched_keys(model, state_dict): state_dict = dict(state_dict) model_state = model.state_dict() mismatched_keys = [] for key in list(state_dict.keys()): if key not in model_state: continue loaded_value = state_dict[key] model_value = model_state[key] if not isinstance(loaded_value, torch.Tensor) or not isinstance(model_value, torch.Tensor): continue if tuple(loaded_value.shape) != tuple(model_value.shape): mismatched_keys.append((key, tuple(loaded_value.shape), tuple(model_value.shape))) state_dict.pop(key) return state_dict, mismatched_keys @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): output_loading_info = kwargs.pop("output_loading_info", False) state_dict = kwargs.pop("state_dict", None) config = kwargs.pop("config", None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) strict = kwargs.pop("strict", False) config_load_kwargs, use_safetensors, weights_only, other_kwargs = _split_pretrained_kwargs(kwargs) config, init_kwargs = cls._resolve_config_and_init_kwargs( pretrained_model_name_or_path=pretrained_model_name_or_path, config=config, config_load_kwargs=config_load_kwargs, other_kwargs=other_kwargs, ) ''' config = cls._autoset_attn_implementation( config, use_flash_attention_2=bool(init_kwargs.get("use_flash_attention_2", False)), torch_dtype=init_kwargs.get("torch_dtype", None), device_map=init_kwargs.get("device_map", None), ) ''' for k in _HF_MODEL_INIT_BLACKLIST: init_kwargs.pop(k, None) model = cls(config, *model_args, **init_kwargs) if state_dict is None: weights_path = _resolve_weights_file( pretrained_model_name_or_path, use_safetensors=use_safetensors, **config_load_kwargs, ) state_dict = _read_state_dict( weights_path, weights_only=True if weights_only is None else bool(weights_only), ) if not isinstance(state_dict, dict): raise TypeError(f"Expected a state dict, got {type(state_dict).__name__}") state_dict = _align_state_dict_with_base_prefix(model, state_dict) state_dict = cls._adapt_state_dict(model, state_dict) mismatched_keys = [] if ignore_mismatched_sizes: state_dict, mismatched_keys = cls._remove_mismatched_keys(model, state_dict) incompatible = model.load_state_dict(state_dict, strict=strict) if hasattr(model, "tie_weights"): model.tie_weights() model.eval() missing_keys = cls._filter_keys_with_patterns( list(incompatible.missing_keys), getattr(model, "_keys_to_ignore_on_load_missing", None), ) unexpected_keys = cls._filter_keys_with_patterns( list(incompatible.unexpected_keys), getattr(model, "_keys_to_ignore_on_load_unexpected", None), ) info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": [], } return (model, info) if output_loading_info else model def l2_norm(input, axis=1, epsilon=1e-12): norm = torch.norm(input, 2, axis, True) norm = torch.clamp(norm, min=epsilon) # Avoid zero division output = torch.div(input, norm) return output def initialize_linear_kaiming(layer: nn.Linear): if isinstance(layer, nn.Linear): nn.init.kaiming_uniform_(layer.weight, nonlinearity='linear') if layer.bias is not None: nn.init.zeros_(layer.bias) def _autocast_disabled(device_type: str): try: return torch.amp.autocast(device_type=device_type, enabled=False) except (AttributeError, TypeError): if device_type == "cuda": return torch.cuda.amp.autocast(enabled=False) if device_type == "cpu" and hasattr(torch, "cpu") and hasattr(torch.cpu, "amp"): return torch.cpu.amp.autocast(enabled=False) return nullcontext() def _normalize_token_attention_mask( attention_mask: Optional[torch.Tensor], *, batch_size: Optional[int] = None, seq_len: Optional[int] = None, device: Optional[torch.device] = None, ) -> Optional[torch.Tensor]: """Convert common attention-mask layouts to a `(batch, seq_len)` boolean keep-mask. Supported inputs are 2D token masks, 3D masks of shape `(batch, 1, seq_len)`, and 4D masks of shape `(batch, 1, query_len, key_len)` used by eager/SDPA attention. Returned values use `True` for valid (non-padding) tokens. """ if attention_mask is None: if batch_size is None or seq_len is None: return None return torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) mask = attention_mask if mask.dim() == 4: if mask.dtype == torch.bool: mask = mask[:, 0].any(dim=-2) else: mask = mask[:, 0].amax(dim=-2) >= 0 return mask.to(dtype=torch.bool) if mask.dim() == 3: if mask.size(1) != 1: raise ValueError( "3D attention masks must have shape (batch, 1, seq_len) when used in ProkBert." ) mask = mask[:, 0, :] if mask.dim() != 2: raise ValueError( "Attention masks for ProkBert must be 2D, 3D `(batch, 1, seq_len)`, or 4D " "`(batch, 1, query_len, key_len)`." ) if mask.dtype == torch.bool: return mask if mask.is_floating_point(): if mask.numel() == 0: return mask.to(dtype=torch.bool) return (mask >= 0) if torch.any(mask < 0) else (mask > 0) return mask != 0 def _to_additive_attention_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: """Convert a boolean keep-mask to the additive attention-bias format expected by eager attention.""" min_dtype = torch.finfo(dtype).min bias = torch.zeros(mask.shape, device=mask.device, dtype=dtype) return bias.masked_fill(~mask, min_dtype) def _build_bidirectional_attention_biases( attention_mask: torch.Tensor, dtype: torch.dtype, sliding_window: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Create full-attention and sliding-window additive masks for eager/SDPA attention.""" if attention_mask.dim() == 4: if attention_mask.dtype == torch.bool: global_attention_mask = _to_additive_attention_bias(attention_mask, dtype) else: global_attention_mask = attention_mask.to(dtype=dtype) query_length = global_attention_mask.shape[-2] key_length = global_attention_mask.shape[-1] else: token_mask = _normalize_token_attention_mask(attention_mask) if token_mask is None: raise ValueError("`attention_mask` cannot be None when creating additive attention biases.") query_length = token_mask.shape[-1] key_length = token_mask.shape[-1] expanded_token_mask = token_mask[:, None, None, :].expand(-1, 1, query_length, -1) global_attention_mask = _to_additive_attention_bias(expanded_token_mask, dtype) row_positions = torch.arange(query_length, device=global_attention_mask.device)[:, None] col_positions = torch.arange(key_length, device=global_attention_mask.device)[None, :] local_window = (row_positions - col_positions).abs() <= sliding_window min_dtype = torch.finfo(dtype).min sliding_window_mask = global_attention_mask.masked_fill( ~local_window.view(1, 1, query_length, key_length), min_dtype, ) return global_attention_mask, sliding_window_mask class ProkBertConfig(PretrainedConfig): r""" Configuration for the standalone ProkBERT ModernBERT-style encoder stack. Canonical config names follow the current ModernBERT conventions: - `layer_types` instead of `global_attn_every_n_layers` - nested `rope_parameters` keyed by attention type instead of a single flat RoPE dict - `tie_word_embeddings` explicitly tracked in the config `classifier_pooling` is also standardized as a single field and extended with the custom `"attention"` option used by the standalone ProkBERT sequence-classification head. The tokenizer metadata fields `kmer` and `shift` are also stored on the config so the model artifact keeps the sequence-tokenization contract alongside the architectural settings. Legacy names are still accepted for backward compatibility when loading older checkpoints/configs. """ model_type = "prokbert" keys_to_ignore_at_inference = ["past_key_values"] default_theta = {"full_attention": 160_000.0, "sliding_attention": 10_000.0} attribute_map = { "classification_dropout_rate": "classifier_dropout", "num_class_labels": "num_labels", "curricular_num_labels": "num_labels", "curricular_face_m": "curricular_margin", "curricular_face_s": "curricular_scale", "curriculum_hidden_size": "curricular_embedding_size", } @classmethod def _build_layer_types( cls, num_hidden_layers: int, global_attn_every_n_layers: int, ) -> list[str]: if global_attn_every_n_layers <= 0: raise ValueError("`global_attn_every_n_layers` must be a positive integer.") return [ "sliding_attention" if bool(i % global_attn_every_n_layers) else "full_attention" for i in range(num_hidden_layers) ] @classmethod def _normalize_rope_parameters( cls, rope_parameters: RopeParameters | dict | None, global_rope_theta: float, local_rope_theta: float, ) -> dict[str, dict]: default_rope_parameters = { "full_attention": {"rope_type": "default", "rope_theta": float(global_rope_theta)}, "sliding_attention": {"rope_type": "default", "rope_theta": float(local_rope_theta)}, } if rope_parameters is None: return default_rope_parameters rope_parameters = dict(rope_parameters) # Backward compatibility: older configs used a single flat dict plus `global_rope_theta` / `local_rope_theta`. if "rope_type" in rope_parameters or "rope_theta" in rope_parameters: shared_rope_type = rope_parameters.get("rope_type", "default") full_theta = float(rope_parameters.get("rope_theta", global_rope_theta)) return { "full_attention": { **{k: v for k, v in rope_parameters.items() if k != "rope_theta"}, "rope_type": shared_rope_type, "rope_theta": full_theta, }, "sliding_attention": { **{k: v for k, v in rope_parameters.items() if k != "rope_theta"}, "rope_type": shared_rope_type, "rope_theta": float(local_rope_theta), }, } normalized_rope_parameters = {} for layer_type in ("full_attention", "sliding_attention"): layer_params = rope_parameters.get(layer_type) if layer_params is None: layer_params = {"rope_type": "default"} else: layer_params = dict(layer_params) layer_params.setdefault("rope_type", "default") layer_params.setdefault("rope_theta", cls.default_theta[layer_type]) normalized_rope_parameters[layer_type] = layer_params return normalized_rope_parameters def __init__( self, vocab_size: int = 4608, hidden_size: int = 384, intermediate_size: int = 1152, num_hidden_layers: int = 6, num_attention_heads: int = 6, hidden_activation: str = "gelu", max_position_embeddings: int = 16384, initializer_range: float = 0.02, initializer_cutoff_factor: float = 2.0, norm_eps: float = 1e-6, norm_bias: bool = False, kmer: int = 6, shift: int = 1, pad_token_id: int = 0, eos_token_id: int = 3, bos_token_id: int = 2, cls_token_id: int = 2, sep_token_id: int = 3, attention_bias: bool = False, attention_dropout: float = 0.0, layer_types: list[str] | None = None, rope_parameters: RopeParameters | dict | None = None, local_attention: int = 256, embedding_dropout: float = 0.0, mlp_bias: bool = False, mlp_dropout: float = 0.0, decoder_bias: bool = True, classifier_pooling: Literal["attention", "cls", "mean"] = "attention", classifier_dropout: float = 0.0, classifier_bias: bool = False, classifier_activation: str = "gelu", deterministic_flash_attn: bool = False, sparse_prediction: bool = False, sparse_pred_ignore_index: int = -100, reference_compile: bool | None = None, repad_logits_with_grad: bool = False, norm_type: str = "rms", tie_word_embeddings: bool = True, num_labels: int = 2, problem_type: str | None = None, curricular_margin: float = 0.5, curricular_scale: float = 64.0, curricular_embedding_size: int | None = None, **kwargs, ): legacy_global_attn_every_n_layers = int(kwargs.pop("global_attn_every_n_layers", 1)) legacy_global_rope_theta = float(kwargs.pop("global_rope_theta", self.default_theta["full_attention"])) legacy_local_rope_theta = float(kwargs.pop("local_rope_theta", self.default_theta["sliding_attention"])) legacy_num_class_labels = kwargs.pop("num_class_labels", None) legacy_curricular_num_labels = kwargs.pop("curricular_num_labels", None) legacy_classifier_dropout = kwargs.pop("classification_dropout_rate", None) legacy_curricular_margin = kwargs.pop("curricular_face_m", None) legacy_curricular_scale = kwargs.pop("curricular_face_s", None) legacy_curricular_embedding_size = kwargs.pop("curriculum_hidden_size", None) kwargs.pop("bert_base_model", None) loaded_id2label = kwargs.get("id2label") if loaded_id2label is not None: num_labels = len(loaded_id2label) elif legacy_curricular_num_labels is not None: num_labels = int(legacy_curricular_num_labels) elif legacy_num_class_labels is not None: num_labels = int(legacy_num_class_labels) if legacy_classifier_dropout is not None: classifier_dropout = float(legacy_classifier_dropout) if legacy_curricular_margin is not None: curricular_margin = float(legacy_curricular_margin) if legacy_curricular_scale is not None: curricular_scale = float(legacy_curricular_scale) if curricular_embedding_size is None and legacy_curricular_embedding_size not in (None, -1): curricular_embedding_size = int(legacy_curricular_embedding_size) if layer_types is None: layer_types = self._build_layer_types( num_hidden_layers=num_hidden_layers, global_attn_every_n_layers=legacy_global_attn_every_n_layers, ) else: layer_types = list(layer_types) rope_parameters = self._normalize_rope_parameters( rope_parameters=rope_parameters, global_rope_theta=legacy_global_rope_theta, local_rope_theta=legacy_local_rope_theta, ) super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, cls_token_id=cls_token_id, sep_token_id=sep_token_id, tie_word_embeddings=tie_word_embeddings, num_labels=num_labels, problem_type=problem_type, **kwargs, ) self.kmer = int(kmer) self.shift = int(shift) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.layer_types = layer_types self.rope_parameters = rope_parameters self.local_attention = local_attention self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout self.decoder_bias = decoder_bias self.classifier_pooling = classifier_pooling self.classifier_dropout = classifier_dropout self.classifier_bias = classifier_bias self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index self.reference_compile = reference_compile self.repad_logits_with_grad = repad_logits_with_grad self.norm_type = norm_type self.tie_word_embeddings = tie_word_embeddings self.num_labels = num_labels self.problem_type = problem_type self.curricular_margin = curricular_margin self.curricular_scale = curricular_scale self.curricular_embedding_size = curricular_embedding_size if self.kmer <= 0: raise ValueError(f"`kmer` must be a positive integer, got {self.kmer}.") if self.shift <= 0: raise ValueError(f"`shift` must be a positive integer, got {self.shift}.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError( "`layer_types` must contain one entry per hidden layer: " f"expected {self.num_hidden_layers}, got {len(self.layer_types)}." ) invalid_layer_types = sorted(set(self.layer_types) - {"full_attention", "sliding_attention"}) if invalid_layer_types: raise ValueError( f"Unsupported values in `layer_types`: {invalid_layer_types}. " 'Allowed values are ["full_attention", "sliding_attention"].' ) if self.classifier_pooling not in ["attention", "cls", "mean"]: raise ValueError( f'Invalid value for `classifier_pooling`, should be one of ["attention", "cls", "mean"], but is {self.classifier_pooling}.' ) if self.norm_type not in {"rms", "layernorm"}: raise ValueError( f'Invalid value for `norm_type`, should be either "rms" or "layernorm", but is {self.norm_type}.' ) def get_rope_parameters(self, layer_type: str) -> dict: if layer_type not in {"full_attention", "sliding_attention"}: raise ValueError( f"Unsupported `layer_type`={layer_type!r}. Expected 'full_attention' or 'sliding_attention'." ) rope_params = self.rope_parameters.get(layer_type) if rope_params is None: rope_params = {"rope_type": "default", "rope_theta": self.default_theta[layer_type]} return rope_params @property def sliding_window(self) -> int: return self.local_attention // 2 @sliding_window.setter def sliding_window(self, value: int): self.local_attention = int(value) * 2 def to_dict(self): output = super().to_dict() output.pop("reference_compile", None) return output _CHECKPOINT_FOR_DOC = "example/prokbert-base" _CONFIG_FOR_DOC = "ProkBertConfig" PROK_BERT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads, etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch module and refer to the PyTorch documentation for general usage and behavior. Parameters: config ([`ProkBertConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the model weights; see [`PreTrainedModel.from_pretrained`] for weight loading. """ class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, bias: bool = False): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim)) if bias else None def forward(self, x: torch.Tensor) -> torch.Tensor: # root-mean-square normalization over last dim rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() x = x * rms if self.bias is not None: x = x + self.bias return self.weight * x def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (torch.Tensor): The query tensor. k (torch.Tensor): The key tensor. cos (torch.Tensor): The cosine part of the rotary embedding. sin (torch.Tensor): The sine part of the rotary embedding. position_ids (torch.Tensor, optional): Deprecated and unused. unsqueeze_dim (int, optional): The dimension along which to unsqueeze cos and sin. Returns: tuple(torch.Tensor): The rotated query and key tensors. """ original_dtype = q.dtype cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin) k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) return q_embed.to(original_dtype), k_embed.to(original_dtype) def eager_attention_forward( module: "ProkBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, **_kwargs, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # Apply rotary positional embedding to query and key. query, key = apply_rotary_pos_emb(query, key, cos, sin) scale = module.head_dim ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if local_attention != (-1, -1): attention_mask = sliding_window_mask attn_weights = attn_weights + attention_mask # Upcast attention to fp32 for stability. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, -1, dim) if output_attentions: return (attn_output, attn_weights) return (attn_output,) def flash_attention_forward( module: "ProkBertAttention", qkv: torch.Tensor, rotary_emb: "ProkBertUnpaddedRotaryEmbedding", cu_seqlens: torch.Tensor, max_seqlen: int, local_attention: Tuple[int, int], bs: int, dim: int, target_dtype: torch.dtype = torch.bfloat16, **_kwargs, ) -> Tuple[torch.Tensor]: # qkv: (total_seqlen, 3, nheads, headdim) qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: orig_dtype = qkv.dtype qkv = qkv.to(target_dtype) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=module.attention_dropout if module.training else 0.0, deterministic=module.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=module.attention_dropout if module.training else 0.0, deterministic=module.deterministic_flash_attn, window_size=local_attention, ) return (attn.view(bs, dim),) def sdpa_attention_forward( module: "ProkBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, ) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask attn_output = ( F.scaled_dot_product_attention( query, key, value, dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ) .transpose(1, 2) .contiguous() ) attn_output = attn_output.view(bs, -1, dim) return (attn_output,) PROK_BERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } def _unpad_prokbert_input( inputs: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Remove padding from input sequences. Args: inputs: (batch, seqlen, ...) or (batch, seqlen) attention_mask: (batch, seqlen), where 1 means valid and 0 means padding. position_ids: (batch, seqlen), optional position ids. labels: (batch, seqlen), optional labels. Returns: unpadded_inputs: Tensor of shape (total_nnz, ...) containing only valid tokens. indices: Tensor of indices corresponding to valid tokens. cu_seqlens: Cumulative sequence lengths of the unpadded tokens (shape: batch + 1). max_seqlen_in_batch: Maximum sequence length among all sequences (excluding padding). unpadded_position_ids: (total_nnz,) or None. unpadded_labels: (total_nnz,) or None. """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = int(seqlens_in_batch.max().item()) cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) if inputs.dim() == 2: unpadded_inputs = inputs.flatten()[indices] else: batch, seqlen, *rest = inputs.shape shape = batch * seqlen unpadded_inputs = inputs.view(shape, *rest)[indices] unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None unpadded_labels = labels.flatten()[indices] if labels is not None else None return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels def _pad_prokbert_output( inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int, ) -> torch.Tensor: """ Add padding back to the output tensor. Args: inputs: Tensor of shape (total_nnz, ...) containing outputs for only valid tokens. indices: Tensor of indices indicating positions of valid tokens. batch: Batch size. seqlen: Maximum sequence length (including padding). Returns: Tensor of shape (batch, seqlen, ...) with outputs in their original padded positions. """ if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen) else: _, *rest = inputs.shape output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) return padded_inputs class ApplyRotaryEmbUnpad(torch.autograd.Function): @staticmethod def forward( ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): # qkv: (total_nnz, 3, nheads, headdim) qkv = qkv.contiguous() total_nnz, _three, _nheads, headdim = qkv.shape # Combine the (3, nheads) dimensions for the first two channels to create a (total_nnz, 2*nheads, headdim) tensor. qk = qkv[:, :2].view(total_nnz, -1, headdim) apply_rotary( qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True, ) ctx.save_for_backward(cos, sin, cu_seqlens) ctx.max_seqlen = max_seqlen return qkv @staticmethod def backward(ctx, do): cos, sin, cu_seqlens = ctx.saved_tensors do = do.contiguous() total_nnz, _three, _nheads, headdim = do.shape dqk = do[:, :2].view(total_nnz, -1, headdim) apply_rotary( dqk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=False, inplace=True, conjugate=True, ) return do, None, None, None, None def apply_rotary_unpadded( qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): """ Apply rotary embeddings to an unpadded (packed) QKV tensor. Args: qkv: Tensor of shape (total_nnz, 3, nheads, headdim) for packed QKV. cos, sin: Precomputed cosine and sine caches. cu_seqlens: Cumulative sequence lengths (batch + 1,). max_seqlen: Maximum sequence length in the batch. Returns: Tensor with rotary embeddings applied. """ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) class ProkBertUnpaddedRotaryEmbedding(RotaryEmbedding): """ Rotary embeddings for unpadded (packed) sequences used in ProkBERT. """ def __init__( self, dim: int, base: float = 16000.0, max_seqlen: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): """ Args: dim: Dimension of each head. base: Base for the rotary frequency computation. max_seqlen: Maximum sequence length to precompute the cosine and sine cache. device: Device on which to create the cache. dtype: Data type for the cache. """ #super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) super().__init__(dim=dim, base=base, device=device, interleaved=False) self.max_seqlen = max_seqlen if max_seqlen is not None and device is not None and dtype is not None: self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) def forward( self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Apply rotary embeddings *inplace* to a packed QKV tensor. Args: qkv: Tensor of shape (total_nnz, 3, nheads, headdim). cu_seqlens: Cumulative sequence lengths tensor (batch + 1,). max_seqlen: Maximum sequence length in the current batch. Returns: Tensor with rotary embeddings applied. """ if max_seqlen is not None: self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) qkv = apply_rotary_unpadded( qkv, self._cos_cached, self._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) return qkv def extra_repr(self) -> str: return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" class ProkBertEmbeddings(nn.Module): """ Construct the embeddings from token embeddings, layer normalization, and dropout. """ def __init__(self, config: ProkBertConfig): super().__init__() self.config = config self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) if config.norm_type == "rms": self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) else: self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = nn.Dropout(config.embedding_dropout) def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Forward pass for the embeddings layer. Args: input_ids: Tensor of input token ids. inputs_embeds: Alternatively, a pre-computed embedding tensor. Returns: Tensor of embeddings with normalization and dropout applied. """ if inputs_embeds is not None: hidden_states = self.drop(self.norm(inputs_embeds)) else: hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids))) return hidden_states class ProkBertRotaryEmbedding(nn.Module): def __init__(self, config: ProkBertConfig, layer_type: str, device: Optional[torch.device] = None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.layer_type = layer_type rope_params = config.get_rope_parameters(layer_type) self.rope_type = rope_params["rope_type"] self.rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(config, device, layer_type=layer_type) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod def compute_default_rope_parameters( config: ProkBertConfig, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, layer_type: Optional[str] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation. """ current_layer_type = layer_type or "full_attention" rope_params = config.get_rope_parameters(current_layer_type) base = rope_params["rope_theta"] dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads attention_factor = 1.0 inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) return inv_freq, attention_factor def _dynamic_frequency_update(self, position_ids, device): """ Dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - Growing beyond the cached sequence length (allow scaling) 2 - The current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, layer_type=self.layer_type, ) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class ProkBertMLP(nn.Module): """Applies the GLU at the end of each ModernBERT layer. Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. """ def __init__(self, config: ProkBertConfig): super().__init__() self.config = config self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) self.act = ACT2FN[config.hidden_activation] self.drop = nn.Dropout(config.mlp_dropout) self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) return self.Wo(self.drop(self.act(input) * gate)) class ProkBertAttention(nn.Module): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is available, this module uses it to improve throughput. Otherwise, it falls back on PyTorch's SDPA (or eager) implementation. """ def __init__(self, config: ProkBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.layer_id = layer_id if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" ) self.attention_dropout = config.attention_dropout self.deterministic_flash_attn = config.deterministic_flash_attn self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.all_head_size = self.head_dim * self.num_heads self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) if layer_id is None: self.attention_type = "full_attention" else: self.attention_type = config.layer_types[layer_id] if self.attention_type == "sliding_attention": local_window = config.sliding_window if config._attn_implementation == "flash_attention_2": # FlashAttention uses inclusive local-attention boundaries. local_window = local_window + 1 self.local_attention = (local_window, local_window) max_position_embeddings = config.local_attention elif self.attention_type == "full_attention": self.local_attention = (-1, -1) max_position_embeddings = config.max_position_embeddings else: raise ValueError( f"Unsupported attention type {self.attention_type!r}. " "Expected 'full_attention' or 'sliding_attention'." ) rope_theta = float(config.get_rope_parameters(self.attention_type)["rope_theta"]) if config._attn_implementation == "flash_attention_2": self.rotary_emb = ProkBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta, ) else: self.rotary_emb = ProkBertRotaryEmbedding(config=config, layer_type=self.attention_type) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() self.pruned_heads = set() def forward( self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] if self.config._attn_implementation == "flash_attention_2": qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) attn_outputs = PROK_BERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, local_attention=self.local_attention, bs=bs, dim=self.all_head_size, output_attentions=output_attentions, **kwargs, ) hidden_states = attn_outputs[0] hidden_states = self.out_drop(self.Wo(hidden_states)) return (hidden_states,) + attn_outputs[1:] class ProkBertEncoderLayer(nn.Module): def __init__(self, config: ProkBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.attention_type = "full_attention" if layer_id is None else config.layer_types[layer_id] Norm = RMSNorm if config.norm_type == "rms" else nn.LayerNorm self.attn_norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp_norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.attn = ProkBertAttention(config=config, layer_id=layer_id) self.mlp = ProkBertMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: attn_outputs = self.attn( self.attn_norm(hidden_states), attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, ) hidden_states = hidden_states + attn_outputs[0] mlp_output = self.mlp(self.mlp_norm(hidden_states)) hidden_states = hidden_states + mlp_output return (hidden_states,) + attn_outputs[1:] PROK_BERT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads, etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch module and refer to the PyTorch documentation for general usage and behavior. Parameters: config ([`ProkBertConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the model weights; see [`PreTrainedModel.from_pretrained`] for weight loading. """ @add_start_docstrings( "The bare ProkBert Model outputting raw hidden-states without any specific head on top.", PROK_BERT_START_DOCSTRING, ) class ProkBertPreTrainedModel(PreTrainedModel): config_class = ProkBertConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ProkBertEmbeddings", "ProkBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = False def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: cutoff_factor = 3 def init_weight(module: nn.Module, std: float): nn.init.trunc_normal_( module.weight, mean=0.0, std=std, a=-cutoff_factor * std, b=cutoff_factor * std, ) if isinstance(module, nn.Linear): if module.bias is not None: nn.init.zeros_(module.bias) stds = { "in": self.config.initializer_range, "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), "embedding": self.config.initializer_range, "final_out": self.config.hidden_size ** -0.5, } if isinstance(module, ProkBertEmbeddings): init_weight(module.tok_embeddings, stds["embedding"]) elif isinstance(module, ProkBertMLP): init_weight(module.Wi, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ProkBertAttention): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ProkBertPredictionHead): init_weight(module.dense, stds["out"]) elif isinstance(module, (ProkBertForMaskedLM, ProkBertForMaskedLM2)): init_weight(module.decoder, stds["out"]) @classmethod def _autoset_attn_implementation( cls, config, use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, ): if config._attn_implementation_internal is None: config._attn_implementation_internal = "flash_attention_2" try: return cls._check_and_enable_flash_attn_2( config, torch_dtype=torch.float16, device_map=device_map, hard_check_only=False, check_device_map=check_device_map, ) except (ValueError, ImportError): config._attn_implementation_internal = None return super()._autoset_attn_implementation( config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch.float16, device_map=device_map, check_device_map=check_device_map, ) def resize_token_embeddings(self, *args, **kwargs): model_embeds = super().resize_token_embeddings(*args, **kwargs) return model_embeds @add_start_docstrings( "The bare ProkBert Model outputting raw hidden-states without any specific head on top.", PROK_BERT_START_DOCSTRING, ) class ProkBertModel(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): def __init__(self, config: ProkBertConfig): #config = self._autoset_attn_implementation(config) super().__init__(config) self.config = config self.embeddings = ProkBertEmbeddings(config) self.layers = nn.ModuleList( [ProkBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] ) if config.norm_type == "rms": self.final_norm = RMSNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) else: self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embeddings.tok_embeddings def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions and self.config._attn_implementation == "flash_attention_2": logger.warning_once( "`output_attentions=True` is not supported with flash_attention_2 in ProkBertModel. " "Falling back to `output_attentions=False`." ) output_attentions = False if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None or seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False restore_attn_implementation = None if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_prokbert_input( inputs=input_ids, attention_mask=attention_mask, ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_prokbert_input( inputs=inputs_embeds, attention_mask=attention_mask, ) else: if output_attentions and self.config._attn_implementation == "sdpa": restore_attn_implementation = self.config._attn_implementation if position_ids is None: position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions, ) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, sliding_window_mask, position_ids, cu_seqlens, max_seqlen, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) hidden_states = self.final_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if repad: hidden_states = _pad_prokbert_output(inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len) if all_hidden_states is not None: all_hidden_states = tuple( _pad_prokbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) for hs in all_hidden_states ) if restore_attn_implementation is not None: self.config._attn_implementation = restore_attn_implementation if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def _update_attention_mask( self, attention_mask: torch.Tensor, output_attentions: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: if output_attentions: if self.config._attn_implementation == "sdpa": logger.warning_once( "Outputting attentions is only supported with the 'eager' attention implementation, " 'not with "sdpa". Falling back to `attn_implementation="eager"`.' ) self.config._attn_implementation = "eager" elif self.config._attn_implementation != "eager": logger.warning_once( "Outputting attentions is only supported with the eager attention implementation, " f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`. ' "Setting `output_attentions=False`." ) return _build_bidirectional_attention_biases( attention_mask=attention_mask, dtype=self.dtype, sliding_window=self.config.sliding_window, ) class ProkBertPredictionHead(nn.Module): def __init__(self, config: ProkBertConfig): super().__init__() Norm = RMSNorm if getattr(config, "norm_type", "layernorm") == "rms" else nn.LayerNorm self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) self.act = ACT2FN[config.classifier_activation] self.norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @add_start_docstrings( "The ProkBert Model with a decoder head on top that is used for masked language modeling.", PROK_BERT_START_DOCSTRING, ) class ProkBertForMaskedLM(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} def __init__(self, config: ProkBertConfig): super().__init__(config) self.config = config self.model = ProkBertModel(config) self.head = ProkBertPredictionHead(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index self.post_init() def get_output_embeddings(self): return self.decoder def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings: nn.Linear): self.decoder = new_embeddings def _prediction_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(hidden_states)) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None or seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_prokbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_prokbert_input( inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels, ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] logits_are_sparse = False if self.sparse_prediction and labels is not None: labels = labels.view(-1) last_hidden_state = last_hidden_state.view(labels.shape[0], -1) mask_tokens = labels != self.sparse_pred_ignore_index last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] logits_are_sparse = True logits = self._prediction_logits(last_hidden_state) loss = None if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) should_repad_logits = ( self.config._attn_implementation == "flash_attention_2" and indices is not None and batch_size is not None and seq_len is not None and not logits_are_sparse ) if should_repad_logits: with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_prokbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class ProkBertForSequenceClassification(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): """Standalone ProkBERT sequence classifier with mask-aware pooling.""" def __init__(self, config: ProkBertConfig): super().__init__(config) self.num_labels = int(config.num_labels) self.config = config self.model = ProkBertModel(config) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) # Kept for backward compatibility with existing sequence-classification checkpoints. self.weighting_layer = nn.Linear(config.hidden_size, 1) self.dropout = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.post_init() def _init_weights(self, module: nn.Module): super()._init_weights(module) if module is self.weighting_layer: nn.init.zeros_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) if module is self.classifier: nn.init.xavier_uniform_(module.weight, gain=1.0) module.weight.data /= math.sqrt(self.classifier.in_features) if module.bias is not None: nn.init.zeros_(module.bias) def _pool_cls(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: if token_mask is None: return sequence_output[:, 0] first_token_indices = token_mask.to(dtype=torch.long).argmax(dim=-1) batch_indices = torch.arange(sequence_output.shape[0], device=sequence_output.device) return sequence_output[batch_indices, first_token_indices] def _pool_mean(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: if token_mask is None: return sequence_output.mean(dim=1) weights = token_mask.unsqueeze(-1).to(dtype=sequence_output.dtype) denom = weights.sum(dim=1).clamp(min=1.0) return (sequence_output * weights).sum(dim=1) / denom def _pool_attention(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: scores = self.weighting_layer(sequence_output) if token_mask is not None: empty_rows = token_mask.sum(dim=1) == 0 if empty_rows.any(): token_mask = token_mask.clone() token_mask[empty_rows, 0] = True scores = scores.masked_fill(~token_mask.unsqueeze(-1), torch.finfo(scores.dtype).min) weights = torch.softmax(scores.float(), dim=1).to(dtype=sequence_output.dtype) return torch.sum(weights * sequence_output, dim=1) def _pool_sequence(self, sequence_output: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: token_mask = _normalize_token_attention_mask( attention_mask, batch_size=sequence_output.shape[0], seq_len=sequence_output.shape[1], device=sequence_output.device, ) pooling = self.config.classifier_pooling if pooling == "attention": return self._pool_attention(sequence_output, token_mask) if pooling == "mean": return self._pool_mean(sequence_output, token_mask) if pooling == "cls": return self._pool_cls(sequence_output, token_mask) raise ValueError( f"Unsupported `classifier_pooling`={pooling!r}. Expected one of ['attention', 'cls', 'mean']." ) def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif labels.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.long, torch.uint8): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: return loss_fct(logits.squeeze(), labels.squeeze().to(logits.dtype)) return loss_fct(logits, labels.to(logits.dtype)) if self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() return loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() return loss_fct(logits, labels.to(logits.dtype)) raise ValueError( f"Unsupported `problem_type`={self.config.problem_type!r}. " "Expected 'regression', 'single_label_classification', or 'multi_label_classification'." ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, sliding_window_mask: torch.Tensor = None, position_ids: torch.LongTensor = None, inputs_embeds: torch.Tensor = None, labels: torch.Tensor = None, indices: torch.Tensor = None, cu_seqlens: torch.Tensor = None, max_seqlen: int = None, batch_size: int = None, seq_len: int = None, output_attentions: bool = None, output_hidden_states: bool = None, return_dict: bool = None, **kwargs, ) -> SequenceClassifierOutput: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if sequence_output.dim() == 2: if indices is None or batch_size is None or seq_len is None: raise ValueError( "Received unpadded hidden states from `ProkBertModel`, but `indices`, `batch_size`, and " "`seq_len` were not provided to repad them for sequence classification." ) sequence_output = _pad_prokbert_output( inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_len, ) pooled_output = self._pool_sequence(sequence_output, attention_mask) pooled_output = self.norm(pooled_output) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: loss = self._compute_loss(logits, labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @dataclass class CurricularSequenceClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None embeddings: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class CurricularFace(nn.Module): def __init__(self, in_features: int, out_features: int, m: float = 0.5, s: float = 64.0, ema_alpha: float = 0.01): super().__init__() self.in_features = int(in_features) self.out_features = int(out_features) self.m = float(m) self.s = float(s) self.ema_alpha = float(ema_alpha) self.cos_m = math.cos(self.m) self.sin_m = math.sin(self.m) self.threshold = math.cos(math.pi - self.m) self.mm = math.sin(math.pi - self.m) * self.m self.kernel = nn.Parameter(torch.empty(self.in_features, self.out_features)) self.register_buffer("t", torch.zeros(1, dtype=torch.float32)) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.kernel) self.t.zero_() def cosine(self, embeddings: torch.Tensor) -> torch.Tensor: with _autocast_disabled(embeddings.device.type): x = F.normalize(embeddings.float(), p=2.0, dim=1, eps=1e-12) w = F.normalize(self.kernel.float(), p=2.0, dim=0, eps=1e-12) cos_theta = F.linear(x, w.t()).clamp(-1.0, 1.0) return cos_theta def inference_logits(self, embeddings: torch.Tensor) -> torch.Tensor: return self.cosine(embeddings) * self.s def margin_logits_from_cosine( self, cos_theta: torch.Tensor, labels: torch.LongTensor, update_t: bool = False, ) -> torch.Tensor: labels = labels.reshape(-1).long() target = cos_theta.gather(1, labels.unsqueeze(1)) sin_theta = torch.sqrt((1.0 - target.square()).clamp(min=0.0)) cos_theta_m = target * self.cos_m - sin_theta * self.sin_m hard_mask = cos_theta > cos_theta_m final_target = torch.where( target > self.threshold, cos_theta_m, target - self.mm, ) if update_t: with torch.no_grad(): target_mean = target.mean().to(dtype=self.t.dtype).view_as(self.t) self.t.lerp_(target_mean, self.ema_alpha) t = self.t.to(device=cos_theta.device, dtype=cos_theta.dtype) adjusted = torch.where(hard_mask, cos_theta * (t + cos_theta), cos_theta) adjusted = adjusted.scatter(1, labels.unsqueeze(1), final_target) return adjusted * self.s class ProkBertForCurricularClassification(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): """ProkBERT sequence classifier with CurricularFace logits for single-label classification.""" def __init__(self, config: ProkBertConfig): super().__init__(config) self.config = config self.num_labels = int(config.num_labels) if self.num_labels < 2: raise ValueError( "`ProkBertForCurricularClassification` requires `config.num_labels >= 2`. " "CurricularFace is intended for single-label classification." ) if self.config.problem_type is None: self.config.problem_type = "single_label_classification" elif self.config.problem_type != "single_label_classification": raise ValueError( "`ProkBertForCurricularClassification` only supports `problem_type='single_label_classification'`." ) self.model = ProkBertModel(config) self.weighting_layer = nn.Linear(config.hidden_size, 1) self.dropout = nn.Dropout(config.classifier_dropout) use_projection = config.curricular_embedding_size not in (None, -1) embedding_dim = config.hidden_size if not use_projection else int(config.curricular_embedding_size) self.linear = nn.Linear(config.hidden_size, embedding_dim) if use_projection else nn.Identity() self.curricular_face = CurricularFace( in_features=embedding_dim, out_features=self.num_labels, m=float(config.curricular_margin), s=float(config.curricular_scale), ) self.loss_fct = nn.CrossEntropyLoss() self.post_init() with torch.no_grad(): nn.init.zeros_(self.weighting_layer.weight) if self.weighting_layer.bias is not None: nn.init.zeros_(self.weighting_layer.bias) if isinstance(self.linear, nn.Linear): initialize_linear_kaiming(self.linear) def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def _pool_cls(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: if token_mask is None: return sequence_output[:, 0] first_token_indices = token_mask.to(dtype=torch.long).argmax(dim=-1) batch_indices = torch.arange(sequence_output.shape[0], device=sequence_output.device) return sequence_output[batch_indices, first_token_indices] def _pool_mean(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: if token_mask is None: return sequence_output.mean(dim=1) weights = token_mask.unsqueeze(-1).to(dtype=sequence_output.dtype) denom = weights.sum(dim=1).clamp(min=1.0) return (sequence_output * weights).sum(dim=1) / denom def _pool_attention(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor: scores = self.weighting_layer(sequence_output) if token_mask is not None: empty_rows = token_mask.sum(dim=1) == 0 if empty_rows.any(): token_mask = token_mask.clone() token_mask[empty_rows, 0] = True scores = scores.masked_fill(~token_mask.unsqueeze(-1), torch.finfo(scores.dtype).min) weights = torch.softmax(scores.float(), dim=1).to(dtype=sequence_output.dtype) return torch.sum(weights * sequence_output, dim=1) def _pool_sequence(self, sequence_output: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: token_mask = _normalize_token_attention_mask( attention_mask, batch_size=sequence_output.shape[0], seq_len=sequence_output.shape[1], device=sequence_output.device, ) pooling = self.config.classifier_pooling if pooling == "attention": return self._pool_attention(sequence_output, token_mask) if pooling == "mean": return self._pool_mean(sequence_output, token_mask) if pooling == "cls": return self._pool_cls(sequence_output, token_mask) raise ValueError( f"Unsupported `classifier_pooling`={pooling!r}. Expected one of ['attention', 'cls', 'mean']." ) def _repad_sequence_output_if_needed( self, sequence_output: torch.Tensor, indices: Optional[torch.Tensor], batch_size: Optional[int], seq_len: Optional[int], ) -> torch.Tensor: if sequence_output.dim() != 2: return sequence_output if indices is None or batch_size is None or seq_len is None: raise ValueError( "Received unpadded hidden states from `ProkBertModel`, but `indices`, `batch_size`, and " "`seq_len` were not provided to repad them for curricular classification." ) return _pad_prokbert_output( inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_len, ) def _compute_embeddings( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, apply_dropout: bool = True, ) -> tuple[torch.Tensor, BaseModelOutput]: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) sequence_output = self._repad_sequence_output_if_needed( outputs.last_hidden_state, indices=indices, batch_size=batch_size, seq_len=seq_len, ) pooled_output = self._pool_sequence(sequence_output, attention_mask) if apply_dropout: pooled_output = self.dropout(pooled_output) embeddings = self.linear(pooled_output) return embeddings, outputs def encode( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, normalize: bool = True, ) -> torch.Tensor: embeddings, _ = self._compute_embeddings( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, apply_dropout=False, ) return l2_norm(embeddings, axis=1) if normalize else embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_embeddings: bool = False, normalize_embeddings: bool = True, **kwargs, ) -> Union[Tuple, CurricularSequenceClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict embeddings, outputs = self._compute_embeddings( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, apply_dropout=self.training, ) exported_embeddings = None if return_embeddings: exported_embeddings = l2_norm(embeddings, axis=1) if normalize_embeddings else embeddings cos_theta = self.curricular_face.cosine(embeddings) logits = cos_theta * self.curricular_face.s loss = None if labels is not None: labels = labels.view(-1).long() train_logits = self.curricular_face.margin_logits_from_cosine( cos_theta, labels, update_t=self.training, ) loss = self.loss_fct(train_logits, labels) if not return_dict: out = (logits,) if return_embeddings: out = out + (exported_embeddings,) if output_hidden_states: out = out + (outputs.hidden_states,) if output_attentions: out = out + (outputs.attentions,) return ((loss,) + out) if loss is not None else out return CurricularSequenceClassifierOutput( loss=loss, logits=logits, embeddings=exported_embeddings, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class ProkBertForMaskedLM2(_SafeFromPretrainedMixin, ProkBertPreTrainedModel): _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} def __init__(self, config): super().__init__(config) self.config = config self.model = ProkBertModel(config) self.head = ProkBertPredictionHead(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) # for sparse‐integer masking (legacy) self.sparse_prediction = config.sparse_prediction self.sparse_pred_ignore_index = config.sparse_pred_ignore_index self.post_init() def get_output_embeddings(self): return self.decoder def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings: nn.Linear): self.decoder = new_embeddings @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor]= None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, labels_dist: Optional[torch.FloatTensor] = None, loss_mask: Optional[torch.BoolTensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # 1) Optional unpad for flash_attention_2 if self.config._attn_implementation == "flash_attention_2" \ and indices is None and cu_seqlens is None and max_seqlen is None: # infer batch_size, seq_len if batch_size is None or seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] # EXPLICIT device pick device = input_ids.device if input_ids is not None else inputs_embeds.device attention_mask = attention_mask if attention_mask is not None else \ torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = \ _unpad_prokbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = \ _unpad_prokbert_input( inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) # 2) Core encoder outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # (B,L,H) or packed (N,H) # 3) Legacy sparse integer mask if self.sparse_prediction and labels is not None: flat_labels = labels.view(-1) flat_hidden = sequence_output.view(flat_labels.shape[0], -1) mask_tokens = flat_labels != self.sparse_pred_ignore_index sequence_output = flat_hidden[mask_tokens] labels = flat_labels[mask_tokens] hidden = self.head(sequence_output) logits = self.decoder(hidden) loss = None V = self.config.vocab_size # 5a) Integer‐label MLM if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, V), labels.view(-1)) # 5b) Soft‐distribution MLM (no re‐pad) elif labels_dist is not None and loss_mask is not None: B, L = loss_mask.shape flat_mask = loss_mask.view(-1) # (B*L,) flat_dist = labels_dist.view(-1, V) # (B*L, V) # packed by attention_mask if logits.dim() == 2 and logits.shape[0] != flat_mask.sum().item(): full_attn = attention_mask.view(-1) # (B*L,) assert logits.shape[0] == full_attn.sum().item() dist_attn = flat_dist[full_attn] # (Natt, V) mask_in_attn = flat_mask[full_attn] # (Natt,) pred = logits[mask_in_attn] # (N_mask, V) targ = dist_attn[mask_in_attn] # (N_mask, V) # packed exactly by loss_mask elif logits.dim() == 2 and logits.shape[0] == flat_mask.sum().item(): pred = logits targ = flat_dist[flat_mask] # full (B,L,V) else: flat_logits = logits.view(-1, V) # (B*L, V) pred = flat_logits[flat_mask] # (N_mask, V) targ = flat_dist[flat_mask] # (N_mask, V) eps = 1e-8 targ = targ.clamp_min(eps) targ = targ / targ.sum(dim=-1, keepdim=True) targ = targ.to(pred.dtype).detach() logp = F.log_softmax(pred, dim=-1) loss = F.kl_div(logp, targ, reduction="batchmean") should_repad_logits = ( self.config._attn_implementation == "flash_attention_2" and indices is not None and batch_size is not None and seq_len is not None and not (self.sparse_prediction and labels is not None) ) if should_repad_logits: with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_prokbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) # 6) Return if not return_dict: out = (logits,) + outputs[1:] return ((loss,) + out) if loss is not None else out return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )