| | import torch |
| | from transformers import PreTrainedModel |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | from .configuration_tinyllm import TinyLLMConfig |
| | from models.core_models import GenericTransformer |
| | from models.embedding_models import Embedder |
| | from models.model_heads import AutoregressiveLMHead |
| | from models.model_shell import ModelShell |
| | from models.components.base_tokenizer import BaseTokenizer |
| |
|
| |
|
| | def _build_tinyllm(model_cfg): |
| | tokenizer = BaseTokenizer() |
| | embedding_model = Embedder(model_cfg=model_cfg, tokenizer=tokenizer) |
| | core_model = GenericTransformer(model_cfg=model_cfg) |
| | model_head = AutoregressiveLMHead(model_cfg=model_cfg) |
| | if model_cfg.get("embedding_weight_tying", False): |
| | embedding_model.token_embedder.weight = model_head.linear.weight |
| | return ModelShell( |
| | embedding_model=embedding_model, |
| | core_model=core_model, |
| | model_head=model_head, |
| | ) |
| |
|
| |
|
| | class TinyLLMForCausalLM(PreTrainedModel, GenerationMixin): |
| | config_class = TinyLLMConfig |
| | base_model_prefix = "model" |
| |
|
| | def __init__(self, config: TinyLLMConfig): |
| | super().__init__(config) |
| | self.model = _build_tinyllm(config.model_cfg) |
| |
|
| | def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| | |
| | attention_mask = None |
| | logits, _ = self.model(input_ids, attention_mask=attention_mask) |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[:, :-1, :].contiguous() |
| | shift_labels = labels[:, 1:].contiguous() |
| | loss = torch.nn.functional.cross_entropy( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| | return CausalLMOutputWithPast(loss=loss, logits=logits) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): |
| | return {"input_ids": input_ids, "attention_mask": attention_mask} |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.embedding_model.token_embedder |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.embedding_model.token_embedder = value |
| |
|