| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import LlamaConfig, LlamaForCausalLM |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
| class LlamaAlbertConfig(LlamaConfig): |
| model_type = "llama_albert" |
| architectures = ["LlamaAlbertForCausalLM"] |
|
|
| def __init__(self, embedding_dim=128, **kwargs): |
| super().__init__( |
| **kwargs, |
| ) |
| self.embedding_dim = embedding_dim |
| self.auto_map={ |
| "AutoConfig": "modeling_llama_albert.LlamaAlbertConfig", |
| "AutoModelForCausalLM": "modeling_llama_albert.LlamaAlbertForCausalLM", |
| } |
| self._auto_class="modeling_llama_albert.LlamaAlbertForCausalLM" |
|
|
|
|
| class LlamaAlbertForCausalLM(LlamaForCausalLM): |
| config_class = LlamaAlbertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| |
| |
| self.model.embed_tokens = nn.Sequential( |
| nn.Embedding(config.vocab_size, config.embedding_dim), |
| nn.Linear(config.embedding_dim, config.hidden_size, bias=False), |
| ) |
|
|
| |
| |
| self.lm_head = nn.Sequential( |
| nn.Linear(config.hidden_size, config.embedding_dim, bias=False), |
| nn.Linear(config.embedding_dim, config.vocab_size, bias=False), |
| ) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens[0] |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens[0] = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head[1] |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head[1] = new_embeddings |
|
|
| def forward(self, input_ids=None, **kwargs): |
| |
| |
| |
| |
| return super().forward(input_ids=input_ids, **kwargs) |
|
|