"""PyTorch Small Transformer model for English to Hindi/Bengali translation.""" import math import torch import torch.nn as nn from typing import Optional, Tuple from transformers import PreTrainedModel from transformers.modeling_outputs import Seq2SeqLMOutput from transformers.configuration_utils import PretrainedConfig class SmallTransformerConfig(PretrainedConfig): model_type = "small_transformer" def __init__( self, vocab_size=80000, d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1, max_position_embeddings=512, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs ): self.vocab_size = vocab_size self.d_model = d_model self.nhead = nhead self.num_encoder_layers = num_encoder_layers self.num_decoder_layers = num_decoder_layers self.dim_feedforward = dim_feedforward self.dropout = dropout self.max_position_embeddings = max_position_embeddings super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs ) class SmallTransformerPreTrainedModel(PreTrainedModel): config_class = SmallTransformerConfig base_model_prefix = "small_transformer" supports_gradient_checkpointing = False _no_split_modules = [] def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class SmallTransformer(SmallTransformerPreTrainedModel): def __init__(self, config: SmallTransformerConfig): super().__init__(config) self.config = config self.embedding = nn.Embedding( config.vocab_size, config.d_model, padding_idx=config.pad_token_id ) self.pos_encoder = nn.Embedding(config.max_position_embeddings, config.d_model) self.pos_decoder = nn.Embedding(config.max_position_embeddings, config.d_model) self.embed_scale = math.sqrt(config.d_model) enc_layer = nn.TransformerEncoderLayer( d_model=config.d_model, nhead=config.nhead, dim_feedforward=config.dim_feedforward, dropout=config.dropout, batch_first=True ) dec_layer = nn.TransformerDecoderLayer( d_model=config.d_model, nhead=config.nhead, dim_feedforward=config.dim_feedforward, dropout=config.dropout, batch_first=True ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=config.num_encoder_layers) self.decoder = nn.TransformerDecoder(dec_layer, num_layers=config.num_decoder_layers) self.output_layer = nn.Linear(config.d_model, config.vocab_size) # Initialize weights self.post_init() def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Use decoder_input_ids if provided, otherwise shift labels if decoder_input_ids is None and labels is not None: decoder_input_ids = labels.clone() src = input_ids tgt = decoder_input_ids assert src.dim() == 2 and tgt.dim() == 2 # Create masks src_mask = (src == self.config.pad_token_id) tgt_mask_pad = (tgt == self.config.pad_token_id) T = tgt.size(1) causal_mask = torch.triu(torch.ones((T, T), device=tgt.device), diagonal=1).bool() # Positional indices src_pos = torch.arange(0, src.size(1), device=src.device).unsqueeze(0).expand(src.size(0), -1).clamp( max=self.config.max_position_embeddings - 1 ) tgt_pos = torch.arange(0, tgt.size(1), device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1).clamp( max=self.config.max_position_embeddings - 1 ) # Embeddings src_emb = self.embedding(src) * self.embed_scale + self.pos_encoder(src_pos) tgt_emb = self.embedding(tgt) * self.embed_scale + self.pos_decoder(tgt_pos) # Encode and decode memory = self.encoder(src_emb, src_key_padding_mask=src_mask) output = self.decoder( tgt_emb, memory, tgt_mask=causal_mask, tgt_key_padding_mask=tgt_mask_pad, memory_key_padding_mask=src_mask ) logits = self.output_layer(output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=logits, past_key_values=None, decoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, encoder_last_hidden_state=memory, encoder_hidden_states=None, encoder_attentions=None, ) def generate( self, input_ids: torch.LongTensor, max_length: int = None, max_new_tokens: int = None, lang_token_id: int = None, eos_token_id: int = None, **kwargs ): """Simple greedy generation for translation.""" if eos_token_id is None: eos_token_id = self.config.eos_token_id # Handle max_new_tokens parameter if max_new_tokens is not None: max_length = max_new_tokens elif max_length is None: max_length = 64 batch_size = input_ids.size(0) device = input_ids.device # Start with language token if lang_token_id is None: raise ValueError("lang_token_id must be provided for generation") decoder_input_ids = torch.full((batch_size, 1), lang_token_id, dtype=torch.long, device=device) for _ in range(max_length - 1): outputs = self.forward( input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True ) next_token_logits = outputs.logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) # Stop if all sequences have generated EOS if (next_tokens == eos_token_id).all(): break return decoder_input_ids