| """ |
| LSTM language model compatible with HuggingFace Trainer and AutoModelForCausalLM. |
| |
| Designed to be a drop-in replacement for LlamaForCausalLM in the Fourier |
| emergence experiments: same tokenizer, same data pipeline, same analysis code. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from dataclasses import dataclass, field |
| from typing import Optional, Tuple |
|
|
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutput |
|
|
|
|
| |
|
|
|
|
| class LSTMLanguageModelConfig(PretrainedConfig): |
| model_type = "lstm_lm" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 128256, |
| embed_dim: int = 1024, |
| hidden_size: int = 1024, |
| num_layers: int = 2, |
| dropout: float = 0.1, |
| tie_word_embeddings: bool = True, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.embed_dim = embed_dim |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.dropout = dropout |
| self.tie_word_embeddings = tie_word_embeddings |
|
|
|
|
| |
|
|
|
|
| class LSTMForCausalLM(PreTrainedModel): |
| config_class = LSTMLanguageModelConfig |
| _tied_weights_keys = {"lm_head.weight": "embed_tokens.weight"} |
|
|
| def __init__(self, config: LSTMLanguageModelConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim) |
| self.lstm = nn.LSTM( |
| input_size=config.embed_dim, |
| hidden_size=config.hidden_size, |
| num_layers=config.num_layers, |
| batch_first=True, |
| dropout=config.dropout if config.num_layers > 1 else 0.0, |
| ) |
| self.drop = nn.Dropout(config.dropout) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| if config.tie_word_embeddings: |
| if config.embed_dim == config.hidden_size: |
| self.lm_head.weight = self.embed_tokens.weight |
| |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self) -> nn.Embedding: |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value: nn.Embedding): |
| self.embed_tokens = value |
|
|
| def get_output_embeddings(self) -> nn.Linear: |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Linear): |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> CausalLMOutput: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| embeds = self.embed_tokens(input_ids) |
| embeds = self.drop(embeds) |
|
|
| lstm_out, _ = self.lstm(embeds) |
| lstm_out = self.drop(lstm_out) |
|
|
| logits = self.lm_head(lstm_out) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, self.config.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| hidden_states = None |
| if output_hidden_states: |
| |
| hidden_states = (embeds, lstm_out) |
|
|
| return CausalLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=hidden_states, |
| ) |
|
|
|
|
| |
| |
| |
| |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| LSTMLanguageModelConfig.register_for_auto_class() |
| LSTMForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
| AutoConfig.register("lstm_lm", LSTMLanguageModelConfig) |
| AutoModelForCausalLM.register(LSTMLanguageModelConfig, LSTMForCausalLM) |
|
|