| import typing as tp |
| import warnings |
| from functools import partial |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.attention.flex_attention import flex_attention |
| from transformers import PreTrainedModel |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation.utils import GenerationMixin |
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
| from .configuration_gidd import GiddConfig |
|
|
|
|
| @dataclass |
| class AttentionLayerOutput: |
| hidden_states: torch.Tensor |
| attentions: tp.Optional[torch.Tensor] = None |
| past_key_values: tp.Optional[tp.List[tp.Tuple[torch.Tensor, torch.Tensor]]] = None |
|
|
| @dataclass |
| class DecoderLayerOutput: |
| hidden_states: torch.Tensor |
| attentions: tp.Optional[torch.Tensor] = None |
| past_key_values: tp.Optional[tp.List[tp.Tuple[torch.Tensor, torch.Tensor]]] = None |
|
|
|
|
| def promote_dtype(args: tuple, *, dtype: torch.dtype | None = None) -> tuple: |
| return tuple( |
| torch.as_tensor(x, dtype=dtype) if x is not None else None |
| for x in args |
| ) |
|
|
|
|
| class ScaledLinear(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| *, |
| scale: float | tp.Literal["fan_in", "fan_out"] = 1.0, |
| use_bias: bool = True, |
| dtype: torch.dtype | None = None, |
| ): |
| super().__init__() |
|
|
| if scale == "fan_in": |
| scale = in_features**-0.5 |
| elif scale == "fan_out": |
| scale = out_features**-0.5 |
|
|
| if scale != 1.0: |
| def _scale_operator(x): |
| return x * scale |
| else: |
| def _scale_operator(x): |
| return x |
|
|
| self._scale_operator = _scale_operator |
| self.in_features = in_features |
| self.out_features = out_features |
|
|
| self.use_bias = use_bias |
|
|
| weight_shape = (out_features, in_features) |
| weight = torch.zeros(weight_shape, dtype=dtype) |
| self.weight = nn.Parameter(weight) |
|
|
| if use_bias: |
| bias = torch.zeros((out_features,), dtype=dtype) |
| self.bias = nn.Parameter(bias) |
| else: |
| self.bias = None |
|
|
| def forward( |
| self, |
| inputs: torch.Tensor, |
| w: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| dtype = inputs.dtype |
| weight = self.weight if w is None else w |
| bias = self.bias if self.use_bias else None |
|
|
| if bias is not None: |
| inputs, weight, bias = promote_dtype((inputs, weight, bias), dtype=dtype) |
| else: |
| inputs, weight = promote_dtype((inputs, weight), dtype=dtype) |
|
|
| y = torch.matmul( |
| inputs, |
| weight.T, |
| ) |
|
|
| y = self._scale_operator(y) |
|
|
| if bias is not None: |
| y = y + bias.reshape((1,) * (y.ndim - 1) + (-1,)) |
|
|
| return y |
|
|
|
|
| def _apply_rotary_emb( |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| is_neox_style: bool, |
| ) -> torch.Tensor: |
| cos = cos.unsqueeze(2).to(dtype=x.dtype) |
| sin = sin.unsqueeze(2).to(dtype=x.dtype) |
| assert sin.ndim == x.ndim |
| if is_neox_style: |
| x1, x2 = torch.chunk(x, 2, dim=-1) |
| else: |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
|
|
| o1 = x1 * cos - x2 * sin |
| o2 = x2 * cos + x1 * sin |
|
|
| if is_neox_style: |
| return torch.cat((o1, o2), dim=-1) |
| else: |
| return torch.stack((o1, o2), dim=-1).reshape(x.shape) |
|
|
| def apply_basic_rope( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| positions: torch.Tensor, |
| frequencies: torch.Tensor, |
| rotary_dim: int, |
| is_neox_style: bool, |
| offsets: torch.Tensor | None = None, |
| dtype: torch.dtype = torch.float32, |
| ): |
| if offsets is not None: |
| positions = positions + offsets |
| cos, sin = torch.chunk(frequencies[positions], 2, dim=-1) |
| if rotary_dim != query.shape[-1]: |
| query_rot = _apply_rotary_emb(query[..., :rotary_dim], cos, sin, is_neox_style) |
| query = torch.cat((query_rot, query[..., rotary_dim:]), dim=-1) |
| key_rot = _apply_rotary_emb(key[..., :rotary_dim], cos, sin, is_neox_style) |
| key = torch.cat((key_rot, key[..., rotary_dim:]), dim=-1) |
| return query.to(dtype), key.to(dtype), cos, sin |
| else: |
| query = _apply_rotary_emb(query, cos, sin, is_neox_style) |
| key = _apply_rotary_emb(key, cos, sin, is_neox_style) |
| return query.to(dtype), key.to(dtype), cos, sin |
|
|
| def compute_basic_frequencies( |
| base: int, |
| rotary_dim: int, |
| max_position_embeddings: int, |
| ): |
| inv = 1.0 / torch.pow( |
| base, |
| torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim, |
| ) |
| freqs = torch.einsum( |
| "i,j->ij", |
| torch.arange(max_position_embeddings, dtype=torch.float32), |
| inv, |
| ) |
| freqs = torch.cat([freqs.cos(), freqs.sin()], dim=-1) |
| return freqs |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__( |
| self, |
| head_size: int, |
| rotary_dim: int, |
| max_position_embeddings: int, |
| base: int, |
| is_neox_style: bool, |
| dtype: torch.dtype, |
| ): |
| super().__init__() |
| self.head_size = head_size |
| self.rotary_dim = rotary_dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| self.is_neox_style = is_neox_style |
| self.dtype = dtype |
|
|
| def forward( |
| self, |
| positions: torch.Tensor, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| offsets: torch.Tensor | None = None, |
| frequencies: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if frequencies is None: |
| frequencies = compute_basic_frequencies( |
| base=self.base, |
| rotary_dim=self.rotary_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| ) |
| if hasattr(frequencies, "value"): |
| frequencies = frequencies.value |
| return apply_basic_rope( |
| query=query, |
| key=key, |
| positions=positions, |
| frequencies=frequencies, |
| rotary_dim=self.rotary_dim, |
| is_neox_style=self.is_neox_style, |
| offsets=offsets, |
| dtype=self.dtype, |
| ) |
|
|
|
|
| class GiddRMSNorm(nn.Module): |
| def __init__( |
| self, |
| config: GiddConfig, |
| dtype=torch.float32, |
| ): |
| super().__init__() |
| self.config = config |
| self.epsilon = self.config.rms_norm_eps |
| self.weight = nn.Parameter(torch.zeros(self.config.hidden_size, dtype=dtype)) |
| |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| dtype = hidden_states.dtype |
| variance = hidden_states.to(torch.float32) |
| variance = variance.pow(2.0) |
| variance = variance.mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) |
|
|
| hidden_states = ((1 + self.weight) * hidden_states) |
| return hidden_states.to(dtype) |
|
|
| ALL_LAYERNORM_LAYERS.append(GiddRMSNorm) |
|
|
|
|
| class GiddMLP(nn.Module): |
| def __init__( |
| self, |
| config: GiddConfig, |
| dtype=torch.float32, |
| ): |
| super().__init__() |
| self.config = config |
| self.dtype = dtype |
|
|
| linear_class = partial( |
| ScaledLinear, |
| scale=config.weight_scaling, |
| dtype=dtype, |
| use_bias=self.config.mlp_bias, |
| ) |
| self.up_proj = linear_class(config.hidden_size, config.intermediate_size) |
| self.down_proj = linear_class(config.intermediate_size, config.hidden_size) |
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor: |
| h = self.up_proj(h) |
| h = torch.relu(h) ** 2 |
| h = self.down_proj(h) |
| return h |
|
|
|
|
| class FlexSoftcapAttention(nn.Module): |
| def __init__(self, head_dim, n_heads, softmax_scale, soft_cap): |
| super().__init__() |
| self.d_model = head_dim * n_heads |
| self.n_heads = n_heads |
| self.head_dim = head_dim |
| self.scale = float(softmax_scale) |
| self.soft_cap = float(soft_cap) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ): |
| B, _, L = q.shape[:3] |
|
|
| def score_mod(score, b, h, q_idx, kv_idx): |
| soft_cap = self.soft_cap |
| score = soft_cap * torch.tanh(score / soft_cap) |
| keep = attention_mask[b, q_idx, kv_idx] |
| return torch.where(keep, score, torch.finfo(score.dtype).min) |
|
|
| out = flex_attention( |
| q, |
| k, |
| v, |
| score_mod=score_mod, |
| scale=self.scale, |
| ) |
| out = out.transpose(1, 2).contiguous().view(B, L, self.d_model) |
| return out, None |
|
|
|
|
| class VanillaSoftcapAttention(nn.Module): |
| def __init__(self, head_dim, n_heads, softmax_scale, soft_cap): |
| super().__init__() |
| self.d_model = head_dim * n_heads |
| self.n_heads = n_heads |
| self.head_dim = head_dim |
| self.scale = float(softmax_scale) |
| self.soft_cap = float(soft_cap) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ): |
| B, _, L = q.shape[:3] |
| scores = torch.einsum( |
| "bhqd,bhkd->bhqk", |
| q * self.scale, |
| k, |
| ) |
| scores = self.soft_cap * torch.tanh(scores / self.soft_cap) |
| if attention_mask is not None: |
| scores = scores.masked_fill(~attention_mask.unsqueeze(1), torch.finfo(scores.dtype).min) |
| probs = torch.softmax(scores.to(torch.float32), dim=-1).to(scores.dtype) |
| out = torch.einsum( |
| "bhqk,bhkd->bhqd", |
| probs, |
| v, |
| ) |
| out = out.transpose(1, 2).contiguous().view(B, L, self.d_model) |
| return out, probs |
|
|
|
|
| class GiddAttention(nn.Module): |
| def __init__( |
| self, |
| config: GiddConfig, |
| layer_idx: int, |
| dtype=torch.float32, |
| ): |
| super().__init__() |
|
|
| self.hidden_size = config.hidden_size |
| head_dim = config.hidden_size // config.num_attention_heads |
| self.head_dim = getattr(config, "head_dim", head_dim) |
| self.num_attention_heads = self.hidden_size // self.head_dim |
| self.is_causal = config.is_causal |
| self.layer_idx = layer_idx |
|
|
| self.use_qk_norm = config.use_qk_norm |
| if self.use_qk_norm: |
| self.q_norm = GiddRMSNorm(config, dtype=torch.float32) |
| self.k_norm = GiddRMSNorm(config, dtype=torch.float32) |
| else: |
| self.q_norm = None |
| self.k_norm = None |
|
|
| self.attention_bias = config.attention_bias |
| if self.attention_bias: |
| self.k_bias = nn.Parameter( |
| torch.zeros((self.num_attention_heads, self.head_dim), dtype=dtype), |
| ) |
| self.v_bias = nn.Parameter( |
| torch.zeros((self.num_attention_heads, self.head_dim), dtype=dtype), |
| ) |
| else: |
| self.k_bias = None |
| self.v_bias = None |
|
|
| linear_class = partial( |
| ScaledLinear, |
| scale=config.weight_scaling, |
| dtype=dtype, |
| use_bias=False, |
| ) |
| self.q_proj = linear_class( |
| self.hidden_size, |
| self.num_attention_heads * self.head_dim, |
| ) |
| self.k_proj = linear_class( |
| self.hidden_size, |
| self.num_attention_heads * self.head_dim, |
| ) |
| self.v_proj = linear_class( |
| self.hidden_size, |
| self.num_attention_heads * self.head_dim, |
| ) |
| self.o_proj = linear_class( |
| self.num_attention_heads * self.head_dim, |
| self.hidden_size, |
| ) |
|
|
| self.rotary = RotaryEmbedding( |
| head_size=self.head_dim, |
| rotary_dim=self.head_dim, |
| max_position_embeddings=config.max_position_embeddings, |
| base=config.rope_theta, |
| is_neox_style=True, |
| dtype=dtype, |
| ) |
|
|
| if config.attn_performer == "flex": |
| self.attention_performer = FlexSoftcapAttention( |
| head_dim=self.head_dim, |
| n_heads=self.num_attention_heads, |
| softmax_scale=self.head_dim**-0.5, |
| soft_cap=config.attn_soft_cap, |
| ) |
| elif config.attn_performer == "eager": |
| self.attention_performer = VanillaSoftcapAttention( |
| head_dim=self.head_dim, |
| n_heads=self.num_attention_heads, |
| softmax_scale=self.head_dim**-0.5, |
| soft_cap=config.attn_soft_cap, |
| ) |
| else: |
| raise ValueError(f"Unknown attn_performer: {config.attn_performer}") |
|
|
| def concatenate( |
| self, |
| *, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: torch.Tensor, |
| past_key_values: tp.Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| ): |
| assert query.shape[1] == key.shape[1], "Query and Key lengths must match for GIDD attention." |
| if attention_mask is not None: |
| if attention_mask.dtype != torch.bool: |
| warnings.warn("attention_mask should be a boolean array", stacklevel=1) |
| attention_mask = (attention_mask == 1) |
|
|
| batch_size = query.shape[0] |
|
|
| |
| |
|
|
| if attention_mask.ndim == 2: |
| attention_mask = attention_mask.unsqueeze(1) |
| attention_mask = attention_mask.expand(-1, query.shape[1], -1) |
| elif attention_mask.ndim == 3: |
| |
| pass |
|
|
| if self.attention_bias: |
| ones = torch.ones( |
| attention_mask.shape[:2] + (1,), |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ) |
| attention_mask = torch.cat( |
| [ |
| ones, |
| attention_mask, |
| ], |
| dim=-1, |
| ) |
|
|
| if past_key_values is not None: |
| past_keys, past_values = past_key_values |
| key = torch.cat([past_keys, key], dim=1) |
| value = torch.cat([past_values, value], dim=1) |
| elif self.attention_bias: |
| n_heads = self.num_attention_heads |
| bias_shape = (batch_size, 1, n_heads, self.head_dim) |
| k_bias = self.k_bias.view(1, 1, n_heads, self.head_dim).expand(bias_shape) |
| v_bias = self.v_bias.view(1, 1, n_heads, self.head_dim).expand(bias_shape) |
| key = torch.cat([k_bias, key], dim=1) |
| value = torch.cat([v_bias, value], dim=1) |
|
|
| |
| return query, key, value, attention_mask, (key, value) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| past_key_values: tp.Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| frequencies: tp.Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ) -> AttentionLayerOutput: |
| batch_size, sequence_length = hidden_states.shape[:2] |
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| if self.use_qk_norm: |
| query_states = self.q_norm(query_states) |
| key_states = self.k_norm(key_states) |
|
|
| qshape = ( |
| batch_size, |
| sequence_length, |
| self.num_attention_heads, |
| self.head_dim, |
| ) |
| kv_shape = ( |
| batch_size, |
| sequence_length, |
| self.num_attention_heads, |
| self.head_dim, |
| ) |
| query_states = query_states.view(qshape) |
| key_states = key_states.view(kv_shape) |
| value_states = value_states.view(kv_shape) |
|
|
| query_states, key_states, cos, sin = self.rotary( |
| positions=position_ids, |
| query=query_states, |
| key=key_states, |
| frequencies=frequencies, |
| ) |
|
|
| ( |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| past_key_values, |
| ) = self.concatenate( |
| query=query_states, |
| key=key_states, |
| value=value_states, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| ) |
|
|
| attention_out, attentions = self.attention_performer.forward( |
| q=query_states.transpose(1, 2), |
| k=key_states.transpose(1, 2), |
| v=value_states.transpose(1, 2), |
| attention_mask=attention_mask, |
| ) |
|
|
| attn_output = self.o_proj(attention_out) |
|
|
| return AttentionLayerOutput( |
| hidden_states=attn_output, |
| attentions=attentions if output_attentions else None, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| class GiddLayer(nn.Module): |
| def __init__( |
| self, |
| config: GiddConfig, |
| layer_idx: int, |
| dtype=torch.float32, |
| resid_scale: float = 1.0, |
| ): |
| super().__init__() |
| self.config = config |
| self.resid_scale = resid_scale |
| self.layer_idx = layer_idx |
|
|
| self.self_attn = GiddAttention( |
| layer_idx=layer_idx, |
| config=config, |
| dtype=dtype, |
| ) |
|
|
| self.mlp = GiddMLP( |
| config=config, |
| dtype=dtype, |
| ) |
| self.attn_layernorm = GiddRMSNorm( |
| config=config, |
| dtype=torch.float32, |
| ) |
| self.mlp_layernorm = GiddRMSNorm( |
| config=config, |
| dtype=torch.float32, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| past_key_values: tp.Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| frequencies: tp.Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ) -> DecoderLayerOutput: |
| attn_inputs = self.attn_layernorm(hidden_states) |
| attn_outputs = self.self_attn( |
| attn_inputs, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| frequencies=frequencies, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = hidden_states + self.resid_scale * attn_outputs.hidden_states |
|
|
| mlp_inputs = self.mlp_layernorm(hidden_states) |
| mlp_output = self.mlp(mlp_inputs) |
| hidden_states = hidden_states + self.resid_scale * mlp_output |
|
|
| return DecoderLayerOutput( |
| hidden_states=hidden_states, |
| attentions=attn_outputs.attentions, |
| past_key_values=attn_outputs.past_key_values, |
| ) |
| |
|
|
| class GiddPreTrainedModel(PreTrainedModel): |
| config_class = GiddConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = False |
| _no_split_modules = ["GiddLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = False |
| _supports_sdpa = False |
| _supports_flex_attn = False |
| _can_compile_fullgraph = False |
| _supports_attention_backend = False |
| _can_record_outputs = { |
| "hidden_states": GiddLayer, |
| "attentions": GiddAttention, |
| } |
|
|
| def _init_weights(self, module): |
| super()._init_weights(module) |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
|
|
| class GiddModel(GiddPreTrainedModel): |
| def __init__( |
| self, |
| config: GiddConfig, |
| ): |
| super().__init__(config=config) |
|
|
| self.resid_scale = config.resid_scale / config.num_hidden_layers |
| dtype = config.torch_dtype |
|
|
| self.embed_tokens = nn.Embedding( |
| num_embeddings=self.config.vocab_size, |
| embedding_dim=self.config.hidden_size, |
| ) |
| self.embed_tokens.weight.data = self.embed_tokens.weight.data.to(dtype) |
| nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=self.config.emb_init_scale) |
|
|
| freqs = compute_basic_frequencies( |
| base=config.rope_theta, |
| rotary_dim=config.hidden_size // config.num_attention_heads, |
| max_position_embeddings=config.max_position_embeddings, |
| ) |
| self.frequencies = nn.Buffer(freqs, persistent=False) |
|
|
| self.layers = nn.ModuleList( |
| [ |
| GiddLayer( |
| config=config, |
| layer_idx=i, |
| resid_scale=self.resid_scale, |
| dtype=dtype, |
| ) |
| for i in range(self.config.num_hidden_layers) |
| ] |
| ) |
| self.norm = GiddRMSNorm( |
| config=config, |
| dtype=torch.float32, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: tp.Optional[torch.Tensor] = None, |
| inputs_embeds: tp.Optional[torch.Tensor] = None, |
| attention_mask: tp.Optional[torch.Tensor] = None, |
| position_ids: tp.Optional[torch.Tensor] = None, |
| past_key_values: tp.Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, |
| use_cache: bool = False, |
| cache_position: tp.Optional[torch.LongTensor] = None, |
| output_attentions: tp.Optional[bool] = None, |
| output_hidden_states: tp.Optional[bool] = None, |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids.to(torch.long)) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = [None] * self.config.num_hidden_layers |
| elif past_key_values is not None: |
| past_key_values = list(past_key_values) |
|
|
| if position_ids is None: |
| past_seen_tokens = 0 |
| if past_key_values is not None and any(past_key_values): |
| past_seen_tokens = [kv[0].shape[1] for kv in past_key_values if kv is not None][0] |
| cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens |
| position_ids = cache_position.unsqueeze(0) |
|
|
| batch_size, sequence_length, _ = inputs_embeds.shape |
|
|
| assert sequence_length <= self.config.max_position_embeddings, ( |
| f"Maximum Position Embedding Reached ! (expected <= {self.config.max_position_embeddings} got {sequence_length})" |
| ) |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, sequence_length), |
| dtype=torch.bool, |
| device=inputs_embeds.device, |
| ) |
| else: |
| if attention_mask.dtype != torch.bool: |
| attention_mask = (attention_mask == 1) |
|
|
| if position_ids is None: |
| position_ids = torch.arange( |
| inputs_embeds.shape[-2], |
| dtype=torch.int32, |
| device=inputs_embeds.device, |
| ) |
| position_ids = position_ids.unsqueeze(0).expand(inputs_embeds.shape[:-1]) |
|
|
| hidden_states = inputs_embeds |
|
|
| all_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
| for idx, block in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer_outputs = block( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| frequencies=self.frequencies, |
| past_key_values=past_key_values[idx] if past_key_values is not None else None, |
| ) |
| hidden_states = layer_outputs.hidden_states |
|
|
| if output_attentions: |
| all_attentions += (layer_outputs.attentions,) |
|
|
| if use_cache: |
| past_key_values[idx] = layer_outputs.past_key_values |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| class GiddForDiffusionLM(GiddPreTrainedModel, GenerationMixin): |
| def __init__( |
| self, |
| config: GiddConfig, |
| ): |
| super().__init__(config=config) |
|
|
| self.model = GiddModel(config=config) |
|
|
| self.lm_head = ScaledLinear( |
| config.hidden_size, |
| config.vocab_size, |
| scale=config.head_scaling, |
| dtype=config.torch_dtype, |
| use_bias=False, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: tp.Optional[torch.Tensor] = None, |
| inputs_embeds: tp.Optional[torch.Tensor] = None, |
| attention_mask: tp.Optional[torch.Tensor] = None, |
| position_ids: tp.Optional[torch.Tensor] = None, |
| past_key_values: tp.Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, |
| use_cache: bool = False, |
| output_attentions: tp.Optional[bool] = None, |
| output_hidden_states: tp.Optional[bool] = None, |
| ) -> CausalLMOutputWithPast: |
| outputs = self.model( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
|
|
| if self.config.tie_word_embeddings: |
| logits = hidden_states @ self.model.embed_tokens.weight.t() |
| else: |
| logits = self.lm_head(hidden_states) |
|
|
| return CausalLMOutputWithPast( |
| loss=None, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| past_key_values=outputs.past_key_values, |
| ) |
| |
| def _sample_prior(self, shape: tuple[int, ...], device: torch.device, mask_token_id: int = 3) -> torch.Tensor: |
| p_unif = torch.sigmoid( |
| torch.ones(shape, device=device) * self.config.min_log_snr + self.config.noise_type |
| ) |
| r = torch.rand(shape, device=device) |
| unif = torch.randint(0, self.config.vocab_size, shape, device=device) |
| samples = torch.where(r < p_unif, unif, mask_token_id) |
| return samples |
| |
| def _probs_with_topk_topp(self, logits, temperature: float, top_p: float | None, top_k: int | None): |
| if temperature == 0.0: |
| probs = torch.zeros_like(logits) |
| indices = torch.argmax(logits, dim=-1, keepdim=True) |
| probs.scatter_(-1, indices, 1.0) |
| return probs |
| |
| x = logits / temperature |
|
|
| if top_k is not None and 0 < top_k < x.size(-1): |
| kth = torch.topk(x, top_k, dim=-1).values[..., -1, None] |
| x = torch.where(x < kth, torch.full_like(x, float("-inf")), x) |
|
|
| if top_p is not None and 0.0 < top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(x, descending=True, dim=-1) |
| sorted_probs = torch.softmax(sorted_logits, dim=-1) |
| cumprobs = sorted_probs.cumsum(dim=-1) |
|
|
| remove = cumprobs > top_p |
| remove[..., 1:] = remove[..., :-1].clone() |
| remove[..., 0] = False |
|
|
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| x = x.scatter(-1, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(x, dim=-1) |
|
|
| return probs |
| |
| def _pi_lambda(self, log_snr, mask_token_id=3): |
| unif_vec = torch.ones((self.config.vocab_size,), device=log_snr.device) / (self.config.vocab_size - 1) |
| unif_vec[mask_token_id] = 0.0 |
| alpha = torch.sigmoid(log_snr + self.config.noise_type) |
| pi = alpha * unif_vec |
| pi[..., mask_token_id] = 1.0 - alpha |
| return pi |
| |
| def _sample_ancestral( |
| self, |
| z: torch.Tensor, |
| x_hat: torch.Tensor, |
| log_snr_t: torch.Tensor, |
| log_snr_s: torch.Tensor, |
| mask_token_id: int = 3, |
| ): |
| alpha_s = log_snr_s.sigmoid() |
| alpha_t = log_snr_t.sigmoid() |
| beta_s, beta_t = 1.0 - alpha_s, 1.0 - alpha_t |
| alpha_t_s = alpha_t / alpha_s |
|
|
| pi_s = self._pi_lambda(log_snr_s, mask_token_id=mask_token_id) |
| pi_t = self._pi_lambda(log_snr_t, mask_token_id=mask_token_id) |
| beta_pi_t_s = beta_t * pi_t - alpha_t_s * beta_s * pi_s |
| |
|
|
| q_t = alpha_t * x_hat + beta_t * pi_t[None, None, :] |
| q_s = alpha_s * x_hat + beta_s * pi_s[None, None, :] |
| q_t_at_z = q_t.gather(-1, z.unsqueeze(-1)).squeeze(-1) |
|
|
| z_vec = torch.nn.functional.one_hot(z, num_classes=self.config.vocab_size).to(q_t.dtype) |
| q_t_s_at_z = alpha_t_s * z_vec + beta_pi_t_s[z, None] |
|
|
| p_s_t = q_s * q_t_s_at_z / q_t_at_z[..., None] |
|
|
| z_next = torch.multinomial(p_s_t.flatten(0, 1), num_samples=1).view_as(z) |
| return z_next |
|
|
| def _sample_adaptive( |
| self, |
| z: torch.Tensor, |
| logits: torch.Tensor, |
| log_snr: torch.Tensor, |
| n_tokens: int = 1, |
| mask_token_id: int = 3, |
| temperature: float = 0.0, |
| top_p: float | None = None, |
| top_k: int | None = None, |
| ): |
| pi_vec = self._pi_lambda(log_snr, mask_token_id=mask_token_id) |
| p_noise = pi_vec[z] |
| p_noise = p_noise / p_noise.sum(dim=-1, keepdim=True) |
|
|
| x_hat = logits.softmax(dim=-1) |
| p_max = x_hat.max(dim=-1).values |
| p_curr = x_hat.gather(-1, z.unsqueeze(-1)).squeeze(-1) |
| p_delta = (p_max - p_curr) * p_noise |
|
|
| next_poss = torch.topk(p_delta, n_tokens, dim=-1).indices |
| probs = self._probs_with_topk_topp( |
| logits=logits, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| ) |
| next_tokens = torch.multinomial(probs.flatten(0, 1), num_samples=1).view_as(z) |
|
|
| z_next = z.clone() |
| batch_indices = torch.arange(z.shape[0], device=z.device).unsqueeze(-1) |
| z_next[batch_indices, next_poss] = next_tokens[batch_indices, next_poss] |
| return z_next |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: tp.Optional[torch.Tensor] = None, |
| max_length: int = 2048, |
| min_length: int = 0, |
| temperature: float = 1.0, |
| block_length: int = 128, |
| steps: int = 128, |
| top_p: tp.Optional[float] = None, |
| top_k: tp.Optional[int] = None, |
| bos_token_id: int = 0, |
| eos_token_id: int = 1, |
| pad_token_id: int = 2, |
| mask_token_id: int = 3, |
| sampling_method: tp.Literal["ancestral", "adaptive"] = "ancestral", |
| noise_schedule: tp.Literal["linear", "cosine"] | tp.Callable[[torch.Tensor], torch.Tensor] = "cosine", |
| tokens_per_step: int = 1, |
| show_progress: bool = False, |
| ): |
| r""" |
| Generates tokens with block-wise denoising diffusion. |
| |
| Parameters: |
| inputs (`torch.Tensor`): |
| The token sequence used as a prompt for the generation. |
| temperature (`float`, *optional*, defaults to 0.0): |
| The value used to module the next token probabilities. A value of 0.0 corresponds to greedy decoding. |
| block_length (`int`, *optional*, defaults to 32): |
| The size of each generation block. The model generates text in parallel within these blocks. This is a |
| key parameter for controlling the granularity of the generation process. |
| steps (`int`, *optional*, defaults to 32): |
| The number of denoising steps to perform for each block. |
| max_length (`int`, *optional*, defaults to 2048): |
| The maximum length of the sequence to be generated. |
| min_length (`int`, *optional*, defaults to 0): |
| The minimum length of the sequence to be generated. |
| top_p (`float`, *optional*): |
| If set to a float value between 0 and 1, only the most probable tokens with probabilities that add up to |
| `top_p` or higher are kept for generation (nucleus sampling). |
| top_k (`int`, *optional*): |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. |
| bos_token_id (`int`, *optional*, defaults to 0): |
| The token ID for the beginning-of-sequence token. |
| eos_token_id (`int`, *optional*, defaults to 1): |
| The token ID for the end-of-sequence token. |
| pad_token_id (`int`, *optional*, defaults to 2): |
| The token ID for the padding token. |
| mask_token_id (`int`, *optional*, defaults to 3): |
| The token ID used as a placeholder for tokens that are yet to be generated. |
| Return: |
| `torch.Tensor`: A string containing the generated token IDs, starting |
| after the prompt and stopping at the first `eos_id` or `gen_length`. |
| """ |
| if sampling_method not in ["ancestral", "adaptive"]: |
| raise ValueError(f"Unsupported sampling method: {sampling_method}") |
| if noise_schedule not in ["linear", "cosine"] and not callable(noise_schedule): |
| raise ValueError("noise_schedule must be 'linear', 'cosine', or a callable function.") |
|
|
| if inputs is None: |
| inputs = torch.tensor([[bos_token_id]], device=self.device, dtype=torch.long) |
| batch_size = 1 |
| prompt_length = 0 |
| else: |
| batch_size = inputs.shape[0] |
| prompt_length = inputs.shape[1] |
| if eos_token_id in inputs: |
| warnings.warn("Input prompt contains eos_token_id. Generation may stop earlier than expected.", stacklevel=1) |
| input_ids = inputs.to(self.device) |
|
|
| total_length = self.config.max_position_embeddings |
|
|
| if noise_schedule == "linear": |
| noise_schedule_fn = lambda t: 1.0 - t |
| elif noise_schedule == "cosine": |
| noise_schedule_fn = lambda t: 0.5 + 0.5 * torch.cos(t * torch.pi) |
| else: |
| noise_schedule_fn = noise_schedule |
|
|
| x_prior = self._sample_prior( |
| shape=(batch_size, total_length), |
| device=self.device, |
| mask_token_id=mask_token_id, |
| ) |
| x = x_prior.clone() |
| if prompt_length > 0: |
| x[:, :prompt_length] = input_ids.clone() |
|
|
| position_ids = torch.arange(total_length, device=self.device) |
| position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
|
|
| noise_mask = torch.ones_like(x, dtype=torch.bool) |
| noise_mask[:, :prompt_length] = False |
|
|
| min_log_snr = torch.tensor(self.config.min_log_snr, device=self.device) |
| max_log_snr = torch.tensor(self.config.max_log_snr, device=self.device) |
| alpha_min = torch.sigmoid(min_log_snr) |
| alpha_max = torch.sigmoid(max_log_snr) |
| ts = torch.linspace(0.0, 1.0, steps=steps + 1, device=self.device) |
| alpha_t = (alpha_max - alpha_min) * noise_schedule_fn(ts) + alpha_min |
| log_snrs = torch.log(alpha_t / (1.0 - alpha_t)).clip(min_log_snr, max_log_snr) |
|
|
| if show_progress: |
| import tqdm.auto as tqdm |
| est_num_blocks = (max_length + block_length - 1) // block_length |
| est_num_steps = est_num_blocks * steps |
| pbar = tqdm.tqdm(total=est_num_steps) |
| update_pbar = lambda n: pbar.update(n) |
| def stop_pbar(): |
| pbar.total = pbar.n |
| pbar.refresh() |
| close_pbar = lambda: pbar.close() |
| else: |
| update_pbar = lambda n: None |
| stop_pbar = lambda: None |
| close_pbar = lambda: None |
|
|
| try: |
| num_blocks = 0 |
| while True: |
| current_window_start = prompt_length + num_blocks * block_length |
| current_window_end = current_window_start + block_length |
| attn_mask = (noise_mask[..., :, None] >= noise_mask[..., None, :]) |
|
|
| keep_logits = False |
| past_key_values = None |
| for step in range(steps, 0, -1): |
| if past_key_values is None: |
| output = self.forward( |
| input_ids=x[:, :current_window_start], |
| attention_mask=attn_mask[:, :current_window_start, :current_window_start], |
| position_ids=position_ids[:, :current_window_start], |
| use_cache=True, |
| ) |
| past_key_values = output.past_key_values |
|
|
| if not keep_logits: |
| logits = self.forward( |
| input_ids=x[:, current_window_start:], |
| attention_mask=attn_mask[:, current_window_start:], |
| position_ids=position_ids[:, current_window_start:], |
| past_key_values=past_key_values, |
| ).logits |
| active_logits = logits[:, :block_length, :] |
| |
| |
| |
| |
| |
| |
| |
|
|
| active_logits[..., mask_token_id] = float("-inf") |
| min_eos_idx = max(0, min_length + prompt_length - current_window_start) |
| active_logits[:, :min_eos_idx, eos_token_id] = float("-inf") |
| |
| z_t = x[:, current_window_start:current_window_end] |
| if sampling_method == "ancestral": |
| x_hat = self._probs_with_topk_topp( |
| active_logits.to(torch.float32), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| z_s = self._sample_ancestral( |
| z=z_t, |
| x_hat=x_hat, |
| log_snr_t=log_snrs[step], |
| log_snr_s=log_snrs[step - 1], |
| mask_token_id=mask_token_id, |
| ) |
| elif sampling_method == "adaptive": |
| z_s = self._sample_adaptive( |
| z=z_t, |
| logits=active_logits.to(torch.float32), |
| log_snr=log_snrs[step], |
| n_tokens=tokens_per_step, |
| mask_token_id=mask_token_id, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| ) |
| keep_logits = (z_s == z_t).all().item() |
|
|
| x[:, current_window_start:current_window_end] = z_s.clone() |
|
|
| update_pbar(1) |
|
|
| num_blocks += 1 |
| noise_mask[:, :current_window_end] = False |
|
|
| has_eos = (x == eos_token_id).any(-1).all().item() |
| all_done = current_window_end >= max_length + prompt_length or has_eos |
| if all_done: |
| stop_pbar() |
| break |
| finally: |
| close_pbar() |
|
|
| generated_answer = x[:, :max_length + prompt_length] |
|
|
| eos_idx = (generated_answer == eos_token_id).int().argmax(dim=-1) |
| for i, idx in enumerate(eos_idx): |
| if idx > 0: |
| generated_answer[i, idx:] = pad_token_id |
|
|
| return generated_answer |
|
|