| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| PreTrainedModel, | |
| PretrainedConfig, | |
| GenerationMixin, | |
| ) | |
| from transformers.modeling_outputs import CausalLMOutput | |
| class RecursiveLanguageModelConfig(PretrainedConfig): | |
| model_type = "recursive_language_model" | |
| def __init__(self, vocab_size=50257, hidden_size=512, **kwargs): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| class RecursiveLanguageModel(PreTrainedModel, GenerationMixin): | |
| config_class = RecursiveLanguageModelConfig | |
| main_input_name = "input_ids" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.post_init() | |
| def forward(self, input_ids, attention_mask=None, **kwargs): | |
| x = self.token_embedding(input_ids) | |
| logits = self.lm_head(x) | |
| return CausalLMOutput(logits=logits) | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| } | |