""" LSTM language model compatible with HuggingFace Trainer and AutoModelForCausalLM. Designed to be a drop-in replacement for LlamaForCausalLM in the Fourier emergence experiments: same tokenizer, same data pipeline, same analysis code. """ import torch import torch.nn as nn from dataclasses import dataclass, field from typing import Optional, Tuple from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutput # ── Config ───────────────────────────────────────────────────────── class LSTMLanguageModelConfig(PretrainedConfig): model_type = "lstm_lm" def __init__( self, vocab_size: int = 128256, embed_dim: int = 1024, hidden_size: int = 1024, num_layers: int = 2, dropout: float = 0.1, tie_word_embeddings: bool = True, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.tie_word_embeddings = tie_word_embeddings # ── Model ────────────────────────────────────────────────────────── class LSTMForCausalLM(PreTrainedModel): config_class = LSTMLanguageModelConfig _tied_weights_keys = {"lm_head.weight": "embed_tokens.weight"} def __init__(self, config: LSTMLanguageModelConfig): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim) self.lstm = nn.LSTM( input_size=config.embed_dim, hidden_size=config.hidden_size, num_layers=config.num_layers, batch_first=True, dropout=config.dropout if config.num_layers > 1 else 0.0, ) self.drop = nn.Dropout(config.dropout) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: if config.embed_dim == config.hidden_size: self.lm_head.weight = self.embed_tokens.weight # If dims don't match, silently skip tying (projection needed) self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens def set_input_embeddings(self, value: nn.Embedding): self.embed_tokens = value def get_output_embeddings(self) -> nn.Linear: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Linear): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> CausalLMOutput: return_dict = return_dict if return_dict is not None else self.config.use_return_dict embeds = self.embed_tokens(input_ids) embeds = self.drop(embeds) lstm_out, _ = self.lstm(embeds) lstm_out = self.drop(lstm_out) logits = self.lm_head(lstm_out) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100, ) hidden_states = None if output_hidden_states: # Return embedding + lstm output for compatibility with analysis hidden_states = (embeds, lstm_out) return CausalLMOutput( loss=loss, logits=logits, hidden_states=hidden_states, ) # Register so AutoModelForCausalLM.from_pretrained works on saved checkpoints. # register_for_auto_class() saves auto_map in config.json when saved. # AutoConfig/AutoModel.register() makes the class findable in any process # that imports this module — no trust_remote_code needed. from transformers import AutoConfig, AutoModelForCausalLM LSTMLanguageModelConfig.register_for_auto_class() LSTMForCausalLM.register_for_auto_class("AutoModelForCausalLM") AutoConfig.register("lstm_lm", LSTMLanguageModelConfig) AutoModelForCausalLM.register(LSTMLanguageModelConfig, LSTMForCausalLM)