nbrg-transformers / models2.py
ligeti's picture
Updating the curriculum head and the base mini-c model
35a7e78
from typing import Optional, Tuple, Union, Dict, Literal, Callable
import math
import os
from contextlib import nullcontext
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.utils import ModelOutput, logging
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.configuration_utils import PretrainedConfig
from transformers.utils.doc import add_code_sample_docstrings, add_start_docstrings
from transformers.utils.import_utils import is_triton_available, is_flash_attn_2_available
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput, MaskedLMOutput
try:
from transformers.utils.hub import cached_file
except ImportError: # transformers<5 compatibility
from transformers.utils import cached_file # type: ignore
try:
from transformers.modeling_rope_utils import RopeParameters
except ImportError:
RopeParameters = object
if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.ops.triton.rotary import apply_rotary
else:
RotaryEmbedding = object
logger = logging.get_logger(__name__)
_HF_CONFIG_LOAD_KWARGS = {
"cache_dir",
"force_download",
"local_files_only",
"token",
"revision",
"subfolder",
"proxies",
}
_HF_NON_MODEL_INIT_KWARGS = {
"trust_remote_code",
"_from_auto",
"adapter_kwargs",
}
_HF_MODEL_INIT_BLACKLIST = {
"device_map",
"low_cpu_mem_usage",
"offload_folder",
"offload_state_dict",
"max_memory",
"quantization_config",
"tp_plan",
"tp_size",
"weights_only",
#"use_flash_attention_2",
#"attn_implementation",
#"torch_dtype"
}
logger = logging.get_logger(__name__)
def _split_pretrained_kwargs(kwargs):
kwargs = dict(kwargs)
for k in _HF_NON_MODEL_INIT_KWARGS:
kwargs.pop(k, None)
config_load_kwargs = {
k: kwargs.pop(k) for k in list(kwargs) if k in _HF_CONFIG_LOAD_KWARGS
}
use_safetensors = kwargs.pop("use_safetensors", None)
weights_only = kwargs.pop("weights_only", True)
return config_load_kwargs, use_safetensors, weights_only, kwargs
def _resolve_weights_file(pretrained_model_name_or_path, use_safetensors=None, **load_kwargs) -> str:
pretrained_model_name_or_path = os.fspath(pretrained_model_name_or_path)
if use_safetensors is True:
candidates = ("model.safetensors",)
elif use_safetensors is False:
candidates = ("pytorch_model.bin",)
else:
candidates = ("model.safetensors", "pytorch_model.bin")
subfolder = load_kwargs.get("subfolder")
if os.path.isdir(pretrained_model_name_or_path):
base_dir = (
os.path.join(pretrained_model_name_or_path, subfolder)
if subfolder
else pretrained_model_name_or_path
)
for name in candidates:
path = os.path.join(base_dir, name)
if os.path.exists(path):
return path
for name in candidates:
try:
path = cached_file(pretrained_model_name_or_path, name, **load_kwargs)
if path is not None:
return path
except Exception:
pass
raise FileNotFoundError(
f"No checkpoint file found in {pretrained_model_name_or_path!r} "
f"(candidates: {', '.join(candidates)})"
)
def _read_state_dict(weights_path, weights_only: bool = True) -> dict[str, torch.Tensor]:
weights_path = os.fspath(weights_path)
if weights_path.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load_file
return safe_load_file(weights_path, device="cpu")
try:
return torch.load(weights_path, map_location="cpu", weights_only=weights_only)
except TypeError: # older torch versions
return torch.load(weights_path, map_location="cpu")
def _align_state_dict_with_base_prefix(model: nn.Module, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
state_dict = dict(state_dict)
base_prefix = getattr(model, "base_model_prefix", None)
if not base_prefix:
return state_dict
prefix = f"{base_prefix}."
model_keys = set(model.state_dict().keys())
model_has_prefix = any(k.startswith(prefix) for k in model_keys)
ckpt_has_prefix = any(k.startswith(prefix) for k in state_dict.keys())
# Loading a wrapped checkpoint (model.*) into the bare encoder.
if ckpt_has_prefix and not model_has_prefix:
stripped = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
return stripped if stripped else state_dict
# Loading a bare encoder checkpoint into a wrapped task model.
if model_has_prefix and not ckpt_has_prefix:
remapped = {}
for k, v in state_dict.items():
prefixed = f"{prefix}{k}"
remapped[prefixed if prefixed in model_keys else k] = v
return remapped
return state_dict
class _SafeFromPretrainedMixin:
@classmethod
def _adapt_state_dict(cls, model, state_dict):
return state_dict
@staticmethod
def _filter_keys_with_patterns(keys, patterns):
if not patterns:
return list(keys)
import re
compiled = [re.compile(p) if isinstance(p, str) else p for p in patterns]
return [k for k in keys if not any(p.search(k) for p in compiled)]
@classmethod
def _resolve_config_and_init_kwargs(
cls,
pretrained_model_name_or_path,
config,
config_load_kwargs,
other_kwargs,
):
if isinstance(config, PretrainedConfig):
return config, other_kwargs
if config is None:
config_source = pretrained_model_name_or_path
elif isinstance(config, (str, os.PathLike)):
config_source = config
else:
raise TypeError(
"`config` must be None, a path-like object, or an instance of PretrainedConfig"
)
config, init_kwargs = cls.config_class.from_pretrained(
config_source,
return_unused_kwargs=True,
**config_load_kwargs,
**other_kwargs,
)
return config, init_kwargs
@staticmethod
def _remove_mismatched_keys(model, state_dict):
state_dict = dict(state_dict)
model_state = model.state_dict()
mismatched_keys = []
for key in list(state_dict.keys()):
if key not in model_state:
continue
loaded_value = state_dict[key]
model_value = model_state[key]
if not isinstance(loaded_value, torch.Tensor) or not isinstance(model_value, torch.Tensor):
continue
if tuple(loaded_value.shape) != tuple(model_value.shape):
mismatched_keys.append((key, tuple(loaded_value.shape), tuple(model_value.shape)))
state_dict.pop(key)
return state_dict, mismatched_keys
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
output_loading_info = kwargs.pop("output_loading_info", False)
state_dict = kwargs.pop("state_dict", None)
config = kwargs.pop("config", None)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
strict = kwargs.pop("strict", False)
config_load_kwargs, use_safetensors, weights_only, other_kwargs = _split_pretrained_kwargs(kwargs)
config, init_kwargs = cls._resolve_config_and_init_kwargs(
pretrained_model_name_or_path=pretrained_model_name_or_path,
config=config,
config_load_kwargs=config_load_kwargs,
other_kwargs=other_kwargs,
)
'''
config = cls._autoset_attn_implementation(
config,
use_flash_attention_2=bool(init_kwargs.get("use_flash_attention_2", False)),
torch_dtype=init_kwargs.get("torch_dtype", None),
device_map=init_kwargs.get("device_map", None),
)
'''
for k in _HF_MODEL_INIT_BLACKLIST:
init_kwargs.pop(k, None)
model = cls(config, *model_args, **init_kwargs)
if state_dict is None:
weights_path = _resolve_weights_file(
pretrained_model_name_or_path,
use_safetensors=use_safetensors,
**config_load_kwargs,
)
state_dict = _read_state_dict(
weights_path,
weights_only=True if weights_only is None else bool(weights_only),
)
if not isinstance(state_dict, dict):
raise TypeError(f"Expected a state dict, got {type(state_dict).__name__}")
state_dict = _align_state_dict_with_base_prefix(model, state_dict)
state_dict = cls._adapt_state_dict(model, state_dict)
mismatched_keys = []
if ignore_mismatched_sizes:
state_dict, mismatched_keys = cls._remove_mismatched_keys(model, state_dict)
incompatible = model.load_state_dict(state_dict, strict=strict)
if hasattr(model, "tie_weights"):
model.tie_weights()
model.eval()
missing_keys = cls._filter_keys_with_patterns(
list(incompatible.missing_keys),
getattr(model, "_keys_to_ignore_on_load_missing", None),
)
unexpected_keys = cls._filter_keys_with_patterns(
list(incompatible.unexpected_keys),
getattr(model, "_keys_to_ignore_on_load_unexpected", None),
)
info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": [],
}
return (model, info) if output_loading_info else model
def l2_norm(input, axis=1, epsilon=1e-12):
norm = torch.norm(input, 2, axis, True)
norm = torch.clamp(norm, min=epsilon) # Avoid zero division
output = torch.div(input, norm)
return output
def initialize_linear_kaiming(layer: nn.Linear):
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, nonlinearity='linear')
if layer.bias is not None:
nn.init.zeros_(layer.bias)
def _autocast_disabled(device_type: str):
try:
return torch.amp.autocast(device_type=device_type, enabled=False)
except (AttributeError, TypeError):
if device_type == "cuda":
return torch.cuda.amp.autocast(enabled=False)
if device_type == "cpu" and hasattr(torch, "cpu") and hasattr(torch.cpu, "amp"):
return torch.cpu.amp.autocast(enabled=False)
return nullcontext()
def _normalize_token_attention_mask(
attention_mask: Optional[torch.Tensor],
*,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
device: Optional[torch.device] = None,
) -> Optional[torch.Tensor]:
"""Convert common attention-mask layouts to a `(batch, seq_len)` boolean keep-mask.
Supported inputs are 2D token masks, 3D masks of shape `(batch, 1, seq_len)`, and 4D masks
of shape `(batch, 1, query_len, key_len)` used by eager/SDPA attention. Returned values use
`True` for valid (non-padding) tokens.
"""
if attention_mask is None:
if batch_size is None or seq_len is None:
return None
return torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
mask = attention_mask
if mask.dim() == 4:
if mask.dtype == torch.bool:
mask = mask[:, 0].any(dim=-2)
else:
mask = mask[:, 0].amax(dim=-2) >= 0
return mask.to(dtype=torch.bool)
if mask.dim() == 3:
if mask.size(1) != 1:
raise ValueError(
"3D attention masks must have shape (batch, 1, seq_len) when used in ProkBert."
)
mask = mask[:, 0, :]
if mask.dim() != 2:
raise ValueError(
"Attention masks for ProkBert must be 2D, 3D `(batch, 1, seq_len)`, or 4D "
"`(batch, 1, query_len, key_len)`."
)
if mask.dtype == torch.bool:
return mask
if mask.is_floating_point():
if mask.numel() == 0:
return mask.to(dtype=torch.bool)
return (mask >= 0) if torch.any(mask < 0) else (mask > 0)
return mask != 0
def _to_additive_attention_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Convert a boolean keep-mask to the additive attention-bias format expected by eager attention."""
min_dtype = torch.finfo(dtype).min
bias = torch.zeros(mask.shape, device=mask.device, dtype=dtype)
return bias.masked_fill(~mask, min_dtype)
def _build_bidirectional_attention_biases(
attention_mask: torch.Tensor,
dtype: torch.dtype,
sliding_window: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create full-attention and sliding-window additive masks for eager/SDPA attention."""
if attention_mask.dim() == 4:
if attention_mask.dtype == torch.bool:
global_attention_mask = _to_additive_attention_bias(attention_mask, dtype)
else:
global_attention_mask = attention_mask.to(dtype=dtype)
query_length = global_attention_mask.shape[-2]
key_length = global_attention_mask.shape[-1]
else:
token_mask = _normalize_token_attention_mask(attention_mask)
if token_mask is None:
raise ValueError("`attention_mask` cannot be None when creating additive attention biases.")
query_length = token_mask.shape[-1]
key_length = token_mask.shape[-1]
expanded_token_mask = token_mask[:, None, None, :].expand(-1, 1, query_length, -1)
global_attention_mask = _to_additive_attention_bias(expanded_token_mask, dtype)
row_positions = torch.arange(query_length, device=global_attention_mask.device)[:, None]
col_positions = torch.arange(key_length, device=global_attention_mask.device)[None, :]
local_window = (row_positions - col_positions).abs() <= sliding_window
min_dtype = torch.finfo(dtype).min
sliding_window_mask = global_attention_mask.masked_fill(
~local_window.view(1, 1, query_length, key_length),
min_dtype,
)
return global_attention_mask, sliding_window_mask
class ProkBertConfig(PretrainedConfig):
r"""
Configuration for the standalone ProkBERT ModernBERT-style encoder stack.
Canonical config names follow the current ModernBERT conventions:
- `layer_types` instead of `global_attn_every_n_layers`
- nested `rope_parameters` keyed by attention type instead of a single flat RoPE dict
- `tie_word_embeddings` explicitly tracked in the config
`classifier_pooling` is also standardized as a single field and extended with the custom
`"attention"` option used by the standalone ProkBERT sequence-classification head.
The tokenizer metadata fields `kmer` and `shift` are also stored on the config so the model artifact keeps
the sequence-tokenization contract alongside the architectural settings.
Legacy names are still accepted for backward compatibility when loading older checkpoints/configs.
"""
model_type = "prokbert"
keys_to_ignore_at_inference = ["past_key_values"]
default_theta = {"full_attention": 160_000.0, "sliding_attention": 10_000.0}
attribute_map = {
"classification_dropout_rate": "classifier_dropout",
"num_class_labels": "num_labels",
"curricular_num_labels": "num_labels",
"curricular_face_m": "curricular_margin",
"curricular_face_s": "curricular_scale",
"curriculum_hidden_size": "curricular_embedding_size",
}
@classmethod
def _build_layer_types(
cls,
num_hidden_layers: int,
global_attn_every_n_layers: int,
) -> list[str]:
if global_attn_every_n_layers <= 0:
raise ValueError("`global_attn_every_n_layers` must be a positive integer.")
return [
"sliding_attention" if bool(i % global_attn_every_n_layers) else "full_attention"
for i in range(num_hidden_layers)
]
@classmethod
def _normalize_rope_parameters(
cls,
rope_parameters: RopeParameters | dict | None,
global_rope_theta: float,
local_rope_theta: float,
) -> dict[str, dict]:
default_rope_parameters = {
"full_attention": {"rope_type": "default", "rope_theta": float(global_rope_theta)},
"sliding_attention": {"rope_type": "default", "rope_theta": float(local_rope_theta)},
}
if rope_parameters is None:
return default_rope_parameters
rope_parameters = dict(rope_parameters)
# Backward compatibility: older configs used a single flat dict plus `global_rope_theta` / `local_rope_theta`.
if "rope_type" in rope_parameters or "rope_theta" in rope_parameters:
shared_rope_type = rope_parameters.get("rope_type", "default")
full_theta = float(rope_parameters.get("rope_theta", global_rope_theta))
return {
"full_attention": {
**{k: v for k, v in rope_parameters.items() if k != "rope_theta"},
"rope_type": shared_rope_type,
"rope_theta": full_theta,
},
"sliding_attention": {
**{k: v for k, v in rope_parameters.items() if k != "rope_theta"},
"rope_type": shared_rope_type,
"rope_theta": float(local_rope_theta),
},
}
normalized_rope_parameters = {}
for layer_type in ("full_attention", "sliding_attention"):
layer_params = rope_parameters.get(layer_type)
if layer_params is None:
layer_params = {"rope_type": "default"}
else:
layer_params = dict(layer_params)
layer_params.setdefault("rope_type", "default")
layer_params.setdefault("rope_theta", cls.default_theta[layer_type])
normalized_rope_parameters[layer_type] = layer_params
return normalized_rope_parameters
def __init__(
self,
vocab_size: int = 4608,
hidden_size: int = 384,
intermediate_size: int = 1152,
num_hidden_layers: int = 6,
num_attention_heads: int = 6,
hidden_activation: str = "gelu",
max_position_embeddings: int = 16384,
initializer_range: float = 0.02,
initializer_cutoff_factor: float = 2.0,
norm_eps: float = 1e-6,
norm_bias: bool = False,
kmer: int = 6,
shift: int = 1,
pad_token_id: int = 0,
eos_token_id: int = 3,
bos_token_id: int = 2,
cls_token_id: int = 2,
sep_token_id: int = 3,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: list[str] | None = None,
rope_parameters: RopeParameters | dict | None = None,
local_attention: int = 256,
embedding_dropout: float = 0.0,
mlp_bias: bool = False,
mlp_dropout: float = 0.0,
decoder_bias: bool = True,
classifier_pooling: Literal["attention", "cls", "mean"] = "attention",
classifier_dropout: float = 0.0,
classifier_bias: bool = False,
classifier_activation: str = "gelu",
deterministic_flash_attn: bool = False,
sparse_prediction: bool = False,
sparse_pred_ignore_index: int = -100,
reference_compile: bool | None = None,
repad_logits_with_grad: bool = False,
norm_type: str = "rms",
tie_word_embeddings: bool = True,
num_labels: int = 2,
problem_type: str | None = None,
curricular_margin: float = 0.5,
curricular_scale: float = 64.0,
curricular_embedding_size: int | None = None,
**kwargs,
):
legacy_global_attn_every_n_layers = int(kwargs.pop("global_attn_every_n_layers", 1))
legacy_global_rope_theta = float(kwargs.pop("global_rope_theta", self.default_theta["full_attention"]))
legacy_local_rope_theta = float(kwargs.pop("local_rope_theta", self.default_theta["sliding_attention"]))
legacy_num_class_labels = kwargs.pop("num_class_labels", None)
legacy_curricular_num_labels = kwargs.pop("curricular_num_labels", None)
legacy_classifier_dropout = kwargs.pop("classification_dropout_rate", None)
legacy_curricular_margin = kwargs.pop("curricular_face_m", None)
legacy_curricular_scale = kwargs.pop("curricular_face_s", None)
legacy_curricular_embedding_size = kwargs.pop("curriculum_hidden_size", None)
kwargs.pop("bert_base_model", None)
loaded_id2label = kwargs.get("id2label")
if loaded_id2label is not None:
num_labels = len(loaded_id2label)
elif legacy_curricular_num_labels is not None:
num_labels = int(legacy_curricular_num_labels)
elif legacy_num_class_labels is not None:
num_labels = int(legacy_num_class_labels)
if legacy_classifier_dropout is not None:
classifier_dropout = float(legacy_classifier_dropout)
if legacy_curricular_margin is not None:
curricular_margin = float(legacy_curricular_margin)
if legacy_curricular_scale is not None:
curricular_scale = float(legacy_curricular_scale)
if curricular_embedding_size is None and legacy_curricular_embedding_size not in (None, -1):
curricular_embedding_size = int(legacy_curricular_embedding_size)
if layer_types is None:
layer_types = self._build_layer_types(
num_hidden_layers=num_hidden_layers,
global_attn_every_n_layers=legacy_global_attn_every_n_layers,
)
else:
layer_types = list(layer_types)
rope_parameters = self._normalize_rope_parameters(
rope_parameters=rope_parameters,
global_rope_theta=legacy_global_rope_theta,
local_rope_theta=legacy_local_rope_theta,
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
cls_token_id=cls_token_id,
sep_token_id=sep_token_id,
tie_word_embeddings=tie_word_embeddings,
num_labels=num_labels,
problem_type=problem_type,
**kwargs,
)
self.kmer = int(kmer)
self.shift = int(shift)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.initializer_cutoff_factor = initializer_cutoff_factor
self.norm_eps = norm_eps
self.norm_bias = norm_bias
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types
self.rope_parameters = rope_parameters
self.local_attention = local_attention
self.embedding_dropout = embedding_dropout
self.mlp_bias = mlp_bias
self.mlp_dropout = mlp_dropout
self.decoder_bias = decoder_bias
self.classifier_pooling = classifier_pooling
self.classifier_dropout = classifier_dropout
self.classifier_bias = classifier_bias
self.classifier_activation = classifier_activation
self.deterministic_flash_attn = deterministic_flash_attn
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad
self.norm_type = norm_type
self.tie_word_embeddings = tie_word_embeddings
self.num_labels = num_labels
self.problem_type = problem_type
self.curricular_margin = curricular_margin
self.curricular_scale = curricular_scale
self.curricular_embedding_size = curricular_embedding_size
if self.kmer <= 0:
raise ValueError(f"`kmer` must be a positive integer, got {self.kmer}.")
if self.shift <= 0:
raise ValueError(f"`shift` must be a positive integer, got {self.shift}.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError(
"`layer_types` must contain one entry per hidden layer: "
f"expected {self.num_hidden_layers}, got {len(self.layer_types)}."
)
invalid_layer_types = sorted(set(self.layer_types) - {"full_attention", "sliding_attention"})
if invalid_layer_types:
raise ValueError(
f"Unsupported values in `layer_types`: {invalid_layer_types}. "
'Allowed values are ["full_attention", "sliding_attention"].'
)
if self.classifier_pooling not in ["attention", "cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be one of ["attention", "cls", "mean"], but is {self.classifier_pooling}.'
)
if self.norm_type not in {"rms", "layernorm"}:
raise ValueError(
f'Invalid value for `norm_type`, should be either "rms" or "layernorm", but is {self.norm_type}.'
)
def get_rope_parameters(self, layer_type: str) -> dict:
if layer_type not in {"full_attention", "sliding_attention"}:
raise ValueError(
f"Unsupported `layer_type`={layer_type!r}. Expected 'full_attention' or 'sliding_attention'."
)
rope_params = self.rope_parameters.get(layer_type)
if rope_params is None:
rope_params = {"rope_type": "default", "rope_theta": self.default_theta[layer_type]}
return rope_params
@property
def sliding_window(self) -> int:
return self.local_attention // 2
@sliding_window.setter
def sliding_window(self, value: int):
self.local_attention = int(value) * 2
def to_dict(self):
output = super().to_dict()
output.pop("reference_compile", None)
return output
_CHECKPOINT_FOR_DOC = "example/prokbert-base"
_CONFIG_FOR_DOC = "ProkBertConfig"
PROK_BERT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods
the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads, etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch module and refer to the PyTorch documentation for general usage and behavior.
Parameters:
config ([`ProkBertConfig`]):
Model configuration class with all the parameters of the model.
Initializing with a config file does not load the model weights; see [`PreTrainedModel.from_pretrained`] for weight loading.
"""
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, bias: bool = False):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# root-mean-square normalization over last dim
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
x = x * rms
if self.bias is not None:
x = x + self.bias
return self.weight * x
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (torch.Tensor): The query tensor.
k (torch.Tensor): The key tensor.
cos (torch.Tensor): The cosine part of the rotary embedding.
sin (torch.Tensor): The sine part of the rotary embedding.
position_ids (torch.Tensor, optional): Deprecated and unused.
unsqueeze_dim (int, optional): The dimension along which to unsqueeze cos and sin.
Returns:
tuple(torch.Tensor): The rotated query and key tensors.
"""
original_dtype = q.dtype
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
return q_embed.to(original_dtype), k_embed.to(original_dtype)
def eager_attention_forward(
module: "ProkBertAttention",
qkv: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
output_attentions: Optional[bool] = False,
**_kwargs,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# Apply rotary positional embedding to query and key.
query, key = apply_rotary_pos_emb(query, key, cos, sin)
scale = module.head_dim ** -0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_weights = attn_weights + attention_mask
# Upcast attention to fp32 for stability.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bs, -1, dim)
if output_attentions:
return (attn_output, attn_weights)
return (attn_output,)
def flash_attention_forward(
module: "ProkBertAttention",
qkv: torch.Tensor,
rotary_emb: "ProkBertUnpaddedRotaryEmbedding",
cu_seqlens: torch.Tensor,
max_seqlen: int,
local_attention: Tuple[int, int],
bs: int,
dim: int,
target_dtype: torch.dtype = torch.bfloat16,
**_kwargs,
) -> Tuple[torch.Tensor]:
# qkv: (total_seqlen, 3, nheads, headdim)
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
orig_dtype = qkv.dtype
qkv = qkv.to(target_dtype)
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.attention_dropout if module.training else 0.0,
deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
attn = attn.to(orig_dtype)
else:
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.attention_dropout if module.training else 0.0,
deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
return (attn.view(bs, dim),)
def sdpa_attention_forward(
module: "ProkBertAttention",
qkv: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
**_kwargs,
) -> Tuple[torch.Tensor]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
query, key = apply_rotary_pos_emb(query, key, cos, sin)
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_output = (
F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=module.attention_dropout if module.training else 0.0,
attn_mask=attention_mask,
)
.transpose(1, 2)
.contiguous()
)
attn_output = attn_output.view(bs, -1, dim)
return (attn_output,)
PROK_BERT_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
def _unpad_prokbert_input(
inputs: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Remove padding from input sequences.
Args:
inputs: (batch, seqlen, ...) or (batch, seqlen)
attention_mask: (batch, seqlen), where 1 means valid and 0 means padding.
position_ids: (batch, seqlen), optional position ids.
labels: (batch, seqlen), optional labels.
Returns:
unpadded_inputs: Tensor of shape (total_nnz, ...) containing only valid tokens.
indices: Tensor of indices corresponding to valid tokens.
cu_seqlens: Cumulative sequence lengths of the unpadded tokens (shape: batch + 1).
max_seqlen_in_batch: Maximum sequence length among all sequences (excluding padding).
unpadded_position_ids: (total_nnz,) or None.
unpadded_labels: (total_nnz,) or None.
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
if inputs.dim() == 2:
unpadded_inputs = inputs.flatten()[indices]
else:
batch, seqlen, *rest = inputs.shape
shape = batch * seqlen
unpadded_inputs = inputs.view(shape, *rest)[indices]
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
unpadded_labels = labels.flatten()[indices] if labels is not None else None
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
def _pad_prokbert_output(
inputs: torch.Tensor,
indices: torch.Tensor,
batch: int,
seqlen: int,
) -> torch.Tensor:
"""
Add padding back to the output tensor.
Args:
inputs: Tensor of shape (total_nnz, ...) containing outputs for only valid tokens.
indices: Tensor of indices indicating positions of valid tokens.
batch: Batch size.
seqlen: Maximum sequence length (including padding).
Returns:
Tensor of shape (batch, seqlen, ...) with outputs in their original padded positions.
"""
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
return padded_inputs
class ApplyRotaryEmbUnpad(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cos,
sin,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
# qkv: (total_nnz, 3, nheads, headdim)
qkv = qkv.contiguous()
total_nnz, _three, _nheads, headdim = qkv.shape
# Combine the (3, nheads) dimensions for the first two channels to create a (total_nnz, 2*nheads, headdim) tensor.
qk = qkv[:, :2].view(total_nnz, -1, headdim)
apply_rotary(
qk,
cos,
sin,
seqlen_offsets=0,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=False,
inplace=True,
)
ctx.save_for_backward(cos, sin, cu_seqlens)
ctx.max_seqlen = max_seqlen
return qkv
@staticmethod
def backward(ctx, do):
cos, sin, cu_seqlens = ctx.saved_tensors
do = do.contiguous()
total_nnz, _three, _nheads, headdim = do.shape
dqk = do[:, :2].view(total_nnz, -1, headdim)
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=0,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=False,
inplace=True,
conjugate=True,
)
return do, None, None, None, None
def apply_rotary_unpadded(
qkv,
cos,
sin,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Apply rotary embeddings to an unpadded (packed) QKV tensor.
Args:
qkv: Tensor of shape (total_nnz, 3, nheads, headdim) for packed QKV.
cos, sin: Precomputed cosine and sine caches.
cu_seqlens: Cumulative sequence lengths (batch + 1,).
max_seqlen: Maximum sequence length in the batch.
Returns:
Tensor with rotary embeddings applied.
"""
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
class ProkBertUnpaddedRotaryEmbedding(RotaryEmbedding):
"""
Rotary embeddings for unpadded (packed) sequences used in ProkBERT.
"""
def __init__(
self,
dim: int,
base: float = 16000.0,
max_seqlen: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Args:
dim: Dimension of each head.
base: Base for the rotary frequency computation.
max_seqlen: Maximum sequence length to precompute the cosine and sine cache.
device: Device on which to create the cache.
dtype: Data type for the cache.
"""
#super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
super().__init__(dim=dim, base=base, device=device, interleaved=False)
self.max_seqlen = max_seqlen
if max_seqlen is not None and device is not None and dtype is not None:
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
def forward(
self,
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Apply rotary embeddings *inplace* to a packed QKV tensor.
Args:
qkv: Tensor of shape (total_nnz, 3, nheads, headdim).
cu_seqlens: Cumulative sequence lengths tensor (batch + 1,).
max_seqlen: Maximum sequence length in the current batch.
Returns:
Tensor with rotary embeddings applied.
"""
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
qkv = apply_rotary_unpadded(
qkv,
self._cos_cached,
self._sin_cached,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
return qkv
def extra_repr(self) -> str:
return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
class ProkBertEmbeddings(nn.Module):
"""
Construct the embeddings from token embeddings, layer normalization, and dropout.
"""
def __init__(self, config: ProkBertConfig):
super().__init__()
self.config = config
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
if config.norm_type == "rms":
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
else:
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = nn.Dropout(config.embedding_dropout)
def forward(
self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass for the embeddings layer.
Args:
input_ids: Tensor of input token ids.
inputs_embeds: Alternatively, a pre-computed embedding tensor.
Returns:
Tensor of embeddings with normalization and dropout applied.
"""
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(inputs_embeds))
else:
hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
return hidden_states
class ProkBertRotaryEmbedding(nn.Module):
def __init__(self, config: ProkBertConfig, layer_type: str, device: Optional[torch.device] = None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.layer_type = layer_type
rope_params = config.get_rope_parameters(layer_type)
self.rope_type = rope_params["rope_type"]
self.rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(config, device, layer_type=layer_type)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: ProkBertConfig,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
layer_type: Optional[str] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation.
"""
current_layer_type = layer_type or "full_attention"
rope_params = config.get_rope_parameters(current_layer_type)
base = rope_params["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
return inv_freq, attention_factor
def _dynamic_frequency_update(self, position_ids, device):
"""
Dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - Growing beyond the cached sequence length (allow scaling)
2 - The current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config,
device,
seq_len=seq_len,
layer_type=self.layer_type,
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class ProkBertMLP(nn.Module):
"""Applies the GLU at the end of each ModernBERT layer.
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
"""
def __init__(self, config: ProkBertConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
self.act = ACT2FN[config.hidden_activation]
self.drop = nn.Dropout(config.mlp_dropout)
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.drop(self.act(input) * gate))
class ProkBertAttention(nn.Module):
"""Performs multi-headed self attention on a batch of unpadded sequences.
If Flash Attention 2 is available, this module uses it to improve throughput.
Otherwise, it falls back on PyTorch's SDPA (or eager) implementation.
"""
def __init__(self, config: ProkBertConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.layer_id = layer_id
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
)
self.attention_dropout = config.attention_dropout
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
if layer_id is None:
self.attention_type = "full_attention"
else:
self.attention_type = config.layer_types[layer_id]
if self.attention_type == "sliding_attention":
local_window = config.sliding_window
if config._attn_implementation == "flash_attention_2":
# FlashAttention uses inclusive local-attention boundaries.
local_window = local_window + 1
self.local_attention = (local_window, local_window)
max_position_embeddings = config.local_attention
elif self.attention_type == "full_attention":
self.local_attention = (-1, -1)
max_position_embeddings = config.max_position_embeddings
else:
raise ValueError(
f"Unsupported attention type {self.attention_type!r}. "
"Expected 'full_attention' or 'sliding_attention'."
)
rope_theta = float(config.get_rope_parameters(self.attention_type)["rope_theta"])
if config._attn_implementation == "flash_attention_2":
self.rotary_emb = ProkBertUnpaddedRotaryEmbedding(
dim=self.head_dim,
max_seqlen=max_position_embeddings,
base=rope_theta,
)
else:
self.rotary_emb = ProkBertRotaryEmbedding(config=config, layer_type=self.attention_type)
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
qkv = self.Wqkv(hidden_states)
bs = hidden_states.shape[0]
if self.config._attn_implementation == "flash_attention_2":
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
else:
qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
attn_outputs = PROK_BERT_ATTENTION_FUNCTION[self.config._attn_implementation](
self,
qkv=qkv,
rotary_emb=self.rotary_emb,
local_attention=self.local_attention,
bs=bs,
dim=self.all_head_size,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = attn_outputs[0]
hidden_states = self.out_drop(self.Wo(hidden_states))
return (hidden_states,) + attn_outputs[1:]
class ProkBertEncoderLayer(nn.Module):
def __init__(self, config: ProkBertConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.attention_type = "full_attention" if layer_id is None else config.layer_types[layer_id]
Norm = RMSNorm if config.norm_type == "rms" else nn.LayerNorm
self.attn_norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.mlp_norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.attn = ProkBertAttention(config=config, layer_id=layer_id)
self.mlp = ProkBertMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
attn_outputs = self.attn(
self.attn_norm(hidden_states),
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = hidden_states + attn_outputs[0]
mlp_output = self.mlp(self.mlp_norm(hidden_states))
hidden_states = hidden_states + mlp_output
return (hidden_states,) + attn_outputs[1:]
PROK_BERT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods
the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads, etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch module and refer to the PyTorch documentation for general usage and behavior.
Parameters:
config ([`ProkBertConfig`]):
Model configuration class with all the parameters of the model.
Initializing with a config file does not load the model weights; see [`PreTrainedModel.from_pretrained`] for weight loading.
"""
@add_start_docstrings(
"The bare ProkBert Model outputting raw hidden-states without any specific head on top.",
PROK_BERT_START_DOCSTRING,
)
class ProkBertPreTrainedModel(PreTrainedModel):
config_class = ProkBertConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ProkBertEmbeddings", "ProkBertEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False
def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3
def init_weight(module: nn.Module, std: float):
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)
stds = {
"in": self.config.initializer_range,
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
"embedding": self.config.initializer_range,
"final_out": self.config.hidden_size ** -0.5,
}
if isinstance(module, ProkBertEmbeddings):
init_weight(module.tok_embeddings, stds["embedding"])
elif isinstance(module, ProkBertMLP):
init_weight(module.Wi, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ProkBertAttention):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ProkBertPredictionHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, (ProkBertForMaskedLM, ProkBertForMaskedLM2)):
init_weight(module.decoder, stds["out"])
@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
):
if config._attn_implementation_internal is None:
config._attn_implementation_internal = "flash_attention_2"
try:
return cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch.float16,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
except (ValueError, ImportError):
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch.float16,
device_map=device_map,
check_device_map=check_device_map,
)
def resize_token_embeddings(self, *args, **kwargs):
model_embeds = super().resize_token_embeddings(*args, **kwargs)
return model_embeds
@add_start_docstrings(
"The bare ProkBert Model outputting raw hidden-states without any specific head on top.",
PROK_BERT_START_DOCSTRING,
)
class ProkBertModel(_SafeFromPretrainedMixin, ProkBertPreTrainedModel):
def __init__(self, config: ProkBertConfig):
#config = self._autoset_attn_implementation(config)
super().__init__(config)
self.config = config
self.embeddings = ProkBertEmbeddings(config)
self.layers = nn.ModuleList(
[ProkBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
)
if config.norm_type == "rms":
self.final_norm = RMSNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
else:
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings.tok_embeddings
def set_input_embeddings(self, value):
self.embeddings.tok_embeddings = value
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions and self.config._attn_implementation == "flash_attention_2":
logger.warning_once(
"`output_attentions=True` is not supported with flash_attention_2 in ProkBertModel. "
"Falling back to `output_attentions=False`."
)
output_attentions = False
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if batch_size is None or seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
repad = False
restore_attn_implementation = None
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_prokbert_input(
inputs=input_ids,
attention_mask=attention_mask,
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_prokbert_input(
inputs=inputs_embeds,
attention_mask=attention_mask,
)
else:
if output_attentions and self.config._attn_implementation == "sdpa":
restore_attn_implementation = self.config._attn_implementation
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask,
output_attentions=output_attentions,
)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
sliding_window_mask,
position_ids,
cu_seqlens,
max_seqlen,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions and len(layer_outputs) > 1:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.final_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if repad:
hidden_states = _pad_prokbert_output(inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len)
if all_hidden_states is not None:
all_hidden_states = tuple(
_pad_prokbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
for hs in all_hidden_states
)
if restore_attn_implementation is not None:
self.config._attn_implementation = restore_attn_implementation
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _update_attention_mask(
self,
attention_mask: torch.Tensor,
output_attentions: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
if output_attentions:
if self.config._attn_implementation == "sdpa":
logger.warning_once(
"Outputting attentions is only supported with the 'eager' attention implementation, "
'not with "sdpa". Falling back to `attn_implementation="eager"`.'
)
self.config._attn_implementation = "eager"
elif self.config._attn_implementation != "eager":
logger.warning_once(
"Outputting attentions is only supported with the eager attention implementation, "
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`. '
"Setting `output_attentions=False`."
)
return _build_bidirectional_attention_biases(
attention_mask=attention_mask,
dtype=self.dtype,
sliding_window=self.config.sliding_window,
)
class ProkBertPredictionHead(nn.Module):
def __init__(self, config: ProkBertConfig):
super().__init__()
Norm = RMSNorm if getattr(config, "norm_type", "layernorm") == "rms" else nn.LayerNorm
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
self.act = ACT2FN[config.classifier_activation]
self.norm = Norm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
@add_start_docstrings(
"The ProkBert Model with a decoder head on top that is used for masked language modeling.",
PROK_BERT_START_DOCSTRING,
)
class ProkBertForMaskedLM(_SafeFromPretrainedMixin, ProkBertPreTrainedModel):
_tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"}
def __init__(self, config: ProkBertConfig):
super().__init__(config)
self.config = config
self.model = ProkBertModel(config)
self.head = ProkBertPredictionHead(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
self.sparse_prediction = self.config.sparse_prediction
self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
self.post_init()
def get_output_embeddings(self):
return self.decoder
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.decoder = new_embeddings
def _prediction_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.decoder(self.head(hidden_states))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
if batch_size is None or seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_prokbert_input(
inputs=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_prokbert_input(
inputs=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
logits_are_sparse = False
if self.sparse_prediction and labels is not None:
labels = labels.view(-1)
last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
mask_tokens = labels != self.sparse_pred_ignore_index
last_hidden_state = last_hidden_state[mask_tokens]
labels = labels[mask_tokens]
logits_are_sparse = True
logits = self._prediction_logits(last_hidden_state)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
should_repad_logits = (
self.config._attn_implementation == "flash_attention_2"
and indices is not None
and batch_size is not None
and seq_len is not None
and not logits_are_sparse
)
if should_repad_logits:
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_prokbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class ProkBertForSequenceClassification(_SafeFromPretrainedMixin, ProkBertPreTrainedModel):
"""Standalone ProkBERT sequence classifier with mask-aware pooling."""
def __init__(self, config: ProkBertConfig):
super().__init__(config)
self.num_labels = int(config.num_labels)
self.config = config
self.model = ProkBertModel(config)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
# Kept for backward compatibility with existing sequence-classification checkpoints.
self.weighting_layer = nn.Linear(config.hidden_size, 1)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.post_init()
def _init_weights(self, module: nn.Module):
super()._init_weights(module)
if module is self.weighting_layer:
nn.init.zeros_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
if module is self.classifier:
nn.init.xavier_uniform_(module.weight, gain=1.0)
module.weight.data /= math.sqrt(self.classifier.in_features)
if module.bias is not None:
nn.init.zeros_(module.bias)
def _pool_cls(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor:
if token_mask is None:
return sequence_output[:, 0]
first_token_indices = token_mask.to(dtype=torch.long).argmax(dim=-1)
batch_indices = torch.arange(sequence_output.shape[0], device=sequence_output.device)
return sequence_output[batch_indices, first_token_indices]
def _pool_mean(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor:
if token_mask is None:
return sequence_output.mean(dim=1)
weights = token_mask.unsqueeze(-1).to(dtype=sequence_output.dtype)
denom = weights.sum(dim=1).clamp(min=1.0)
return (sequence_output * weights).sum(dim=1) / denom
def _pool_attention(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor:
scores = self.weighting_layer(sequence_output)
if token_mask is not None:
empty_rows = token_mask.sum(dim=1) == 0
if empty_rows.any():
token_mask = token_mask.clone()
token_mask[empty_rows, 0] = True
scores = scores.masked_fill(~token_mask.unsqueeze(-1), torch.finfo(scores.dtype).min)
weights = torch.softmax(scores.float(), dim=1).to(dtype=sequence_output.dtype)
return torch.sum(weights * sequence_output, dim=1)
def _pool_sequence(self, sequence_output: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
token_mask = _normalize_token_attention_mask(
attention_mask,
batch_size=sequence_output.shape[0],
seq_len=sequence_output.shape[1],
device=sequence_output.device,
)
pooling = self.config.classifier_pooling
if pooling == "attention":
return self._pool_attention(sequence_output, token_mask)
if pooling == "mean":
return self._pool_mean(sequence_output, token_mask)
if pooling == "cls":
return self._pool_cls(sequence_output, token_mask)
raise ValueError(
f"Unsupported `classifier_pooling`={pooling!r}. Expected one of ['attention', 'cls', 'mean']."
)
def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif labels.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.long, torch.uint8):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
return loss_fct(logits.squeeze(), labels.squeeze().to(logits.dtype))
return loss_fct(logits, labels.to(logits.dtype))
if self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
return loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
return loss_fct(logits, labels.to(logits.dtype))
raise ValueError(
f"Unsupported `problem_type`={self.config.problem_type!r}. "
"Expected 'regression', 'single_label_classification', or 'multi_label_classification'."
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
sliding_window_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
inputs_embeds: torch.Tensor = None,
labels: torch.Tensor = None,
indices: torch.Tensor = None,
cu_seqlens: torch.Tensor = None,
max_seqlen: int = None,
batch_size: int = None,
seq_len: int = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
**kwargs,
) -> SequenceClassifierOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
if sequence_output.dim() == 2:
if indices is None or batch_size is None or seq_len is None:
raise ValueError(
"Received unpadded hidden states from `ProkBertModel`, but `indices`, `batch_size`, and "
"`seq_len` were not provided to repad them for sequence classification."
)
sequence_output = _pad_prokbert_output(
inputs=sequence_output,
indices=indices,
batch=batch_size,
seqlen=seq_len,
)
pooled_output = self._pool_sequence(sequence_output, attention_mask)
pooled_output = self.norm(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self._compute_loss(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@dataclass
class CurricularSequenceClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
embeddings: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class CurricularFace(nn.Module):
def __init__(self, in_features: int, out_features: int, m: float = 0.5, s: float = 64.0, ema_alpha: float = 0.01):
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
self.m = float(m)
self.s = float(s)
self.ema_alpha = float(ema_alpha)
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.threshold = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
self.kernel = nn.Parameter(torch.empty(self.in_features, self.out_features))
self.register_buffer("t", torch.zeros(1, dtype=torch.float32))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.kernel)
self.t.zero_()
def cosine(self, embeddings: torch.Tensor) -> torch.Tensor:
with _autocast_disabled(embeddings.device.type):
x = F.normalize(embeddings.float(), p=2.0, dim=1, eps=1e-12)
w = F.normalize(self.kernel.float(), p=2.0, dim=0, eps=1e-12)
cos_theta = F.linear(x, w.t()).clamp(-1.0, 1.0)
return cos_theta
def inference_logits(self, embeddings: torch.Tensor) -> torch.Tensor:
return self.cosine(embeddings) * self.s
def margin_logits_from_cosine(
self,
cos_theta: torch.Tensor,
labels: torch.LongTensor,
update_t: bool = False,
) -> torch.Tensor:
labels = labels.reshape(-1).long()
target = cos_theta.gather(1, labels.unsqueeze(1))
sin_theta = torch.sqrt((1.0 - target.square()).clamp(min=0.0))
cos_theta_m = target * self.cos_m - sin_theta * self.sin_m
hard_mask = cos_theta > cos_theta_m
final_target = torch.where(
target > self.threshold,
cos_theta_m,
target - self.mm,
)
if update_t:
with torch.no_grad():
target_mean = target.mean().to(dtype=self.t.dtype).view_as(self.t)
self.t.lerp_(target_mean, self.ema_alpha)
t = self.t.to(device=cos_theta.device, dtype=cos_theta.dtype)
adjusted = torch.where(hard_mask, cos_theta * (t + cos_theta), cos_theta)
adjusted = adjusted.scatter(1, labels.unsqueeze(1), final_target)
return adjusted * self.s
class ProkBertForCurricularClassification(_SafeFromPretrainedMixin, ProkBertPreTrainedModel):
"""ProkBERT sequence classifier with CurricularFace logits for single-label classification."""
def __init__(self, config: ProkBertConfig):
super().__init__(config)
self.config = config
self.num_labels = int(config.num_labels)
if self.num_labels < 2:
raise ValueError(
"`ProkBertForCurricularClassification` requires `config.num_labels >= 2`. "
"CurricularFace is intended for single-label classification."
)
if self.config.problem_type is None:
self.config.problem_type = "single_label_classification"
elif self.config.problem_type != "single_label_classification":
raise ValueError(
"`ProkBertForCurricularClassification` only supports `problem_type='single_label_classification'`."
)
self.model = ProkBertModel(config)
self.weighting_layer = nn.Linear(config.hidden_size, 1)
self.dropout = nn.Dropout(config.classifier_dropout)
use_projection = config.curricular_embedding_size not in (None, -1)
embedding_dim = config.hidden_size if not use_projection else int(config.curricular_embedding_size)
self.linear = nn.Linear(config.hidden_size, embedding_dim) if use_projection else nn.Identity()
self.curricular_face = CurricularFace(
in_features=embedding_dim,
out_features=self.num_labels,
m=float(config.curricular_margin),
s=float(config.curricular_scale),
)
self.loss_fct = nn.CrossEntropyLoss()
self.post_init()
with torch.no_grad():
nn.init.zeros_(self.weighting_layer.weight)
if self.weighting_layer.bias is not None:
nn.init.zeros_(self.weighting_layer.bias)
if isinstance(self.linear, nn.Linear):
initialize_linear_kaiming(self.linear)
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def _pool_cls(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor:
if token_mask is None:
return sequence_output[:, 0]
first_token_indices = token_mask.to(dtype=torch.long).argmax(dim=-1)
batch_indices = torch.arange(sequence_output.shape[0], device=sequence_output.device)
return sequence_output[batch_indices, first_token_indices]
def _pool_mean(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor:
if token_mask is None:
return sequence_output.mean(dim=1)
weights = token_mask.unsqueeze(-1).to(dtype=sequence_output.dtype)
denom = weights.sum(dim=1).clamp(min=1.0)
return (sequence_output * weights).sum(dim=1) / denom
def _pool_attention(self, sequence_output: torch.Tensor, token_mask: Optional[torch.Tensor]) -> torch.Tensor:
scores = self.weighting_layer(sequence_output)
if token_mask is not None:
empty_rows = token_mask.sum(dim=1) == 0
if empty_rows.any():
token_mask = token_mask.clone()
token_mask[empty_rows, 0] = True
scores = scores.masked_fill(~token_mask.unsqueeze(-1), torch.finfo(scores.dtype).min)
weights = torch.softmax(scores.float(), dim=1).to(dtype=sequence_output.dtype)
return torch.sum(weights * sequence_output, dim=1)
def _pool_sequence(self, sequence_output: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
token_mask = _normalize_token_attention_mask(
attention_mask,
batch_size=sequence_output.shape[0],
seq_len=sequence_output.shape[1],
device=sequence_output.device,
)
pooling = self.config.classifier_pooling
if pooling == "attention":
return self._pool_attention(sequence_output, token_mask)
if pooling == "mean":
return self._pool_mean(sequence_output, token_mask)
if pooling == "cls":
return self._pool_cls(sequence_output, token_mask)
raise ValueError(
f"Unsupported `classifier_pooling`={pooling!r}. Expected one of ['attention', 'cls', 'mean']."
)
def _repad_sequence_output_if_needed(
self,
sequence_output: torch.Tensor,
indices: Optional[torch.Tensor],
batch_size: Optional[int],
seq_len: Optional[int],
) -> torch.Tensor:
if sequence_output.dim() != 2:
return sequence_output
if indices is None or batch_size is None or seq_len is None:
raise ValueError(
"Received unpadded hidden states from `ProkBertModel`, but `indices`, `batch_size`, and "
"`seq_len` were not provided to repad them for curricular classification."
)
return _pad_prokbert_output(
inputs=sequence_output,
indices=indices,
batch=batch_size,
seqlen=seq_len,
)
def _compute_embeddings(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
apply_dropout: bool = True,
) -> tuple[torch.Tensor, BaseModelOutput]:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
sequence_output = self._repad_sequence_output_if_needed(
outputs.last_hidden_state,
indices=indices,
batch_size=batch_size,
seq_len=seq_len,
)
pooled_output = self._pool_sequence(sequence_output, attention_mask)
if apply_dropout:
pooled_output = self.dropout(pooled_output)
embeddings = self.linear(pooled_output)
return embeddings, outputs
def encode(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
normalize: bool = True,
) -> torch.Tensor:
embeddings, _ = self._compute_embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
apply_dropout=False,
)
return l2_norm(embeddings, axis=1) if normalize else embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_embeddings: bool = False,
normalize_embeddings: bool = True,
**kwargs,
) -> Union[Tuple, CurricularSequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embeddings, outputs = self._compute_embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
apply_dropout=self.training,
)
exported_embeddings = None
if return_embeddings:
exported_embeddings = l2_norm(embeddings, axis=1) if normalize_embeddings else embeddings
cos_theta = self.curricular_face.cosine(embeddings)
logits = cos_theta * self.curricular_face.s
loss = None
if labels is not None:
labels = labels.view(-1).long()
train_logits = self.curricular_face.margin_logits_from_cosine(
cos_theta,
labels,
update_t=self.training,
)
loss = self.loss_fct(train_logits, labels)
if not return_dict:
out = (logits,)
if return_embeddings:
out = out + (exported_embeddings,)
if output_hidden_states:
out = out + (outputs.hidden_states,)
if output_attentions:
out = out + (outputs.attentions,)
return ((loss,) + out) if loss is not None else out
return CurricularSequenceClassifierOutput(
loss=loss,
logits=logits,
embeddings=exported_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class ProkBertForMaskedLM2(_SafeFromPretrainedMixin, ProkBertPreTrainedModel):
_tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"}
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = ProkBertModel(config)
self.head = ProkBertPredictionHead(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
# for sparse‐integer masking (legacy)
self.sparse_prediction = config.sparse_prediction
self.sparse_pred_ignore_index = config.sparse_pred_ignore_index
self.post_init()
def get_output_embeddings(self):
return self.decoder
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.decoder = new_embeddings
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor]= None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
labels_dist: Optional[torch.FloatTensor] = None,
loss_mask: Optional[torch.BoolTensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 1) Optional unpad for flash_attention_2
if self.config._attn_implementation == "flash_attention_2" \
and indices is None and cu_seqlens is None and max_seqlen is None:
# infer batch_size, seq_len
if batch_size is None or seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
# EXPLICIT device pick
device = input_ids.device if input_ids is not None else inputs_embeds.device
attention_mask = attention_mask if attention_mask is not None else \
torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = \
_unpad_prokbert_input(
inputs=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = \
_unpad_prokbert_input(
inputs=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels
)
# 2) Core encoder
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0] # (B,L,H) or packed (N,H)
# 3) Legacy sparse integer mask
if self.sparse_prediction and labels is not None:
flat_labels = labels.view(-1)
flat_hidden = sequence_output.view(flat_labels.shape[0], -1)
mask_tokens = flat_labels != self.sparse_pred_ignore_index
sequence_output = flat_hidden[mask_tokens]
labels = flat_labels[mask_tokens]
hidden = self.head(sequence_output)
logits = self.decoder(hidden)
loss = None
V = self.config.vocab_size
# 5a) Integer‐label MLM
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, V), labels.view(-1))
# 5b) Soft‐distribution MLM (no re‐pad)
elif labels_dist is not None and loss_mask is not None:
B, L = loss_mask.shape
flat_mask = loss_mask.view(-1) # (B*L,)
flat_dist = labels_dist.view(-1, V) # (B*L, V)
# packed by attention_mask
if logits.dim() == 2 and logits.shape[0] != flat_mask.sum().item():
full_attn = attention_mask.view(-1) # (B*L,)
assert logits.shape[0] == full_attn.sum().item()
dist_attn = flat_dist[full_attn] # (Natt, V)
mask_in_attn = flat_mask[full_attn] # (Natt,)
pred = logits[mask_in_attn] # (N_mask, V)
targ = dist_attn[mask_in_attn] # (N_mask, V)
# packed exactly by loss_mask
elif logits.dim() == 2 and logits.shape[0] == flat_mask.sum().item():
pred = logits
targ = flat_dist[flat_mask]
# full (B,L,V)
else:
flat_logits = logits.view(-1, V) # (B*L, V)
pred = flat_logits[flat_mask] # (N_mask, V)
targ = flat_dist[flat_mask] # (N_mask, V)
eps = 1e-8
targ = targ.clamp_min(eps)
targ = targ / targ.sum(dim=-1, keepdim=True)
targ = targ.to(pred.dtype).detach()
logp = F.log_softmax(pred, dim=-1)
loss = F.kl_div(logp, targ, reduction="batchmean")
should_repad_logits = (
self.config._attn_implementation == "flash_attention_2"
and indices is not None
and batch_size is not None
and seq_len is not None
and not (self.sparse_prediction and labels is not None)
)
if should_repad_logits:
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_prokbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
# 6) Return
if not return_dict:
out = (logits,) + outputs[1:]
return ((loss,) + out) if loss is not None else out
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)