""" Translation Transformer Model for HuggingFace Hub """ import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import Seq2SeqLMOutput from typing import Optional, Tuple, Union import math class PositionalEncoding(nn.Module): """Positional encoding for transformer""" def __init__(self, d_model, max_length=5000): super().__init__() pe = torch.zeros(max_length, d_model) position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:, :x.size(1)] class TranslationTransformerConfig(PretrainedConfig): """Configuration class for TranslationTransformer""" model_type = "translation_transformer" def __init__( self, vocab_size=32000, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, pad_token_id=0, bos_token_id=2, eos_token_id=3, max_length=512, **kwargs ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs ) self.vocab_size = vocab_size self.d_model = d_model self.nhead = nhead self.num_encoder_layers = num_encoder_layers self.num_decoder_layers = num_decoder_layers self.dim_feedforward = dim_feedforward self.dropout = dropout self.max_length = max_length # Required for HuggingFace compatibility self.is_encoder_decoder = True self.decoder_start_token_id = bos_token_id class TranslationTransformerModel(PreTrainedModel): """ Encoder-Decoder Transformer for Translation Compatible with HuggingFace Hub """ config_class = TranslationTransformerConfig base_model_prefix = "translation_transformer" supports_gradient_checkpointing = True def __init__(self, config): super().__init__(config) self.config = config # Embeddings self.embedding = nn.Embedding( config.vocab_size, config.d_model, padding_idx=config.pad_token_id ) self.pos_encoder = PositionalEncoding(config.d_model, config.max_length) self.pos_decoder = PositionalEncoding(config.d_model, config.max_length) # Transformer self.transformer = nn.Transformer( d_model=config.d_model, nhead=config.nhead, num_encoder_layers=config.num_encoder_layers, num_decoder_layers=config.num_decoder_layers, dim_feedforward=config.dim_feedforward, dropout=config.dropout, batch_first=True ) # Output layer self.fc_out = nn.Linear(config.d_model, config.vocab_size) # Initialize weights self.post_init() def _init_weights(self, module): """Initialize weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_encoder(self): """Return encoder for compatibility""" return self.transformer.encoder def get_decoder(self): """Return decoder for compatibility""" return self.transformer.decoder def generate_square_subsequent_mask(self, sz, device): """Generate causal mask for decoder""" mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1) mask = mask.masked_fill(mask == 1, float('-inf')) return mask def create_padding_mask(self, seq, pad_token_id): """Create padding mask""" return (seq == pad_token_id) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, Seq2SeqLMOutput]: """Forward pass""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict device = input_ids.device # If labels provided but no decoder_input_ids, shift labels to create decoder_input_ids if labels is not None and decoder_input_ids is None: labels_shifted = labels.clone() labels_shifted[labels_shifted == -100] = self.config.pad_token_id decoder_input_ids = torch.cat([ torch.full((labels.shape[0], 1), self.config.bos_token_id, dtype=torch.long, device=device), labels_shifted[:, :-1] ], dim=1) # Embeddings with scaling src_emb = self.embedding(input_ids) * math.sqrt(self.config.d_model) src_emb = self.pos_encoder(src_emb) tgt_emb = self.embedding(decoder_input_ids) * math.sqrt(self.config.d_model) tgt_emb = self.pos_decoder(tgt_emb) # Create masks tgt_seq_len = decoder_input_ids.size(1) tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len, device) src_key_padding_mask = self.create_padding_mask(input_ids, self.config.pad_token_id) tgt_key_padding_mask = self.create_padding_mask(decoder_input_ids, self.config.pad_token_id) # Transformer forward pass output = self.transformer( src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask ) # Output projection logits = self.fc_out(output) # Calculate loss if labels provided loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=logits, ) def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): """Prepare inputs for generation""" return { "input_ids": kwargs.get("input_ids"), "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, } @staticmethod def _reorder_cache(past_key_values, beam_idx): """Reorder cache for beam search""" return past_key_values def generate( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, max_length: int = 128, num_beams: int = 1, temperature: float = 1.0, do_sample: bool = False, top_k: int = 50, top_p: float = 1.0, **kwargs ) -> torch.LongTensor: """Generate translations""" device = input_ids.device batch_size = input_ids.size(0) # Start with BOS token decoder_input_ids = torch.full( (batch_size, 1), self.config.bos_token_id, dtype=torch.long, device=device ) finished = torch.zeros(batch_size, dtype=torch.bool, device=device) # Generate tokens one by one for _ in range(max_length - 1): outputs = self.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, return_dict=True ) next_token_logits = outputs.logits[:, -1, :] / temperature if do_sample: if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) finished = finished | (next_token.squeeze(-1) == self.config.eos_token_id) next_token[finished] = self.config.pad_token_id decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1) if finished.all(): break return decoder_input_ids # Register the model in the AutoModel registry from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM AutoConfig.register("translation_transformer", TranslationTransformerConfig) AutoModel.register(TranslationTransformerConfig, TranslationTransformerModel) AutoModelForSeq2SeqLM.register(TranslationTransformerConfig, TranslationTransformerModel)