tamoghna's picture
Update modeling.py
c9912a0 verified
"""
Translation Transformer Model for HuggingFace Hub
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import Seq2SeqLMOutput
from typing import Optional, Tuple, Union
import math
class PositionalEncoding(nn.Module):
"""Positional encoding for transformer"""
def __init__(self, d_model, max_length=5000):
super().__init__()
pe = torch.zeros(max_length, d_model)
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class TranslationTransformerConfig(PretrainedConfig):
"""Configuration class for TranslationTransformer"""
model_type = "translation_transformer"
def __init__(
self,
vocab_size=32000,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
pad_token_id=0,
bos_token_id=2,
eos_token_id=3,
max_length=512,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.d_model = d_model
self.nhead = nhead
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.max_length = max_length
# Required for HuggingFace compatibility
self.is_encoder_decoder = True
self.decoder_start_token_id = bos_token_id
class TranslationTransformerModel(PreTrainedModel):
"""
Encoder-Decoder Transformer for Translation
Compatible with HuggingFace Hub
"""
config_class = TranslationTransformerConfig
base_model_prefix = "translation_transformer"
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
# Embeddings
self.embedding = nn.Embedding(
config.vocab_size,
config.d_model,
padding_idx=config.pad_token_id
)
self.pos_encoder = PositionalEncoding(config.d_model, config.max_length)
self.pos_decoder = PositionalEncoding(config.d_model, config.max_length)
# Transformer
self.transformer = nn.Transformer(
d_model=config.d_model,
nhead=config.nhead,
num_encoder_layers=config.num_encoder_layers,
num_decoder_layers=config.num_decoder_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
batch_first=True
)
# Output layer
self.fc_out = nn.Linear(config.d_model, config.vocab_size)
# Initialize weights
self.post_init()
def _init_weights(self, module):
"""Initialize weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_encoder(self):
"""Return encoder for compatibility"""
return self.transformer.encoder
def get_decoder(self):
"""Return decoder for compatibility"""
return self.transformer.decoder
def generate_square_subsequent_mask(self, sz, device):
"""Generate causal mask for decoder"""
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def create_padding_mask(self, seq, pad_token_id):
"""Create padding mask"""
return (seq == pad_token_id)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, Seq2SeqLMOutput]:
"""Forward pass"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
device = input_ids.device
# If labels provided but no decoder_input_ids, shift labels to create decoder_input_ids
if labels is not None and decoder_input_ids is None:
labels_shifted = labels.clone()
labels_shifted[labels_shifted == -100] = self.config.pad_token_id
decoder_input_ids = torch.cat([
torch.full((labels.shape[0], 1), self.config.bos_token_id, dtype=torch.long, device=device),
labels_shifted[:, :-1]
], dim=1)
# Embeddings with scaling
src_emb = self.embedding(input_ids) * math.sqrt(self.config.d_model)
src_emb = self.pos_encoder(src_emb)
tgt_emb = self.embedding(decoder_input_ids) * math.sqrt(self.config.d_model)
tgt_emb = self.pos_decoder(tgt_emb)
# Create masks
tgt_seq_len = decoder_input_ids.size(1)
tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len, device)
src_key_padding_mask = self.create_padding_mask(input_ids, self.config.pad_token_id)
tgt_key_padding_mask = self.create_padding_mask(decoder_input_ids, self.config.pad_token_id)
# Transformer forward pass
output = self.transformer(
src_emb,
tgt_emb,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask
)
# Output projection
logits = self.fc_out(output)
# Calculate loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=logits,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
"""Prepare inputs for generation"""
return {
"input_ids": kwargs.get("input_ids"),
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Reorder cache for beam search"""
return past_key_values
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.FloatTensor] = None,
max_length: int = 128,
num_beams: int = 1,
temperature: float = 1.0,
do_sample: bool = False,
top_k: int = 50,
top_p: float = 1.0,
**kwargs
) -> torch.LongTensor:
"""Generate translations"""
device = input_ids.device
batch_size = input_ids.size(0)
# Start with BOS token
decoder_input_ids = torch.full(
(batch_size, 1),
self.config.bos_token_id,
dtype=torch.long,
device=device
)
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
# Generate tokens one by one
for _ in range(max_length - 1):
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
return_dict=True
)
next_token_logits = outputs.logits[:, -1, :] / temperature
if do_sample:
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
finished = finished | (next_token.squeeze(-1) == self.config.eos_token_id)
next_token[finished] = self.config.pad_token_id
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
if finished.all():
break
return decoder_input_ids
# Register the model in the AutoModel registry
from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM
AutoConfig.register("translation_transformer", TranslationTransformerConfig)
AutoModel.register(TranslationTransformerConfig, TranslationTransformerModel)
AutoModelForSeq2SeqLM.register(TranslationTransformerConfig, TranslationTransformerModel)