from typing import Tuple import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MaskedLMOutput from transformers.cache_utils import Cache, DynamicCache from rotary_embedding_torch import RotaryEmbedding from .config import FSTConfig # === Util === class Residual(nn.Module): def __init__(self): super().__init__() def forward(self, x: Tensor, delta: Tensor): return x + delta # === MLP === class MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int ): super().__init__() self.fc_up = nn.Linear(hidden_size, intermediate_size) self.activation = nn.GELU() self.fc_down = nn.Linear(intermediate_size, hidden_size) def forward(self, x: Tensor): return self.fc_down(self.activation(self.fc_up(x))) # === Attention === class MHAttention(nn.Module): def __init__( self, hidden_size: int, num_attention_heads: int, use_causal_attention: bool = True, layer_idx: int | None = None ): super().__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.head_dim = hidden_size // num_attention_heads assert self.head_dim * self.num_attention_heads == self.hidden_size self.use_causal_attention = use_causal_attention self.layer_idx = layer_idx self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.rotary_emb = RotaryEmbedding(dim=self.head_dim) self.scale = self.head_dim ** -0.5 def forward( self, q: Tensor, k: Tensor | None = None, v: Tensor | None = None, attention_mask: Tensor | None = None, past_key_values: Cache | None = None ): B, T, _ = q.size() if k is None: k = q if v is None: v = q q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) q = q.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2) if past_key_values is None: q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) else: cache_position = past_key_values.get_seq_length(self.layer_idx) q = self.rotary_emb.rotate_queries_or_keys(q, offset=cache_position) k = self.rotary_emb.rotate_queries_or_keys(k, offset=cache_position) k, v = past_key_values.update(k, v, self.layer_idx) is_causal = self.use_causal_attention and attention_mask is None attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, scale=self.scale, is_causal=is_causal) attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size) out = self.o_proj(attn_output) return out # === Blocks === class FeatureBlock(nn.Module): def __init__( self, config: FSTConfig, layer_idx: int = None ): super().__init__() self.attn = MHAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, use_causal_attention=config.use_causal_attention, layer_idx=layer_idx, ) self.mlp = MLP( config.hidden_size, config.intermediate_size ) self.norm_attn = nn.LayerNorm(config.hidden_size) self.norm_mlp = nn.LayerNorm(config.hidden_size) self.resid_attn = Residual() self.resid_mlp = Residual() def forward( self, x: Tensor, attention_mask: Tensor | None = None, past_key_values: Cache | None = None ): attn_out = self.attn(self.norm_attn(x), attention_mask=attention_mask, past_key_values=past_key_values) x = self.resid_attn(x, attn_out) mlp_out = self.mlp(self.norm_mlp(x)) x = self.resid_mlp(x, mlp_out) return x class PredictiveBlock(nn.Module): def __init__( self, config: FSTConfig, layer_idx: int = None ): super().__init__() self.attn = MHAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, use_causal_attention=config.use_causal_attention, layer_idx=layer_idx, ) self.mlp = MLP( config.hidden_size, config.intermediate_size ) self.norm_attn_qk = nn.LayerNorm(config.hidden_size) self.norm_attn_v = nn.LayerNorm(config.hidden_size) self.norm_mlp = nn.LayerNorm(config.hidden_size) self.resid_attn = Residual() self.resid_mlp = Residual() def forward( self, phi: Tensor, f: Tensor, e: Tensor, attention_mask: Tensor | None = None, past_key_values: Cache | None = None ): qk = self.norm_attn_qk(phi) v = self.norm_attn_v(e) attn_out = self.attn(qk, qk, v, attention_mask=attention_mask, past_key_values=past_key_values) f = self.resid_attn(f, attn_out) mlp_out = self.mlp(self.norm_mlp(f)) f = self.resid_mlp(f, mlp_out) return f # === Base Model === class FSTPreTrainedModel(PreTrainedModel): config_class = FSTConfig base_model_prefix = "model" _no_split_modules = ["FSTBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_cache_class = True # Initialization taken from Deepseek and Falcon def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class FSTModel(FSTPreTrainedModel): def __init__( self, config: FSTConfig ): super().__init__(config) self.config = config self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.feature_blocks = nn.ModuleList([FeatureBlock(config, layer_idx) for layer_idx in range(0, config.num_hidden_layers, 2)]) self.predictive_blocks = nn.ModuleList([PredictiveBlock(config, layer_idx) for layer_idx in range(1, config.num_hidden_layers, 2)]) self.norm_out = nn.LayerNorm(config.hidden_size) self.post_init() def _prepare_attention_mask( self, x: Tensor, attention_mask: Tensor | None = None, past_key_values: Cache | None = None, use_causal_attention: bool = True ): device = x.device B = x.shape[0] T = x.shape[1] T_past = past_key_values.get_seq_length() if past_key_values is not None else 0 T_total = T + T_past if use_causal_attention: causal_mask = ~torch.triu( torch.ones((T, T_total), dtype=torch.bool, device=device), diagonal=(1 + T_past) ).unsqueeze(0).unsqueeze(0) if attention_mask is not None: attn_len = attention_mask.shape[-1] if attn_len < T_total: pad = torch.ones(B, T_past, device=device, dtype=attention_mask.dtype) # Fixed: ones instead of zeros attention_mask = torch.cat([pad, attention_mask], dim=-1) elif attn_len > T_total: attention_mask = attention_mask[:, -T_total:] expanded_mask = (attention_mask == 1).view(B, 1, 1, T_total) if use_causal_attention and attention_mask is not None: return causal_mask & expanded_mask elif use_causal_attention: return causal_mask elif attention_mask is not None: # Added: handle non-causal with custom mask return expanded_mask else: return torch.ones((1, 1, T, T_total), dtype=torch.bool, device=device) def forward( self, input_ids: Tensor | None = None, attention_mask: Tensor | None = None, inputs_embeds: Tensor | None = None, past_key_values = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds" assert not (input_ids is None and inputs_embeds is None), "You must specify either input_ids or inputs_embeds" e = self.embedding(input_ids) if input_ids is not None else inputs_embeds B, T, _ = e.shape device = e.device dtype = e.dtype if not use_cache: past_key_values=None elif past_key_values is None: past_key_values = DynamicCache() # Note that we must use an attention mask when caching- otherwise, SDPA uses is_casual and breaks if attention_mask is not None or past_key_values is not None: attention_mask = self._prepare_attention_mask(e, attention_mask=attention_mask, use_causal_attention=self.config.use_causal_attention, past_key_values=past_key_values) hidden_states = [] if output_hidden_states else None phi = e f = torch.zeros(B, T, self.config.hidden_size, dtype=dtype, device=device) # Initialize f as zero for purity, but f=e also works fine for feature_block, predictive_block in zip(self.feature_blocks, self.predictive_blocks): phi = feature_block(phi, attention_mask=attention_mask, past_key_values=past_key_values) f = predictive_block(phi, f, e, attention_mask=attention_mask, past_key_values=past_key_values) if output_hidden_states: hidden_states.append(phi) hidden_states.append(f) if hidden_states is not None: hidden_states = tuple(hidden_states) f = self.norm_out(f) if return_dict: return BaseModelOutputWithPast( last_hidden_state=f, past_key_values=past_key_values, hidden_states=hidden_states ) return f, past_key_values, hidden_states # === Applied Models === class FSTForCausalLM(GenerationMixin, FSTPreTrainedModel): accepts_loss_kwargs = False def __init__( self, config: FSTConfig ): super().__init__(config) self.model = FSTModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.tie_weights() self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} # Avoids safetensor naming issues self.post_init() def get_input_embeddings(self): return self.model.embedding def set_input_embeddings(self, new_embeddings): self.model.embedding = new_embeddings def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self, missing_keys=None, recompute_mapping=False): self.lm_head.weight = self.get_input_embeddings().weight def forward( self, input_ids: Tensor | None = None, attention_mask: Tensor | None = None, past_key_values = None, inputs_embeds: Tensor | None = None, labels: Tensor | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ): if labels is not None: return_dict = True else: return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_output = self.model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states ) logits = self.lm_head(model_output[0]) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) if not return_dict: output = (logits,) + model_output[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=model_output.past_key_values, hidden_states=model_output.hidden_states ) def _prepare_inputs_for_generation( self, input_ids: Tensor, past_key_values: Cache | None = None, attention_mask: Tensor | None = None, **kwargs ): if past_key_values is not None: input_ids = input_ids[:, -1:] model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True} if attention_mask is not None: model_inputs["attention_mask"] = attention_mask for key, value in kwargs.items(): model_inputs[key] = value return model_inputs def _reorder_cache(self, past_key_values: Cache, beam_idx: Tensor): return past_key_values.reorder_cache(beam_idx) class FSTForMaskedLM(FSTPreTrainedModel): accepts_loss_kwargs = False def __init__( self, config: FSTConfig ): super().__init__(config) assert not config.use_causal_attention, "FSTForMaskedLM requires use_causal_attention=False" assert not config.use_cache, "FSTForMaskedLM requires use_cache=False (caching not supported for bidirectional models)" self.model = FSTModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.tie_weights() self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} # Avoids safetensor naming issues self.post_init() def get_input_embeddings(self): return self.model.embedding def set_input_embeddings(self, new_embeddings): self.model.embedding = new_embeddings def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self, missing_keys=None, recompute_mapping=False): self.lm_head.weight = self.get_input_embeddings().weight def forward( self, input_ids: Tensor | None = None, attention_mask: Tensor | None = None, inputs_embeds: Tensor | None = None, labels: Tensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ): if labels is not None: return_dict = True else: return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_output = self.model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=None, use_cache=False, output_hidden_states=output_hidden_states ) logits = self.lm_head(model_output[0]) loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=self.config.pad_token_id if self.config.pad_token_id is not None else -100 ) if not return_dict: output = (logits,) + model_output[1:] return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=model_output.hidden_states )