TinyDoc-VLM / tinydoc_vlm /decoder.py
GautamKishore's picture
Upload folder using huggingface_hub
65880fe verified
Raw
History Blame Contribute Delete
1.92 kB
import torch
import torch.nn as nn
from typing import Optional, List
from transformers import LlamaForCausalLM, LlamaConfig
class TinyDocDecoder(nn.Module):
"""
Decoder wrapper around LlamaForCausalLM (used by SmolLM2).
Manages loading and vocabulary/embedding resizing for special tokens.
"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.lm = LlamaForCausalLM(config)
self.hidden_size = config.hidden_size
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
"""
Resizes input token embeddings and output LM head of the decoder.
"""
resized = self.lm.resize_token_embeddings(new_num_tokens)
self.config.vocab_size = new_num_tokens
return resized
def get_input_embeddings(self) -> nn.Module:
return self.lm.get_input_embeddings()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return self.lm(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)