| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from collections.abc import Callable |
| from typing import Optional, Union |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.integrations.hub_kernels import use_kernel_forward_from_hub |
| from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple |
| from .configuration_gpt_oss_dense import GptOssDenseConfig |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class GptOssDenseRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| GptOssDenseRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return (self.weight * hidden_states).to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class GptOssDenseMLP(nn.Module): |
| """Dense MLP layer for GptOssDense. Uses the same activation as GptOss experts but without routing.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.intermediate_size = config.intermediate_size |
| self.hidden_size = config.hidden_size |
| |
| self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=True) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) |
| self.alpha = 1.702 |
| self.limit = config.swiglu_limit |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) |
| Returns: |
| torch.Tensor: (batch_size, seq_len, hidden_size) |
| """ |
| |
| gate_up = self.gate_up_proj(hidden_states) |
| |
| gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
| |
| gate = gate.clamp(min=None, max=self.limit) |
| up = up.clamp(min=-self.limit, max=self.limit) |
| |
| glu = gate * torch.sigmoid(gate * self.alpha) |
| gated_output = (up + 1) * glu |
| |
| output = self.down_proj(gated_output) |
| return output |
|
|
|
|
| class GptOssDenseRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: GptOssDenseConfig, device=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
|
|
| self.rope_type = self.config.rope_parameters["rope_type"] |
| rope_init_fn: Callable = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = inv_freq |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: Optional[GptOssDenseConfig] = None, |
| device: Optional["torch.device"] = None, |
| seq_len: Optional[int] = None, |
| ) -> tuple["torch.Tensor", float]: |
| """ |
| Computes the inverse frequencies according to the original RoPE implementation |
| Args: |
| config ([`~transformers.PreTrainedConfig`]): |
| The model configuration. |
| device (`torch.device`): |
| The device to use for initialization of the inverse frequencies. |
| seq_len (`int`, *optional*): |
| The current sequence length. Unused for this type of RoPE. |
| Returns: |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). |
| """ |
| base = config.rope_parameters["rope_theta"] |
| dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
|
| attention_factor = 1.0 |
|
|
| |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) |
| ) |
| return inv_freq, attention_factor |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = freqs |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(x.dtype), sin.to(x.dtype) |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def _apply_rotary_emb( |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| first_half, second_half = torch.chunk(x, 2, dim=-1) |
| first_ = first_half * cos - second_half * sin |
| second_ = second_half * cos + first_half * sin |
| return torch.cat((first_, second_), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = _apply_rotary_emb(q, cos, sin) |
| k_embed = _apply_rotary_emb(k, cos, sin) |
| return q_embed, k_embed |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) |
| combined_logits = torch.cat([attn_weights, sinks], dim=-1) |
|
|
| |
| |
|
|
| combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values |
| probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) |
| scores = probs[..., :-1] |
| attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| return attn_output, attn_weights |
|
|
|
|
| class GptOssDenseAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: GptOssDenseConfig, layer_idx: int): |
| super().__init__() |
| self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_values is not None: |
| cache_kwargs = {"cache_position": cache_position} |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| position_ids=position_ids, |
| s_aux=self.sinks, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class GptOssDenseDecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: GptOssDenseConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = GptOssDenseAttention(config=config, layer_idx=layer_idx) |
| self.mlp = GptOssDenseMLP(config) |
| self.input_layernorm = GptOssDenseRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = GptOssDenseRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.attention_type = config.layer_types[layer_idx] |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class GptOssDensePreTrainedModel(PreTrainedModel): |
| config: GptOssDenseConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["GptOssDenseDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = False |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": GptOssDenseDecoderLayer, |
| "attentions": GptOssDenseAttention, |
| } |
| _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] |
| _supports_flash_attention = False |
| _supports_flex_attention = False |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Parameter): |
| module.data.normal_(mean=0.0, std=std) |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, GptOssDenseRMSNorm): |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, GptOssDenseMLP): |
| module.gate_up_proj.weight.data.normal_(mean=0.0, std=std) |
| module.gate_up_proj.bias.data.zero_() |
| module.down_proj.weight.data.normal_(mean=0.0, std=std) |
| module.down_proj.bias.data.zero_() |
| elif isinstance(module, GptOssDenseAttention): |
| module.sinks.data.normal_(mean=0.0, std=std) |
|
|
|
|
| @auto_docstring |
| class GptOssDenseModel(GptOssDensePreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GptOssDenseDecoderLayer`] |
| |
| Args: |
| config: GptOssDenseConfig |
| """ |
|
|
| _no_split_modules = ["GptOssDenseDecoderLayer"] |
|
|
| def __init__(self, config: GptOssDenseConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [GptOssDenseDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = GptOssDenseRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = GptOssDenseRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| if not isinstance(causal_mask_mapping := attention_mask, dict): |
| mask_kwargs = { |
| "config": self.config, |
| "input_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| } |
| causal_mask_mapping = { |
| "full_attention": create_causal_mask(**mask_kwargs), |
| "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), |
| } |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for decoder_layer in self.layers: |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask_mapping[decoder_layer.attention_type], |
| position_embeddings=position_embeddings, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| @auto_docstring |
| class GptOssDenseForCausalLM(GptOssDensePreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = GptOssDenseModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, GptOssDenseForCausalLM |
| |
| >>> model = GptOssDenseForCausalLM.from_pretrained("USERNAME/MODEL_NAME") |
| >>> tokenizer = AutoTokenizer.from_pretrained("USERNAME/MODEL_NAME") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
|
|
| |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| __all__ = ["GptOssDenseForCausalLM", "GptOssDenseModel", "GptOssDensePreTrainedModel"] |
|
|