|
|
""" |
|
|
Translation Transformer Model for HuggingFace Hub |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
from typing import Optional, Tuple, Union |
|
|
import math |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
"""Positional encoding for transformer""" |
|
|
def __init__(self, d_model, max_length=5000): |
|
|
super().__init__() |
|
|
pe = torch.zeros(max_length, d_model) |
|
|
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0) |
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x): |
|
|
return x + self.pe[:, :x.size(1)] |
|
|
|
|
|
|
|
|
class TranslationTransformerConfig(PretrainedConfig): |
|
|
"""Configuration class for TranslationTransformer""" |
|
|
model_type = "translation_transformer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=32000, |
|
|
d_model=512, |
|
|
nhead=8, |
|
|
num_encoder_layers=6, |
|
|
num_decoder_layers=6, |
|
|
dim_feedforward=2048, |
|
|
dropout=0.1, |
|
|
pad_token_id=0, |
|
|
bos_token_id=2, |
|
|
eos_token_id=3, |
|
|
max_length=512, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.nhead = nhead |
|
|
self.num_encoder_layers = num_encoder_layers |
|
|
self.num_decoder_layers = num_decoder_layers |
|
|
self.dim_feedforward = dim_feedforward |
|
|
self.dropout = dropout |
|
|
self.max_length = max_length |
|
|
|
|
|
|
|
|
self.is_encoder_decoder = True |
|
|
self.decoder_start_token_id = bos_token_id |
|
|
|
|
|
|
|
|
class TranslationTransformerModel(PreTrainedModel): |
|
|
""" |
|
|
Encoder-Decoder Transformer for Translation |
|
|
Compatible with HuggingFace Hub |
|
|
""" |
|
|
config_class = TranslationTransformerConfig |
|
|
base_model_prefix = "translation_transformer" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding( |
|
|
config.vocab_size, |
|
|
config.d_model, |
|
|
padding_idx=config.pad_token_id |
|
|
) |
|
|
self.pos_encoder = PositionalEncoding(config.d_model, config.max_length) |
|
|
self.pos_decoder = PositionalEncoding(config.d_model, config.max_length) |
|
|
|
|
|
|
|
|
self.transformer = nn.Transformer( |
|
|
d_model=config.d_model, |
|
|
nhead=config.nhead, |
|
|
num_encoder_layers=config.num_encoder_layers, |
|
|
num_decoder_layers=config.num_decoder_layers, |
|
|
dim_feedforward=config.dim_feedforward, |
|
|
dropout=config.dropout, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.fc_out = nn.Linear(config.d_model, config.vocab_size) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def get_encoder(self): |
|
|
"""Return encoder for compatibility""" |
|
|
return self.transformer.encoder |
|
|
|
|
|
def get_decoder(self): |
|
|
"""Return decoder for compatibility""" |
|
|
return self.transformer.decoder |
|
|
|
|
|
def generate_square_subsequent_mask(self, sz, device): |
|
|
"""Generate causal mask for decoder""" |
|
|
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1) |
|
|
mask = mask.masked_fill(mask == 1, float('-inf')) |
|
|
return mask |
|
|
|
|
|
def create_padding_mask(self, seq, pad_token_id): |
|
|
"""Create padding mask""" |
|
|
return (seq == pad_token_id) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
|
decoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> Union[Tuple, Seq2SeqLMOutput]: |
|
|
"""Forward pass""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
if labels is not None and decoder_input_ids is None: |
|
|
labels_shifted = labels.clone() |
|
|
labels_shifted[labels_shifted == -100] = self.config.pad_token_id |
|
|
|
|
|
decoder_input_ids = torch.cat([ |
|
|
torch.full((labels.shape[0], 1), self.config.bos_token_id, dtype=torch.long, device=device), |
|
|
labels_shifted[:, :-1] |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
src_emb = self.embedding(input_ids) * math.sqrt(self.config.d_model) |
|
|
src_emb = self.pos_encoder(src_emb) |
|
|
|
|
|
tgt_emb = self.embedding(decoder_input_ids) * math.sqrt(self.config.d_model) |
|
|
tgt_emb = self.pos_decoder(tgt_emb) |
|
|
|
|
|
|
|
|
tgt_seq_len = decoder_input_ids.size(1) |
|
|
tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len, device) |
|
|
|
|
|
src_key_padding_mask = self.create_padding_mask(input_ids, self.config.pad_token_id) |
|
|
tgt_key_padding_mask = self.create_padding_mask(decoder_input_ids, self.config.pad_token_id) |
|
|
|
|
|
|
|
|
output = self.transformer( |
|
|
src_emb, |
|
|
tgt_emb, |
|
|
tgt_mask=tgt_mask, |
|
|
src_key_padding_mask=src_key_padding_mask, |
|
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
|
memory_key_padding_mask=src_key_padding_mask |
|
|
) |
|
|
|
|
|
|
|
|
logits = self.fc_out(output) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return Seq2SeqLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
decoder_input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
use_cache=None, |
|
|
encoder_outputs=None, |
|
|
**kwargs |
|
|
): |
|
|
"""Prepare inputs for generation""" |
|
|
return { |
|
|
"input_ids": kwargs.get("input_ids"), |
|
|
"decoder_input_ids": decoder_input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def _reorder_cache(past_key_values, beam_idx): |
|
|
"""Reorder cache for beam search""" |
|
|
return past_key_values |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
max_length: int = 128, |
|
|
num_beams: int = 1, |
|
|
temperature: float = 1.0, |
|
|
do_sample: bool = False, |
|
|
top_k: int = 50, |
|
|
top_p: float = 1.0, |
|
|
**kwargs |
|
|
) -> torch.LongTensor: |
|
|
"""Generate translations""" |
|
|
device = input_ids.device |
|
|
batch_size = input_ids.size(0) |
|
|
|
|
|
|
|
|
decoder_input_ids = torch.full( |
|
|
(batch_size, 1), |
|
|
self.config.bos_token_id, |
|
|
dtype=torch.long, |
|
|
device=device |
|
|
) |
|
|
|
|
|
finished = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
|
for _ in range(max_length - 1): |
|
|
outputs = self.forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] / temperature |
|
|
|
|
|
if do_sample: |
|
|
if top_k > 0: |
|
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
finished = finished | (next_token.squeeze(-1) == self.config.eos_token_id) |
|
|
next_token[finished] = self.config.pad_token_id |
|
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1) |
|
|
|
|
|
if finished.all(): |
|
|
break |
|
|
|
|
|
return decoder_input_ids |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM |
|
|
|
|
|
AutoConfig.register("translation_transformer", TranslationTransformerConfig) |
|
|
AutoModel.register(TranslationTransformerConfig, TranslationTransformerModel) |
|
|
AutoModelForSeq2SeqLM.register(TranslationTransformerConfig, TranslationTransformerModel) |