| import enum |
| import math |
| import warnings |
| from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union |
|
|
| try: |
| |
| |
| import mamba_ssm |
| except ModuleNotFoundError: |
| warnings.warn("mamba_ssm could not be imported", stacklevel=2) |
| try: |
| |
| |
| import causal_conv1d.causal_conv1d_interface as causal_conv1d |
| except ModuleNotFoundError: |
| warnings.warn("causal_conv1d could not be imported", stacklevel=2) |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.cache_utils import DynamicCache |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
|
|
| def _is_first_token(mask: torch.Tensor) -> torch.Tensor: |
| assert mask.dtype == torch.bool |
| B, Nh, q_len, kv_len = mask.shape |
| mask = mask[:, :, :, -q_len:] |
| cont = q_len != kv_len |
| v = False if cont else True |
| out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool()) |
| out = torch.cat( |
| [ |
| torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v), |
| out, |
| ], |
| dim=-1, |
| ) |
| return out |
|
|
|
|
| def _swiglu(h: torch.Tensor) -> torch.Tensor: |
| h0, h1 = h.chunk(2, dim=-1) |
| return torch.nn.functional.silu(h0) * h1 |
|
|
|
|
| class RotaryEmbedding(torch.nn.Module): |
| def __init__( |
| self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None |
| ) -> None: |
| super().__init__() |
|
|
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| |
| self._set_cos_sin_cache( |
| seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
| ) |
|
|
| def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None: |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
| def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
| return ( |
| self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
| self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
| ) |
|
|
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: |
| |
| cos = cos.squeeze(1).squeeze(0) |
| sin = sin.squeeze(1).squeeze(0) |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
| x_embed = (x * cos) + (_rotate_half(x) * sin) |
| return x_embed |
|
|
|
|
| class LinearType(str, enum.Enum): |
| Normal = "normal" |
| Fp8 = "fp8" |
| Fp8Retain = "fp8-retain" |
|
|
|
|
| class Plamo2Config(PretrainedConfig): |
| model_type: str = "plamo2" |
|
|
| def __init__( |
| self, |
| hidden_size: int = 4096, |
| num_hidden_layers: int = 32, |
| rms_norm_eps: float = 1e-6, |
| tie_word_embeddings: bool = True, |
| |
| num_attention_heads: int = 32, |
| num_key_value_heads: int = 4, |
| hidden_size_per_head: int = 128, |
| max_position_embeddings: int = 2048, |
| attention_window_size: int = 2048, |
| full_attention_idx: list[int] | None = None, |
| rope_theta: int = 10000, |
| rope_local_theta: int = 10000, |
| |
| mamba_d_state: int = 64, |
| mamba_d_conv: int = 4, |
| mamba_num_heads: int = 64, |
| mamba_step: int = 2, |
| mamba_chunk_size: int = 256, |
| mamba_enabled: bool = True, |
| |
| intermediate_size: int = 13312, |
| |
| vocab_size: int = 32000, |
| tokenizer_class: str = "Plamo2Tokenizer", |
| pad_token_id: Optional[int] = None, |
| bos_token_id: int = 1, |
| eos_token_id: int = 2, |
| |
| image_token_id: Optional[int] = None, |
| image_feature_size: Optional[int] = None, |
| image_proj_type: Literal["linear", "mlp"] = "linear", |
| |
| linear_type: LinearType = LinearType.Normal, |
| fp8_accum_dtype: Optional[str] = None, |
| |
| eval_attention_n_bit: Optional[int] = None, |
| eval_mlp_n_bit: Optional[int] = None, |
| use_cache: bool = True, |
| **kwargs: Any, |
| ) -> None: |
| |
| |
| self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings) |
| self.hidden_size = hidden_size |
| self.rms_norm_eps = rms_norm_eps |
|
|
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.hidden_size_per_head = hidden_size_per_head |
| self.num_key_value_heads = num_key_value_heads |
| self.attention_window_size = attention_window_size |
| self.full_attention_idx = full_attention_idx if full_attention_idx is not None else [] |
| self.rope_theta = rope_theta |
| self.rope_local_theta = rope_local_theta |
|
|
| self.mamba_d_state = mamba_d_state |
| self.mamba_d_conv = mamba_d_conv |
| self.mamba_num_heads = mamba_num_heads |
| self.mamba_step = mamba_step |
| self.mamba_chunk_size = mamba_chunk_size |
| self.mamba_enabled = mamba_enabled |
|
|
| self.intermediate_size = intermediate_size |
|
|
| self.vocab_size = vocab_size |
|
|
| self.image_token_id = image_token_id |
| self.image_feature_size = image_feature_size |
| self.image_proj_type = image_proj_type |
|
|
| self.linear_type = linear_type |
| self.fp8_accum_dtype = fp8_accum_dtype |
|
|
| self.eval_attention_n_bit = eval_attention_n_bit |
| self.eval_mlp_n_bit = eval_mlp_n_bit |
| self.use_cache = use_cache |
|
|
| |
| self.sliding_window = attention_window_size |
|
|
| super().__init__( |
| tokenizer_class=tokenizer_class, |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
|
|
| @property |
| def layers_block_type(self) -> list[str]: |
| return ["mamba" if is_mamba(self, i) else "attention" for i in range(self.num_hidden_layers)] |
|
|
| @property |
| def rope_local_base_freq(self) -> int: |
| return self.rope_local_theta |
|
|
|
|
| class Plamo2AttentionCache(torch.nn.Module): |
| def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None: |
| super().__init__() |
| B, nh, L, c = key.shape |
| assert len(value.shape) == 4 |
| assert value.shape[0] == B |
| assert value.shape[2] == L |
| self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False)) |
| self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False)) |
|
|
|
|
| class Plamo2MambaCache(torch.nn.Module): |
| def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None: |
| super().__init__() |
| |
| |
| assert len(conv_state.shape) == 3 |
| assert len(ssm_state.shape) == 4 |
| assert conv_state.shape[0] == ssm_state.shape[0] |
| self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False)) |
| self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False)) |
|
|
|
|
| Plamo2LayerCache = Plamo2AttentionCache | Plamo2MambaCache |
|
|
|
|
| class Plamo2Cache(torch.nn.Module): |
| """ |
| stores states of the model for fast decoding. |
| `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are |
| deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use |
| other architectures (e.g., Mamba). |
| This class provides a similar interface to `transformers.Cache`, but is designed to also handle |
| the state of Mamba properly. |
| """ |
|
|
| def __init__(self, config: Plamo2Config) -> None: |
| super().__init__() |
| self.config = config |
| self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) |
|
|
| def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: |
| c = self.cache[layer_idx] |
| if c is None: |
| return key, value |
| assert isinstance(c, Plamo2AttentionCache) |
|
|
| def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None: |
| assert len(cache.shape) == 4 |
| assert len(new_tensor.shape) == 4 |
| assert cache.shape[0] == new_tensor.shape[0] |
| assert cache.shape[1] == new_tensor.shape[1] |
| assert cache.shape[3] == new_tensor.shape[3] |
|
|
| _validate(c.key, key) |
| _validate(c.value, value) |
| assert key.shape[2] == value.shape[2] |
| return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2) |
|
|
| def update_attention( |
| self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int |
| ) -> Plamo2AttentionCache: |
| full_attn = layer_idx in self.config.full_attention_idx |
| window_size = self.config.attention_window_size |
|
|
| if self.cache[layer_idx] is None: |
| if full_attn: |
| self.cache[layer_idx] = Plamo2AttentionCache(key_states, value_states) |
| else: |
| self.cache[layer_idx] = Plamo2AttentionCache( |
| key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :] |
| ) |
| else: |
| c = self.cache[layer_idx] |
| assert isinstance(c, Plamo2AttentionCache) |
| k, v = self.append_kv(key_states, value_states, layer_idx) |
| if full_attn: |
| c.key.data = k |
| c.value.data = v |
| else: |
| c.key.data = k[:, :, -window_size:, :] |
| c.value.data = v[:, :, -window_size:, :] |
| return self.cache[layer_idx] |
|
|
| def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> Plamo2MambaCache: |
| if self.cache[layer_idx] is None: |
| self.cache[layer_idx] = Plamo2MambaCache(conv_state, ssm_state) |
| else: |
| c = self.cache[layer_idx] |
| assert isinstance(c, Plamo2MambaCache) |
| assert c.conv_state.shape == conv_state.shape |
| assert c.ssm_state.shape == ssm_state.shape |
| c.conv_state.data = conv_state |
| c.ssm_state.data = ssm_state |
| return self.cache[layer_idx] |
|
|
| def __getitem__(self, layer_idx: int) -> Plamo2LayerCache | None: |
| assert layer_idx < len(self.cache) |
| layer_cache = self.cache[layer_idx] |
| return layer_cache |
|
|
| def __len__(self) -> int: |
| return len(self.cache) |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = None) -> int: |
| if layer_idx is not None: |
| c = self.cache[layer_idx] |
| assert isinstance(c, Plamo2AttentionCache) |
| return c.key.shape[2] |
|
|
| sequence_length: int | None = None |
| for layer_cache in self.cache: |
| if isinstance(layer_cache, Plamo2AttentionCache): |
| sequence_length = ( |
| max(layer_cache.key.shape[2], sequence_length) |
| if sequence_length is not None |
| else layer_cache.key.shape[2] |
| ) |
| if sequence_length is None: |
| return 0 |
| return sequence_length |
|
|
| def get_max_length(self) -> int | None: |
| return None |
|
|
| def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
| """Given the sequence length of the new inputs, returns the usable length of the cache.""" |
| |
| |
| |
| max_length = self.get_max_length() |
| previous_seq_length = self.get_seq_length(layer_idx) |
| if max_length is not None and previous_seq_length + new_seq_length > max_length: |
| return max_length - new_seq_length |
| return previous_seq_length |
|
|
| def reorder_cache(self, beam_idx: torch.Tensor) -> None: |
| def _mamba(cache: Plamo2MambaCache) -> Plamo2MambaCache: |
| return Plamo2MambaCache( |
| conv_state=cache.conv_state.index_select(0, beam_idx), |
| ssm_state=cache.ssm_state.index_select(0, beam_idx), |
| ) |
|
|
| def _attention(cache: Plamo2AttentionCache) -> Plamo2AttentionCache: |
| return Plamo2AttentionCache( |
| key=cache.key.index_select(0, beam_idx), |
| value=cache.value.index_select(0, beam_idx), |
| ) |
|
|
| for i in range(len(self.cache)): |
| if self.cache[i] is None: |
| continue |
| layer_cache = self.cache[i] |
| if isinstance(layer_cache, Plamo2MambaCache): |
| self.cache[i] = _mamba(layer_cache) |
| else: |
| assert isinstance(layer_cache, Plamo2AttentionCache) |
| self.cache[i] = _attention(layer_cache) |
|
|
| @property |
| def seen_tokens(self) -> int | None: |
| return None |
|
|
|
|
| class DecoderInput(NamedTuple): |
| hidden_states: torch.Tensor |
| attention_mask: Optional[torch.Tensor] = None |
| past_states: Optional[Plamo2Cache] = None |
| output_hidden_states: Optional[bool] = False |
| output_attentions: Optional[bool] = False |
| gradient_checkpointing: bool = False |
| input_ids: Optional[torch.Tensor] = None |
|
|
|
|
| class DecoderOutput(NamedTuple): |
| hidden_states: torch.Tensor |
| all_hidden_states: Optional[Tuple[torch.Tensor, ...]] |
| all_self_attns: Optional[Tuple[torch.Tensor, ...]] |
|
|
|
|
| |
| def _make_causal_mask( |
| input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
| ) -> torch.Tensor: |
| """ |
| Make causal mask used for bi-directional self-attention. |
| """ |
| bsz, tgt_len = input_ids_shape |
| mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device) |
| mask_cond = torch.arange(mask.size(-1), device=device) |
| mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) |
| mask = mask.to(dtype) |
|
|
| if past_key_values_length > 0: |
| mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) |
| return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
|
|
|
|
| |
| def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor: |
| """ |
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
| """ |
| bsz, src_len = mask.size() |
| tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
|
|
| inverted_mask = 1.0 - expanded_mask |
|
|
| return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) |
|
|
|
|
| def _rms_norm( |
| hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0 |
| ) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + eps) |
| hidden_states = hidden_states.to(input_dtype) |
| if weight is not None: |
| hidden_states = (offset + weight) * hidden_states |
| return hidden_states |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| eps: float = 1e-6, |
| offset: float = 1.0, |
| device: Optional[Union[torch.device, str]] = None, |
| ) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.zeros(hidden_size, device=device)) |
| self.variance_epsilon = eps |
| self.offset = offset |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset) |
|
|
|
|
| def get_initial_dt_bias(num_heads: int) -> torch.Tensor: |
| dt_min = 0.001 |
| dt_max = 0.1 |
| dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) |
| dt = torch.clamp(dt, 1e-4) |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| return inv_dt |
|
|
|
|
| def get_initial_A(num_heads: int) -> torch.Tensor: |
| A = torch.arange(1, num_heads + 1, dtype=torch.float32) |
| return torch.log(A) |
|
|
|
|
| def _bf16_supported_in_triton() -> bool: |
| |
| |
| major, _ = torch.cuda.get_device_capability() |
| return major >= 8 |
|
|
|
|
| def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype: |
| if dtype != torch.bfloat16: |
| return dtype |
| if _bf16_supported_in_triton(): |
| return dtype |
| return torch.float32 |
|
|
|
|
| def ssd_update_state( |
| ssm_state: torch.Tensor, |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| D: torch.Tensor, |
| z: torch.Tensor, |
| dt_bias: torch.Tensor, |
| dt_softplus: bool, |
| ) -> torch.Tensor: |
| assert ssm_state.dtype == torch.float32 |
| if dt.is_cuda: |
| dtype = _get_trition_dtype(x.dtype) |
| else: |
| dtype = x.dtype |
| if dt.is_cuda: |
| f = mamba_ssm.ops.triton.selective_state_update.selective_state_update |
| else: |
| f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref |
|
|
| hidden_size_per_head = x.shape[-1] |
| d_state = B.shape[-1] |
| A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float() |
| dt = dt[..., None].expand(-1, -1, hidden_size_per_head) |
| dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head) |
| D = D[:, None].expand(-1, hidden_size_per_head) |
| assert ssm_state.dtype == torch.float32 |
| out = f( |
| ssm_state, |
| x.to(dtype), |
| dt.to(dtype), |
| A.float(), |
| B.to(dtype), |
| C.to(dtype), |
| D.float(), |
| z.to(dtype), |
| dt_bias.float(), |
| dt_softplus=dt_softplus, |
| ) |
| return out[:, None] |
|
|
|
|
| def _ssd_chunk_scan_combined_naive( |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| D: torch.Tensor, |
| z: torch.Tensor, |
| dt_bias: torch.Tensor, |
| dt_softplus: bool, |
| seq_idx: torch.Tensor | None, |
| ssm_state: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| assert ssm_state.dtype == torch.float32 |
| length = x.shape[1] |
| ys = [] |
| for i in range(length): |
| if i != 0 and seq_idx is not None: |
| ssm_state = torch.where( |
| (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None], |
| torch.zeros_like(ssm_state), |
| ssm_state, |
| ) |
| y = ssd_update_state( |
| ssm_state, |
| x[:, i], |
| dt[:, i], |
| A, |
| B[:, i], |
| C[:, i], |
| D, |
| z=z[:, i], |
| dt_bias=dt_bias, |
| dt_softplus=dt_softplus, |
| ) |
| ys.append(y) |
| return torch.cat(ys, dim=1), ssm_state |
|
|
|
|
| def _ssd_chunk_scan_combined_cpu( |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| chunk_size: int, |
| D: torch.Tensor, |
| z: torch.Tensor, |
| dt_bias: torch.Tensor, |
| dt_softplus: bool, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| dt = dt.float() |
| dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) |
| if dt_bias is not None: |
| dt = dt + dt_bias[None, :, None, None] |
| if dt_softplus: |
| dt = F.softplus(dt) |
| dA = dt * A[None, :, None, None] |
| dA_cumsum = torch.cumsum(dA, dim=-1) |
|
|
| _, _, nheads, _ = x.shape |
| dstate = B.shape[-1] |
| _ = dt.shape[2] |
|
|
| with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"): |
| |
| |
| x_ = torch.unflatten(x, 1, (-1, chunk_size)) |
| assert B.shape[2] == nheads |
| B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) |
| decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype) |
| dt_ = dt.to(x.dtype) |
|
|
| |
| B_ = B_.permute(0, 1, 3, 4, 2) |
| tmp = dt_ * decay_states |
| tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] |
| tmp = B_ * tmp |
| x_ = x_.permute(0, 1, 3, 2, 4) |
| tmp = tmp @ x_ |
| states = tmp.permute(0, 1, 2, 4, 3) |
|
|
| states_dtype = states.dtype |
| if states.dtype not in [torch.float32, torch.float64]: |
| states = states.to(torch.float32) |
| with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"): |
| out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref( |
| states.flatten(start_dim=-2, end_dim=-1), |
| dA_cumsum[:, :, :, -1], |
| ) |
| states = torch.unflatten(out, -1, (-1, dstate)) |
| last_state = torch.unflatten(last_state, -1, (-1, dstate)) |
| states = states.to(states_dtype) |
| with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"): |
| out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) |
|
|
| return out, last_state |
|
|
|
|
| @torch.profiler.record_function("ssd_chunk_scan_combined") |
| def ssd_chunk_scan_combined( |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| chunk_size: int, |
| D: torch.Tensor, |
| z: torch.Tensor, |
| dt_bias: torch.Tensor, |
| dt_softplus: bool, |
| return_final_states: bool, |
| seq_idx: torch.Tensor | None, |
| ssm_state: torch.Tensor | None, |
| ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: |
| if seq_idx is not None: |
| assert seq_idx.dtype == torch.int32 |
| assert ssm_state is None |
| assert not return_final_states |
| if ssm_state is not None: |
| assert ssm_state.dtype == torch.float32 |
| assert seq_idx is None |
|
|
| length = x.shape[1] |
|
|
| """ |
| state will be updates by following: |
| ``` |
| dt = softplus(dt) |
| dA = exp(dt * A) |
| state_next = state * dA + dB * x |
| ``` |
| |
| To avoid updating state, we set dt to -inf and x to 0 |
| because `softplus(-inf) = 0` and `exp(0) = 1` |
| """ |
| pad = (chunk_size - length % chunk_size) % chunk_size |
| x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0) |
| dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf")) |
| B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0) |
| C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0) |
| z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0) |
| if seq_idx is not None: |
| seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0) |
|
|
| length = x.shape[1] |
| assert length % chunk_size == 0, (length, chunk_size) |
|
|
| if dt.is_cuda: |
| dtype = _get_trition_dtype(x.dtype) |
| out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( |
| x.to(dtype), |
| dt.to(dtype), |
| A.float(), |
| B.to(dtype), |
| C.to(dtype), |
| chunk_size, |
| D=D.float(), |
| z=z.to(dtype), |
| initial_states=ssm_state, |
| dt_bias=dt_bias.float(), |
| dt_softplus=dt_softplus, |
| seq_idx=seq_idx, |
| return_final_states=return_final_states, |
| ) |
| if return_final_states: |
| return out[0][:, pad:], out[1] |
| else: |
| assert isinstance(out, torch.Tensor) |
| return out[:, pad:] |
| else: |
| if ssm_state is None and seq_idx is None: |
| tmp = _ssd_chunk_scan_combined_cpu( |
| x, |
| dt, |
| A, |
| B, |
| C, |
| chunk_size, |
| D=D, |
| z=z, |
| dt_bias=dt_bias.float(), |
| dt_softplus=dt_softplus, |
| ) |
| else: |
| if ssm_state is None: |
| bsize, _, num_heads, channel = x.shape |
| state = B.shape[-1] |
| ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device) |
| tmp = _ssd_chunk_scan_combined_naive( |
| x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state |
| ) |
| tmp = (tmp[0][:, pad:], tmp[1]) |
| if return_final_states: |
| return tmp |
| else: |
| return tmp[0] |
|
|
|
|
| def _causal_conv1d_update( |
| conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| dtype = conv_state.dtype |
| xBC = xBC.to(dtype) |
| weight = weight.to(dtype) |
| if conv_state.is_cuda: |
| x = causal_conv1d.causal_conv1d_update( |
| x=xBC, |
| conv_state=conv_state, |
| weight=weight[:, 0, :], |
| activation="silu", |
| ) |
| return x, conv_state |
| else: |
| x = causal_conv1d.causal_conv1d_update_ref( |
| x=xBC, |
| conv_state=conv_state, |
| weight=weight[:, 0, :], |
| activation="silu", |
| ) |
| return x, conv_state |
|
|
|
|
| def _causal_conv1d_naive( |
| conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| length = x.shape[-1] |
| out = torch.zeros_like(x) |
| for i in range(length): |
| if i != 0 and seq_idx is not None: |
| conv_state = torch.where( |
| (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None], |
| torch.zeros_like(conv_state), |
| conv_state, |
| ) |
| out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1]) |
| return out, conv_state |
|
|
|
|
| @torch.profiler.record_function("causal_conv1d") |
| def _causal_conv1d( |
| conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| dtype = x.dtype |
| if conv_state is not None: |
| dtype = conv_state.dtype |
| assert seq_idx is None |
| if seq_idx is not None: |
| assert seq_idx.dtype == torch.int32 |
| assert conv_state is None |
| weight = weight.to(dtype) |
| x = x.to(dtype) |
|
|
| return_final_states = conv_state is not None |
| if weight.is_cuda: |
| if x.stride(1) != 1: |
| |
| x = x.transpose(-1, -2).contiguous().transpose(-1, -2) |
| if conv_state is not None: |
| if conv_state.stride(1) != 1: |
| |
| conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2) |
| tmp = causal_conv1d.causal_conv1d_fn( |
| x=x, |
| weight=weight[:, 0, :], |
| initial_states=conv_state, |
| return_final_states=conv_state is not None, |
| activation="silu", |
| seq_idx=seq_idx, |
| ) |
| if conv_state is not None: |
| x, conv_state = tmp |
| else: |
| x = tmp |
| else: |
| if seq_idx is None: |
| x, conv_state = causal_conv1d.causal_conv1d_ref( |
| x=x, |
| initial_states=conv_state, |
| return_final_states=True, |
| weight=weight[:, 0, :], |
| activation="silu", |
| ) |
| else: |
| if conv_state is None: |
| bsize = x.shape[0] |
| dim = weight.shape[0] |
| d_conv = weight.shape[-1] |
| conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device) |
| x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx) |
| if return_final_states: |
| return x, conv_state |
| else: |
| return x, None |
|
|
|
|
| class Mamba(torch.nn.Module): |
| def __init__(self, config: Plamo2Config, layer_idx: int) -> None: |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.d_state = config.mamba_d_state |
| self.d_conv = config.mamba_d_conv |
| self.chunk_size = config.mamba_chunk_size |
| self.num_heads = config.mamba_num_heads |
| |
| self.hidden_size_per_head = config.hidden_size_per_head |
|
|
| self.intermediate_size = self.num_heads * self.hidden_size_per_head |
|
|
| self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) |
| self.conv1d = torch.nn.Conv1d( |
| in_channels=self.intermediate_size, |
| out_channels=self.intermediate_size, |
| bias=False, |
| kernel_size=self.d_conv, |
| groups=self.intermediate_size, |
| padding=0, |
| ) |
| self.dt_dim = max(64, self.hidden_size // 16) |
| |
| |
| |
| self.bcdt_proj = torch.nn.Linear( |
| self.intermediate_size, |
| self.dt_dim + 2 * self.d_state, |
| bias=False, |
| ) |
| self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False) |
|
|
| self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) |
| self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads)) |
| self.D = torch.nn.Parameter(torch.ones(self.num_heads)) |
|
|
| |
| self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim)) |
| self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state)) |
| self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state)) |
|
|
| self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
| def _no_weight_decay_param_names(self) -> set[str]: |
| return set(["D", "dt_bias", "A_log"]) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_states: Optional[Plamo2Cache] = None, |
| ) -> Tuple[torch.Tensor, Optional[Plamo2Cache]]: |
| bsize, length, _ = hidden_states.shape |
| is_update = length == 1 and past_states is not None |
|
|
| bool_mask: torch.Tensor | None = None |
| seq_idx: torch.Tensor | None = None |
| if attention_mask is not None: |
| if len(attention_mask.shape) == 2: |
| attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1) |
| assert len(attention_mask.shape) == 4 |
|
|
| if past_states is None: |
| |
| bool_mask_4d = attention_mask == 0 |
| is_first_token = _is_first_token(bool_mask_4d)[:, 0, :] |
| seq_idx = torch.cumsum(is_first_token, dim=-1) - 1 |
| seq_idx = seq_idx.to(torch.int32) |
|
|
| |
| |
| attention_mask = attention_mask[:, 0, -length:, -length:] |
| bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0 |
|
|
| conv_state: torch.Tensor | None |
| ssm_state: torch.Tensor | None |
| if past_states is None: |
| conv_state = None |
| ssm_state = None |
| elif past_states[self.layer_idx] is None: |
| conv_state = torch.zeros( |
| bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device |
| ) |
| ssm_state = torch.zeros( |
| bsize, |
| self.num_heads, |
| self.hidden_size_per_head, |
| self.d_state, |
| dtype=torch.float32, |
| device=hidden_states.device, |
| ) |
| else: |
| c = past_states[self.layer_idx] |
| assert isinstance(c, Plamo2MambaCache) |
| conv_state = c.conv_state |
| ssm_state = c.ssm_state |
|
|
| zx = self.in_proj(hidden_states) |
| zx = zx.reshape(bsize, length, self.num_heads, -1) |
| |
| |
| z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1) |
|
|
| |
| x = x.reshape(bsize, length, -1).transpose(1, 2) |
| if bool_mask is not None: |
| x = torch.where(bool_mask[:, None, :], x, 0.0) |
| if is_update: |
| assert conv_state is not None |
| x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) |
| else: |
| x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) |
| x = x.to(dtype=hidden_states.dtype) |
| x = x.transpose(1, 2) |
| x = x.reshape(bsize, length, -1) |
| |
| |
| |
| |
| BCdt = self.bcdt_proj(x) |
| x = x.reshape(bsize, length, self.num_heads, -1) |
| B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1) |
| B = B[:, :, None, :] |
| C = C[:, :, None, :] |
|
|
| A = -torch.exp(self.A_log.float()) |
| dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] |
| B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] |
| C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] |
|
|
| |
| dt = self.dt_proj(dt)[..., None] |
|
|
| |
| B = B.expand(-1, -1, self.num_heads, -1) |
| C = C.expand(-1, -1, self.num_heads, -1) |
|
|
| if bool_mask is not None: |
| """ |
| state will be updates by following: |
| ``` |
| dt = softplus(dt) |
| dA = exp(dt * A) |
| state_next = state * dA + dB * x |
| ``` |
| |
| To avoid updating state, we set dt to -inf and x to 0 |
| because `softplus(-inf) = 0` and `exp(0) = 1` |
| """ |
| dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf")) |
| x = torch.where(bool_mask[:, :, None, None], x, 0.0) |
|
|
| |
| if is_update: |
| assert ssm_state is not None |
| out = ssd_update_state( |
| ssm_state, |
| x[:, 0], |
| dt[:, 0].reshape(bsize, -1), |
| A, |
| B[:, 0], |
| C[:, 0], |
| D=self.D, |
| z=z[:, 0], |
| dt_bias=self.dt_bias, |
| dt_softplus=True, |
| ) |
| else: |
| tmp = ssd_chunk_scan_combined( |
| x, |
| dt.reshape(bsize, length, -1), |
| A, |
| B, |
| C, |
| self.chunk_size, |
| D=self.D, |
| z=z, |
| dt_bias=self.dt_bias, |
| dt_softplus=True, |
| return_final_states=past_states is not None, |
| seq_idx=seq_idx, |
| ssm_state=ssm_state, |
| ) |
| if past_states is not None: |
| out, ssm_state = tmp |
| else: |
| assert isinstance(tmp, torch.Tensor) |
| out = tmp |
|
|
| y = self.out_proj(out.reshape(bsize, length, -1)) |
|
|
| if past_states is not None: |
| assert ssm_state is not None |
| assert conv_state is not None |
| past_states.update_mamba(conv_state, ssm_state, self.layer_idx) |
|
|
| return y, past_states |
|
|
|
|
| def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor: |
| max_len = max(q_len, kv_len) |
| mask = ( |
| torch.ones(max_len, max_len, dtype=torch.bool, device=device) |
| .triu(diagonal=-window_size) |
| .tril(diagonal=window_size) |
| ) |
| return mask[-q_len:, -kv_len:] |
|
|
|
|
| class Attention(torch.nn.Module): |
| def __init__(self, config: Plamo2Config, layer_idx: int) -> None: |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| head_dim = config.hidden_size_per_head |
| self.max_position_embeddings = config.max_position_embeddings |
|
|
| self.q_num_heads = config.num_attention_heads |
| self.qk_dim = self.v_dim = head_dim |
| self.k_num_heads = self.v_num_heads = config.num_key_value_heads |
| assert self.q_num_heads % self.k_num_heads == 0 |
| self.n_group = self.q_num_heads // self.k_num_heads |
|
|
| self.q_proj_dim = self.q_num_heads * self.qk_dim |
| self.k_proj_dim = self.k_num_heads * self.qk_dim |
| self.v_proj_dim = self.k_num_heads * self.v_dim |
| self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False) |
| self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) |
|
|
| self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim))) |
| self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim))) |
|
|
| self.full_attn = self.layer_idx in self.config.full_attention_idx |
| base = self.config.rope_theta if self.full_attn else self.config.rope_local_theta |
| self.rotary_emb = RotaryEmbedding( |
| self.qk_dim, max_position_embeddings=self.config.attention_window_size, base=base |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_states: Optional[Plamo2Cache] = None, |
| output_attentions: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Plamo2Cache]]: |
| bsz, q_len, _ = hidden_states.size() |
|
|
| qkv = self.qkv_proj(hidden_states) |
| query_states, key_states, value_states = torch.split( |
| qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1 |
| ) |
| query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2) |
|
|
| attn_dtype = query_states.dtype |
|
|
| query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] |
| key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None] |
|
|
| if past_states is not None: |
| |
| key_states_new = key_states |
| value_states_new = value_states |
| key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) |
| past_states.update_attention(key_states_new, value_states_new, self.layer_idx) |
|
|
| kv_seq_len = key_states.shape[-2] |
| device = hidden_states.device |
| position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None] |
| q_position_ids = position_ids[:, -query_states.shape[2] :] |
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids) |
| key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) |
| |
|
|
| def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: |
| t = torch.repeat_interleave(t, repeat, dim=1) |
| return t[:, :target] |
|
|
| |
| assert self.k_num_heads == self.v_num_heads |
| key_states = _expand_kv(key_states, self.n_group, self.q_num_heads) |
| value_states = _expand_kv(value_states, self.n_group, self.q_num_heads) |
|
|
| query_states = query_states.to(attn_dtype) |
| key_states = key_states.to(attn_dtype) |
| value_states = value_states.to(attn_dtype) |
| if attention_mask is not None and attention_mask.dtype != torch.bool: |
| attention_mask = attention_mask.to(attn_dtype) |
| if attention_mask is None: |
| if not self.full_attn: |
| assert key_states.shape[2] <= self.config.attention_window_size + 1 |
| attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True) |
| else: |
| if attention_mask.dtype == torch.bool: |
| attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf")) |
| if len(attention_mask.shape) == 2: |
| attention_mask = attention_mask[None, None] |
| assert len(attention_mask.shape) == 4 |
|
|
| if not self.full_attn: |
| m_swa = swa_mask( |
| query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size |
| ) |
| |
| m_swa = m_swa[None, None] |
| attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :] |
| attention_mask = torch.where(m_swa, attention_mask, float("-inf")) |
|
|
| |
| |
| bool_mask = torch.logical_not(torch.isneginf(attention_mask)) |
| valid_tokens = torch.sum(bool_mask, dim=-1).bool() |
| attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0)) |
| attn_output = F.scaled_dot_product_attention( |
| query_states, key_states, value_states, attn_mask=attention_mask |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2) |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim) |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_states |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config: Plamo2Config) -> None: |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) |
| self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.gate_up_proj(x) |
| h = _swiglu(h) |
| return self.down_proj(h) |
|
|
|
|
| class Plamo2DecoderLayer(torch.nn.Module): |
| def __init__(self, config: Plamo2Config, layer_idx: int) -> None: |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.is_mamba = config.layers_block_type[layer_idx] == "mamba" |
| self.mixer: torch.nn.Module |
| if self.is_mamba: |
| self.mixer = Mamba(config, layer_idx) |
| else: |
| self.mixer = Attention(config, layer_idx) |
| self.mlp = MLP(config) |
| """ |
| Notes: The model performance was degraded when setting all offsets to 1. |
| """ |
| self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) |
| self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5) |
| self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) |
| self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_state: Optional[Plamo2Cache] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[Any, ...]: |
| |
| residual = hidden_states |
| hidden_states = self.pre_mixer_norm(hidden_states) |
|
|
| |
| if self.is_mamba: |
| hidden_states_sa, present_key_value = self.mixer( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| past_states=past_state, |
| ) |
| self_attn_weights = None |
| else: |
| hidden_states_sa, self_attn_weights, present_key_value = self.mixer( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| past_states=past_state, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states_sa = self.post_mixer_norm(hidden_states_sa) |
| hidden_states = residual + hidden_states_sa |
|
|
| residual = hidden_states |
| hidden_states = self.pre_mlp_norm(hidden_states) |
|
|
| |
| hidden_states_mlp = self.mlp(hidden_states) |
|
|
| |
| hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp) |
| hidden_states = residual + hidden_states_mlp |
|
|
| outputs: Any = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| return outputs |
|
|
|
|
| def is_mamba(config: Plamo2Config, i: int) -> bool: |
| if not config.mamba_enabled: |
| return False |
| assert config.mamba_step > 1 |
| assert i < config.num_hidden_layers |
|
|
| if config.num_hidden_layers <= (config.mamba_step // 2): |
| |
| return i != config.num_hidden_layers - 1 |
| return (i % config.mamba_step) != (config.mamba_step // 2) |
|
|
|
|
| class Plamo2Decoder(torch.nn.Module): |
| def __init__(self, config: Plamo2Config) -> None: |
| super().__init__() |
|
|
| self.layers = torch.nn.ModuleList( |
| [Plamo2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] |
| ) |
| self.gradient_checkpointing = False |
|
|
| def forward(self, x: DecoderInput) -> DecoderOutput: |
| all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None |
| all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None |
| hidden_states = x.hidden_states |
|
|
| for decoder_layer in self.layers: |
| if x.output_hidden_states: |
| assert all_hidden_states is not None |
| all_hidden_states += (hidden_states,) |
|
|
| if self.training and x.gradient_checkpointing: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| x.attention_mask, |
| x.past_states, |
| x.output_attentions, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=x.attention_mask, |
| past_state=x.past_states, |
| output_attentions=x.output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if x.output_attentions: |
| assert layer_outputs[1] is not None |
| assert all_self_attns is not None |
| all_self_attns += (layer_outputs[1],) |
| return DecoderOutput(hidden_states, all_hidden_states, all_self_attns) |
|
|
|
|
| class Plamo2PreTrainedModel(PreTrainedModel): |
| config_class = Plamo2Config |
| _no_split_modules: List[str] |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["PlamoDecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] |
|
|
| def _init_weights(self, module: torch.nn.Module) -> None: |
| std = 0.02 |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| class Plamo2Model(Plamo2PreTrainedModel): |
| def __init__(self, config: Plamo2Config): |
| super().__init__(config) |
| assert config.eval_attention_n_bit is None |
| assert config.eval_mlp_n_bit is None |
|
|
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| if config.image_feature_size is not None: |
| if config.image_proj_type == "mlp": |
| self.image_proj = MLPImageProjector(config) |
| elif config.image_proj_type == "linear": |
| self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) |
| else: |
| raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}") |
| self.layers = Plamo2Decoder(config) |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> torch.nn.Embedding: |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value: torch.nn.Embedding) -> None: |
| self.embed_tokens = value |
|
|
| |
| def _prepare_decoder_attention_mask( |
| self, |
| attention_mask: torch.Tensor, |
| input_shape: Tuple[int, int], |
| inputs_embeds: Optional[torch.Tensor], |
| past_key_values_length: int, |
| ) -> Optional[torch.Tensor]: |
| |
| |
| combined_attention_mask: Optional[torch.Tensor] = None |
| if input_shape[-1] > 1: |
| assert inputs_embeds is not None |
| combined_attention_mask = _make_causal_mask( |
| input_shape, |
| inputs_embeds.dtype, |
| device=inputs_embeds.device, |
| past_key_values_length=past_key_values_length, |
| ) |
| input_shape = (input_shape[0], combined_attention_mask.shape[2]) |
|
|
| if attention_mask is not None: |
| if attention_mask.dim() == 4: |
| |
| expanded_attn_mask = attention_mask |
| else: |
| |
| assert inputs_embeds is not None |
| expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( |
| inputs_embeds.device |
| ) |
| combined_attention_mask = ( |
| expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask |
| ) |
|
|
| return combined_attention_mask |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Plamo2Cache | DynamicCache] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| image_features: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Any, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| batch_size, seq_length, _ = inputs_embeds.shape |
|
|
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
| if past_key_values is not None: |
| |
| if not isinstance(past_key_values, Plamo2Cache): |
| past_key_values_prev = past_key_values |
| past_key_values = Plamo2Cache(self.config) |
|
|
| |
| assert len(past_key_values_prev) == 0 or not any( |
| layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers |
| ) |
| assert isinstance(past_key_values, Plamo2Cache) |
| past_key_values_length = past_key_values.get_seq_length() |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
| assert cache_position is None, "cache_position is not supported yet" |
|
|
| if image_features is not None: |
| assert self.config.image_token_id is not None |
| image_embeds = self.image_proj(image_features) |
| assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape) |
| mask = input_ids == self.config.image_token_id |
| inputs_embeds[mask] = image_embeds[mask] |
|
|
| |
| require_attn_mask = False |
| if not self.training or past_key_values is not None: |
| require_attn_mask = True |
| if seq_length_with_past > self.config.attention_window_size + 1: |
| require_attn_mask = True |
| if require_attn_mask and attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device |
| ) |
| if attention_mask is not None: |
| attention_mask = self._prepare_decoder_attention_mask( |
| attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = Plamo2Cache(self.config) |
|
|
| |
| out = self.layers( |
| DecoderInput( |
| hidden_states, |
| attention_mask, |
| past_key_values, |
| output_hidden_states, |
| output_attentions, |
| self.gradient_checkpointing, |
| ) |
| ) |
| assert isinstance(out, DecoderOutput) |
| hidden_states = out.hidden_states |
| all_hidden_states = out.all_hidden_states |
| all_self_attns = out.all_self_attns |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| assert all_hidden_states is not None |
| all_hidden_states += (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class Plamo2ForCausalLM(Plamo2PreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| |
| |
| |
| |
| |
| _supports_param_buffer_assignment = False |
|
|
| def __init__(self, config: Plamo2Config) -> None: |
| super().__init__(config) |
| self.model = Plamo2Model(config) |
|
|
| self.vocab_size = config.vocab_size |
| vocab_size = ((self.vocab_size + 15) // 16) * 16 |
| self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> torch.nn.Embedding: |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value: torch.nn.Embedding) -> None: |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self) -> torch.nn.Module: |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder: Plamo2Model) -> None: |
| self.model = decoder |
|
|
| def get_decoder(self) -> Plamo2Model: |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Plamo2Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| image_features: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| **kwargs: Any, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| |
| >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| |
| >>> prompt = "Hey, are you consciours? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." |
| ```""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| image_features=image_features, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = logits[:, slice_indices, : self.vocab_size] |
|
|
| loss = None |
| if labels is not None: |
| if len(kwargs) > 0 and set(kwargs.keys()) != set(["ignore_index"]): |
| warnings.warn( |
| f"The following kwargs may not be supported: {', '.join(kwargs.keys())}. ", |
| stacklevel=2, |
| ) |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.Tensor, |
| past_key_values: Optional[Plamo2Cache] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| image_features: Optional[torch.Tensor] = None, |
| **kwargs: Any, |
| ) -> Dict[str, Any]: |
| |
| |
| |
| |
| if isinstance(past_key_values, Plamo2Cache): |
| input_ids = input_ids[:, -1:] |
| if image_features is not None: |
| image_features = image_features[:, -1:, :] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if isinstance(past_key_values, Plamo2Cache): |
| position_ids = position_ids[:, -1].unsqueeze(-1) |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| "image_features": image_features, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values: Plamo2Cache, beam_idx: torch.Tensor) -> Plamo2Cache: |
| past_key_values.reorder_cache(beam_idx) |
| return past_key_values |
|
|
|
|
| class MLPImageProjector(nn.Module): |
| def __init__(self, config: Plamo2Config) -> None: |
| super().__init__() |
| self.config = config |
|
|
| assert config.image_feature_size is not None |
|
|
| |
| self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps) |
| self.bias0 = Bias(config.image_feature_size) |
|
|
| |
| self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) |
| self.bias1 = Bias(config.hidden_size) |
| self.act1 = nn.GELU() |
|
|
| self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| self.bias2 = Bias(config.hidden_size) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| ) -> torch.Tensor: |
| hidden_states = self.norm0(hidden_states) |
| hidden_states = self.bias0(hidden_states) |
|
|
| hidden_states = self.linear1(hidden_states) |
| hidden_states = self.bias1(hidden_states) |
| hidden_states = self.act1(hidden_states) |
|
|
| hidden_states = self.linear2(hidden_states) |
| hidden_states = self.bias2(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class Bias(nn.Module): |
| def __init__(self, num_features: int) -> None: |
| super().__init__() |
| self._bias = nn.Parameter(torch.zeros((num_features,))) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| ) -> torch.Tensor: |
| return x + self._bias |
|
|