| from typing import Tuple |
|
|
| import torch |
| from torch import Tensor |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MaskedLMOutput |
| from transformers.cache_utils import Cache, DynamicCache |
|
|
| from rotary_embedding_torch import RotaryEmbedding |
| from .config import FSTConfig |
|
|
| |
|
|
| class Residual(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x: Tensor, delta: Tensor): |
| return x + delta |
|
|
| |
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| intermediate_size: int |
| ): |
| super().__init__() |
| |
| self.fc_up = nn.Linear(hidden_size, intermediate_size) |
| self.activation = nn.GELU() |
| self.fc_down = nn.Linear(intermediate_size, hidden_size) |
|
|
| def forward(self, x: Tensor): |
| return self.fc_down(self.activation(self.fc_up(x))) |
|
|
| |
|
|
| class MHAttention(nn.Module): |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_attention_heads: int, |
| use_causal_attention: bool = True, |
| layer_idx: int | None = None |
| ): |
| super().__init__() |
| |
| self.hidden_size = hidden_size |
| self.num_attention_heads = num_attention_heads |
| self.head_dim = hidden_size // num_attention_heads |
|
|
| assert self.head_dim * self.num_attention_heads == self.hidden_size |
|
|
| self.use_causal_attention = use_causal_attention |
| self.layer_idx = layer_idx |
| |
| self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
|
|
| self.rotary_emb = RotaryEmbedding(dim=self.head_dim) |
| self.scale = self.head_dim ** -0.5 |
|
|
| def forward( |
| self, |
| q: Tensor, |
| k: Tensor | None = None, |
| v: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| past_key_values: Cache | None = None |
| ): |
| B, T, _ = q.size() |
|
|
| if k is None: |
| k = q |
| if v is None: |
| v = q |
|
|
| q = self.q_proj(q) |
| k = self.k_proj(k) |
| v = self.v_proj(v) |
|
|
| q = q.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2) |
|
|
| if past_key_values is None: |
| |
| q = self.rotary_emb.rotate_queries_or_keys(q) |
| k = self.rotary_emb.rotate_queries_or_keys(k) |
|
|
| else: |
| |
| cache_position = past_key_values.get_seq_length(self.layer_idx) |
| |
| q = self.rotary_emb.rotate_queries_or_keys(q, offset=cache_position) |
| k = self.rotary_emb.rotate_queries_or_keys(k, offset=cache_position) |
| |
| k, v = past_key_values.update(k, v, self.layer_idx) |
|
|
| is_causal = self.use_causal_attention and attention_mask is None |
| attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, scale=self.scale, is_causal=is_causal) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size) |
| out = self.o_proj(attn_output) |
| |
| return out |
|
|
| |
|
|
| class FeatureBlock(nn.Module): |
|
|
| def __init__( |
| self, |
| config: FSTConfig, |
| layer_idx: int = None |
| ): |
| super().__init__() |
| |
| self.attn = MHAttention( |
| hidden_size=config.hidden_size, |
| num_attention_heads=config.num_attention_heads, |
| use_causal_attention=config.use_causal_attention, |
| layer_idx=layer_idx, |
| ) |
| |
| self.mlp = MLP( |
| config.hidden_size, |
| config.intermediate_size |
| ) |
| |
| self.norm_attn = nn.LayerNorm(config.hidden_size) |
| self.norm_mlp = nn.LayerNorm(config.hidden_size) |
|
|
| self.resid_attn = Residual() |
| self.resid_mlp = Residual() |
|
|
| def forward( |
| self, |
| x: Tensor, |
| attention_mask: Tensor | None = None, |
| past_key_values: Cache | None = None |
| ): |
|
|
| attn_out = self.attn(self.norm_attn(x), attention_mask=attention_mask, past_key_values=past_key_values) |
| x = self.resid_attn(x, attn_out) |
|
|
| mlp_out = self.mlp(self.norm_mlp(x)) |
| x = self.resid_mlp(x, mlp_out) |
|
|
| return x |
|
|
| class PredictiveBlock(nn.Module): |
|
|
| def __init__( |
| self, |
| config: FSTConfig, |
| layer_idx: int = None |
| ): |
| super().__init__() |
| |
| self.attn = MHAttention( |
| hidden_size=config.hidden_size, |
| num_attention_heads=config.num_attention_heads, |
| use_causal_attention=config.use_causal_attention, |
| layer_idx=layer_idx, |
| ) |
| |
| self.mlp = MLP( |
| config.hidden_size, |
| config.intermediate_size |
| ) |
| |
| self.norm_attn_qk = nn.LayerNorm(config.hidden_size) |
| self.norm_attn_v = nn.LayerNorm(config.hidden_size) |
| self.norm_mlp = nn.LayerNorm(config.hidden_size) |
|
|
| self.resid_attn = Residual() |
| self.resid_mlp = Residual() |
|
|
| def forward( |
| self, |
| phi: Tensor, |
| f: Tensor, |
| e: Tensor, |
| attention_mask: Tensor | None = None, |
| past_key_values: Cache | None = None |
| ): |
| |
| qk = self.norm_attn_qk(phi) |
| v = self.norm_attn_v(e) |
|
|
| attn_out = self.attn(qk, qk, v, attention_mask=attention_mask, past_key_values=past_key_values) |
| f = self.resid_attn(f, attn_out) |
|
|
| mlp_out = self.mlp(self.norm_mlp(f)) |
| f = self.resid_mlp(f, mlp_out) |
|
|
| return f |
|
|
| |
|
|
| class FSTPreTrainedModel(PreTrainedModel): |
| |
| config_class = FSTConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["FSTBlock"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn_2 = True |
| _supports_cache_class = True |
|
|
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| class FSTModel(FSTPreTrainedModel): |
| |
| def __init__( |
| self, |
| config: FSTConfig |
| ): |
| super().__init__(config) |
|
|
| self.config = config |
| self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
| self.feature_blocks = nn.ModuleList([FeatureBlock(config, layer_idx) for layer_idx in range(0, config.num_hidden_layers, 2)]) |
| self.predictive_blocks = nn.ModuleList([PredictiveBlock(config, layer_idx) for layer_idx in range(1, config.num_hidden_layers, 2)]) |
| self.norm_out = nn.LayerNorm(config.hidden_size) |
|
|
| self.post_init() |
|
|
| def _prepare_attention_mask( |
| self, |
| x: Tensor, |
| attention_mask: Tensor | None = None, |
| past_key_values: Cache | None = None, |
| use_causal_attention: bool = True |
| ): |
| |
| device = x.device |
| B = x.shape[0] |
| T = x.shape[1] |
|
|
| T_past = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| T_total = T + T_past |
|
|
| if use_causal_attention: |
| causal_mask = ~torch.triu( |
| torch.ones((T, T_total), dtype=torch.bool, device=device), |
| diagonal=(1 + T_past) |
| ).unsqueeze(0).unsqueeze(0) |
| |
| if attention_mask is not None: |
| attn_len = attention_mask.shape[-1] |
|
|
| if attn_len < T_total: |
| pad = torch.ones(B, T_past, device=device, dtype=attention_mask.dtype) |
| attention_mask = torch.cat([pad, attention_mask], dim=-1) |
| elif attn_len > T_total: |
| attention_mask = attention_mask[:, -T_total:] |
| |
| expanded_mask = (attention_mask == 1).view(B, 1, 1, T_total) |
| |
| if use_causal_attention and attention_mask is not None: |
| return causal_mask & expanded_mask |
| elif use_causal_attention: |
| return causal_mask |
| elif attention_mask is not None: |
| return expanded_mask |
| else: |
| return torch.ones((1, 1, T, T_total), dtype=torch.bool, device=device) |
| |
| def forward( |
| self, |
| input_ids: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| past_key_values = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| **kwargs, |
| ): |
| |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds" |
| assert not (input_ids is None and inputs_embeds is None), "You must specify either input_ids or inputs_embeds" |
|
|
| e = self.embedding(input_ids) if input_ids is not None else inputs_embeds |
|
|
| B, T, _ = e.shape |
| device = e.device |
| dtype = e.dtype |
|
|
| if not use_cache: |
| past_key_values=None |
| elif past_key_values is None: |
| past_key_values = DynamicCache() |
|
|
| |
| if attention_mask is not None or past_key_values is not None: |
| attention_mask = self._prepare_attention_mask(e, attention_mask=attention_mask, use_causal_attention=self.config.use_causal_attention, past_key_values=past_key_values) |
|
|
| hidden_states = [] if output_hidden_states else None |
|
|
| phi = e |
| f = torch.zeros(B, T, self.config.hidden_size, dtype=dtype, device=device) |
|
|
| for feature_block, predictive_block in zip(self.feature_blocks, self.predictive_blocks): |
| |
| phi = feature_block(phi, attention_mask=attention_mask, past_key_values=past_key_values) |
| f = predictive_block(phi, f, e, attention_mask=attention_mask, past_key_values=past_key_values) |
|
|
| if output_hidden_states: |
| hidden_states.append(phi) |
| hidden_states.append(f) |
|
|
| if hidden_states is not None: |
| hidden_states = tuple(hidden_states) |
|
|
| f = self.norm_out(f) |
|
|
| if return_dict: |
| return BaseModelOutputWithPast( |
| last_hidden_state=f, |
| past_key_values=past_key_values, |
| hidden_states=hidden_states |
| ) |
|
|
| return f, past_key_values, hidden_states |
|
|
| |
|
|
| class FSTForCausalLM(GenerationMixin, FSTPreTrainedModel): |
| |
| accepts_loss_kwargs = False |
|
|
| def __init__( |
| self, |
| config: FSTConfig |
| ): |
| super().__init__(config) |
| |
| self.model = FSTModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| if config.tie_word_embeddings: |
| self.tie_weights() |
| self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embedding |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.model.embedding = new_embeddings |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def tie_weights(self, missing_keys=None, recompute_mapping=False): |
| self.lm_head.weight = self.get_input_embeddings().weight |
|
|
| def forward( |
| self, |
| input_ids: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| past_key_values = None, |
| inputs_embeds: Tensor | None = None, |
| labels: Tensor | None = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| **kwargs, |
| ): |
|
|
| if labels is not None: |
| return_dict = True |
| else: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| model_output = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_hidden_states=output_hidden_states |
| ) |
|
|
| logits = self.lm_head(model_output[0]) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100 |
| ) |
|
|
| if not return_dict: |
| output = (logits,) + model_output[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=model_output.past_key_values, |
| hidden_states=model_output.hidden_states |
| ) |
|
|
| def _prepare_inputs_for_generation( |
| self, |
| input_ids: Tensor, |
| past_key_values: Cache | None = None, |
| attention_mask: Tensor | None = None, |
| **kwargs |
| ): |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True} |
|
|
| if attention_mask is not None: |
| model_inputs["attention_mask"] = attention_mask |
|
|
| for key, value in kwargs.items(): |
| model_inputs[key] = value |
|
|
| return model_inputs |
| |
| def _reorder_cache(self, past_key_values: Cache, beam_idx: Tensor): |
| return past_key_values.reorder_cache(beam_idx) |
| |
| class FSTForMaskedLM(FSTPreTrainedModel): |
| |
| accepts_loss_kwargs = False |
|
|
| def __init__( |
| self, |
| config: FSTConfig |
| ): |
| super().__init__(config) |
| |
| assert not config.use_causal_attention, "FSTForMaskedLM requires use_causal_attention=False" |
| assert not config.use_cache, "FSTForMaskedLM requires use_cache=False (caching not supported for bidirectional models)" |
| |
| self.model = FSTModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| if config.tie_word_embeddings: |
| self.tie_weights() |
| self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embedding |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.model.embedding = new_embeddings |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def tie_weights(self, missing_keys=None, recompute_mapping=False): |
| self.lm_head.weight = self.get_input_embeddings().weight |
|
|
| def forward( |
| self, |
| input_ids: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| labels: Tensor | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| **kwargs, |
| ): |
|
|
| if labels is not None: |
| return_dict = True |
| else: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| model_output = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| past_key_values=None, |
| use_cache=False, |
| output_hidden_states=output_hidden_states |
| ) |
|
|
| logits = self.lm_head(model_output[0]) |
|
|
| loss = None |
| if labels is not None: |
| |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| ignore_index=self.config.pad_token_id if self.config.pad_token_id is not None else -100 |
| ) |
|
|
| if not return_dict: |
| output = (logits,) + model_output[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=model_output.hidden_states |
| ) |
|
|