lstm-12layer-v5 / lstm.py
deqing's picture
Sync main at tokens-200M tokens
2e0af2f verified
"""
LSTM language model compatible with HuggingFace Trainer and AutoModelForCausalLM.
Designed to be a drop-in replacement for LlamaForCausalLM in the Fourier
emergence experiments: same tokenizer, same data pipeline, same analysis code.
"""
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import Optional, Tuple
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
# ── Config ─────────────────────────────────────────────────────────
class LSTMLanguageModelConfig(PretrainedConfig):
model_type = "lstm_lm"
def __init__(
self,
vocab_size: int = 128256,
embed_dim: int = 1024,
hidden_size: int = 1024,
num_layers: int = 2,
dropout: float = 0.1,
tie_word_embeddings: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
self.tie_word_embeddings = tie_word_embeddings
# ── Model ──────────────────────────────────────────────────────────
class LSTMForCausalLM(PreTrainedModel):
config_class = LSTMLanguageModelConfig
_tied_weights_keys = {"lm_head.weight": "embed_tokens.weight"}
def __init__(self, config: LSTMLanguageModelConfig):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim)
self.lstm = nn.LSTM(
input_size=config.embed_dim,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
batch_first=True,
dropout=config.dropout if config.num_layers > 1 else 0.0,
)
self.drop = nn.Dropout(config.dropout)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
if config.embed_dim == config.hidden_size:
self.lm_head.weight = self.embed_tokens.weight
# If dims don't match, silently skip tying (projection needed)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding):
self.embed_tokens = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> CausalLMOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embeds = self.embed_tokens(input_ids)
embeds = self.drop(embeds)
lstm_out, _ = self.lstm(embeds)
lstm_out = self.drop(lstm_out)
logits = self.lm_head(lstm_out)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
hidden_states = None
if output_hidden_states:
# Return embedding + lstm output for compatibility with analysis
hidden_states = (embeds, lstm_out)
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=hidden_states,
)
# Register so AutoModelForCausalLM.from_pretrained works on saved checkpoints.
# register_for_auto_class() saves auto_map in config.json when saved.
# AutoConfig/AutoModel.register() makes the class findable in any process
# that imports this module β€” no trust_remote_code needed.
from transformers import AutoConfig, AutoModelForCausalLM
LSTMLanguageModelConfig.register_for_auto_class()
LSTMForCausalLM.register_for_auto_class("AutoModelForCausalLM")
AutoConfig.register("lstm_lm", LSTMLanguageModelConfig)
AutoModelForCausalLM.register(LSTMLanguageModelConfig, LSTMForCausalLM)