fst_1_3B / model.py
williamconvertino's picture
Update model.py
3ff1b69 verified
from typing import Tuple
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MaskedLMOutput
from transformers.cache_utils import Cache, DynamicCache
from rotary_embedding_torch import RotaryEmbedding
from .config import FSTConfig
# === Util ===
class Residual(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: Tensor, delta: Tensor):
return x + delta
# === MLP ===
class MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int
):
super().__init__()
self.fc_up = nn.Linear(hidden_size, intermediate_size)
self.activation = nn.GELU()
self.fc_down = nn.Linear(intermediate_size, hidden_size)
def forward(self, x: Tensor):
return self.fc_down(self.activation(self.fc_up(x)))
# === Attention ===
class MHAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
use_causal_attention: bool = True,
layer_idx: int | None = None
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
assert self.head_dim * self.num_attention_heads == self.hidden_size
self.use_causal_attention = use_causal_attention
self.layer_idx = layer_idx
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
self.scale = self.head_dim ** -0.5
def forward(
self,
q: Tensor,
k: Tensor | None = None,
v: Tensor | None = None,
attention_mask: Tensor | None = None,
past_key_values: Cache | None = None
):
B, T, _ = q.size()
if k is None:
k = q
if v is None:
v = q
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
q = q.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2)
if past_key_values is None:
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
else:
cache_position = past_key_values.get_seq_length(self.layer_idx)
q = self.rotary_emb.rotate_queries_or_keys(q, offset=cache_position)
k = self.rotary_emb.rotate_queries_or_keys(k, offset=cache_position)
k, v = past_key_values.update(k, v, self.layer_idx)
is_causal = self.use_causal_attention and attention_mask is None
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, scale=self.scale, is_causal=is_causal)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size)
out = self.o_proj(attn_output)
return out
# === Blocks ===
class FeatureBlock(nn.Module):
def __init__(
self,
config: FSTConfig,
layer_idx: int = None
):
super().__init__()
self.attn = MHAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
use_causal_attention=config.use_causal_attention,
layer_idx=layer_idx,
)
self.mlp = MLP(
config.hidden_size,
config.intermediate_size
)
self.norm_attn = nn.LayerNorm(config.hidden_size)
self.norm_mlp = nn.LayerNorm(config.hidden_size)
self.resid_attn = Residual()
self.resid_mlp = Residual()
def forward(
self,
x: Tensor,
attention_mask: Tensor | None = None,
past_key_values: Cache | None = None
):
attn_out = self.attn(self.norm_attn(x), attention_mask=attention_mask, past_key_values=past_key_values)
x = self.resid_attn(x, attn_out)
mlp_out = self.mlp(self.norm_mlp(x))
x = self.resid_mlp(x, mlp_out)
return x
class PredictiveBlock(nn.Module):
def __init__(
self,
config: FSTConfig,
layer_idx: int = None
):
super().__init__()
self.attn = MHAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
use_causal_attention=config.use_causal_attention,
layer_idx=layer_idx,
)
self.mlp = MLP(
config.hidden_size,
config.intermediate_size
)
self.norm_attn_qk = nn.LayerNorm(config.hidden_size)
self.norm_attn_v = nn.LayerNorm(config.hidden_size)
self.norm_mlp = nn.LayerNorm(config.hidden_size)
self.resid_attn = Residual()
self.resid_mlp = Residual()
def forward(
self,
phi: Tensor,
f: Tensor,
e: Tensor,
attention_mask: Tensor | None = None,
past_key_values: Cache | None = None
):
qk = self.norm_attn_qk(phi)
v = self.norm_attn_v(e)
attn_out = self.attn(qk, qk, v, attention_mask=attention_mask, past_key_values=past_key_values)
f = self.resid_attn(f, attn_out)
mlp_out = self.mlp(self.norm_mlp(f))
f = self.resid_mlp(f, mlp_out)
return f
# === Base Model ===
class FSTPreTrainedModel(PreTrainedModel):
config_class = FSTConfig
base_model_prefix = "model"
_no_split_modules = ["FSTBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_cache_class = True
# Initialization taken from Deepseek and Falcon
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class FSTModel(FSTPreTrainedModel):
def __init__(
self,
config: FSTConfig
):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.feature_blocks = nn.ModuleList([FeatureBlock(config, layer_idx) for layer_idx in range(0, config.num_hidden_layers, 2)])
self.predictive_blocks = nn.ModuleList([PredictiveBlock(config, layer_idx) for layer_idx in range(1, config.num_hidden_layers, 2)])
self.norm_out = nn.LayerNorm(config.hidden_size)
self.post_init()
def _prepare_attention_mask(
self,
x: Tensor,
attention_mask: Tensor | None = None,
past_key_values: Cache | None = None,
use_causal_attention: bool = True
):
device = x.device
B = x.shape[0]
T = x.shape[1]
T_past = past_key_values.get_seq_length() if past_key_values is not None else 0
T_total = T + T_past
if use_causal_attention:
causal_mask = ~torch.triu(
torch.ones((T, T_total), dtype=torch.bool, device=device),
diagonal=(1 + T_past)
).unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
attn_len = attention_mask.shape[-1]
if attn_len < T_total:
pad = torch.ones(B, T_past, device=device, dtype=attention_mask.dtype) # Fixed: ones instead of zeros
attention_mask = torch.cat([pad, attention_mask], dim=-1)
elif attn_len > T_total:
attention_mask = attention_mask[:, -T_total:]
expanded_mask = (attention_mask == 1).view(B, 1, 1, T_total)
if use_causal_attention and attention_mask is not None:
return causal_mask & expanded_mask
elif use_causal_attention:
return causal_mask
elif attention_mask is not None: # Added: handle non-causal with custom mask
return expanded_mask
else:
return torch.ones((1, 1, T, T_total), dtype=torch.bool, device=device)
def forward(
self,
input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | None = None,
past_key_values = None,
use_cache: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds"
assert not (input_ids is None and inputs_embeds is None), "You must specify either input_ids or inputs_embeds"
e = self.embedding(input_ids) if input_ids is not None else inputs_embeds
B, T, _ = e.shape
device = e.device
dtype = e.dtype
if not use_cache:
past_key_values=None
elif past_key_values is None:
past_key_values = DynamicCache()
# Note that we must use an attention mask when caching- otherwise, SDPA uses is_casual and breaks
if attention_mask is not None or past_key_values is not None:
attention_mask = self._prepare_attention_mask(e, attention_mask=attention_mask, use_causal_attention=self.config.use_causal_attention, past_key_values=past_key_values)
hidden_states = [] if output_hidden_states else None
phi = e
f = torch.zeros(B, T, self.config.hidden_size, dtype=dtype, device=device) # Initialize f as zero for purity, but f=e also works fine
for feature_block, predictive_block in zip(self.feature_blocks, self.predictive_blocks):
phi = feature_block(phi, attention_mask=attention_mask, past_key_values=past_key_values)
f = predictive_block(phi, f, e, attention_mask=attention_mask, past_key_values=past_key_values)
if output_hidden_states:
hidden_states.append(phi)
hidden_states.append(f)
if hidden_states is not None:
hidden_states = tuple(hidden_states)
f = self.norm_out(f)
if return_dict:
return BaseModelOutputWithPast(
last_hidden_state=f,
past_key_values=past_key_values,
hidden_states=hidden_states
)
return f, past_key_values, hidden_states
# === Applied Models ===
class FSTForCausalLM(GenerationMixin, FSTPreTrainedModel):
accepts_loss_kwargs = False
def __init__(
self,
config: FSTConfig
):
super().__init__(config)
self.model = FSTModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.tie_weights()
self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} # Avoids safetensor naming issues
self.post_init()
def get_input_embeddings(self):
return self.model.embedding
def set_input_embeddings(self, new_embeddings):
self.model.embedding = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def tie_weights(self, missing_keys=None, recompute_mapping=False):
self.lm_head.weight = self.get_input_embeddings().weight
def forward(
self,
input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
past_key_values = None,
inputs_embeds: Tensor | None = None,
labels: Tensor | None = None,
use_cache: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
):
if labels is not None:
return_dict = True
else:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
model_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
if not return_dict:
output = (logits,) + model_output[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_output.past_key_values,
hidden_states=model_output.hidden_states
)
def _prepare_inputs_for_generation(
self,
input_ids: Tensor,
past_key_values: Cache | None = None,
attention_mask: Tensor | None = None,
**kwargs
):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
for key, value in kwargs.items():
model_inputs[key] = value
return model_inputs
def _reorder_cache(self, past_key_values: Cache, beam_idx: Tensor):
return past_key_values.reorder_cache(beam_idx)
class FSTForMaskedLM(FSTPreTrainedModel):
accepts_loss_kwargs = False
def __init__(
self,
config: FSTConfig
):
super().__init__(config)
assert not config.use_causal_attention, "FSTForMaskedLM requires use_causal_attention=False"
assert not config.use_cache, "FSTForMaskedLM requires use_cache=False (caching not supported for bidirectional models)"
self.model = FSTModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.tie_weights()
self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} # Avoids safetensor naming issues
self.post_init()
def get_input_embeddings(self):
return self.model.embedding
def set_input_embeddings(self, new_embeddings):
self.model.embedding = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def tie_weights(self, missing_keys=None, recompute_mapping=False):
self.lm_head.weight = self.get_input_embeddings().weight
def forward(
self,
input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | None = None,
labels: Tensor | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
):
if labels is not None:
return_dict = True
else:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
model_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=None,
use_cache=False,
output_hidden_states=output_hidden_states
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=self.config.pad_token_id if self.config.pad_token_id is not None else -100
)
if not return_dict:
output = (logits,) + model_output[1:]
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=model_output.hidden_states
)