import torch from transformers import PreTrainedModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_tinyllm import TinyLLMConfig from models.core_models import GenericTransformer from models.embedding_models import Embedder from models.model_heads import AutoregressiveLMHead from models.model_shell import ModelShell from models.components.base_tokenizer import BaseTokenizer def _build_tinyllm(model_cfg): tokenizer = BaseTokenizer() embedding_model = Embedder(model_cfg=model_cfg, tokenizer=tokenizer) core_model = GenericTransformer(model_cfg=model_cfg) model_head = AutoregressiveLMHead(model_cfg=model_cfg) if model_cfg.get("embedding_weight_tying", False): embedding_model.token_embedder.weight = model_head.linear.weight return ModelShell( embedding_model=embedding_model, core_model=core_model, model_head=model_head, ) class TinyLLMForCausalLM(PreTrainedModel, GenerationMixin): config_class = TinyLLMConfig base_model_prefix = "model" def __init__(self, config: TinyLLMConfig): super().__init__(config) self.model = _build_tinyllm(config.model_cfg) def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): # TinyLLM uses causal attention internally; ignore HF attention_mask to avoid shape mismatches. attention_mask = None logits, _ = self.model(input_ids, attention_mask=attention_mask) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = torch.nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): return {"input_ids": input_ids, "attention_mask": attention_mask} def get_input_embeddings(self): return self.model.embedding_model.token_embedder def set_input_embeddings(self, value): self.model.embedding_model.token_embedder = value