|
|
"""PyTorch Small Transformer model for English to Hindi/Bengali translation.""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Tuple |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
class SmallTransformerConfig(PretrainedConfig): |
|
|
model_type = "small_transformer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=80000, |
|
|
d_model=256, |
|
|
nhead=8, |
|
|
num_encoder_layers=3, |
|
|
num_decoder_layers=3, |
|
|
dim_feedforward=512, |
|
|
dropout=0.1, |
|
|
max_position_embeddings=512, |
|
|
pad_token_id=0, |
|
|
bos_token_id=1, |
|
|
eos_token_id=2, |
|
|
**kwargs |
|
|
): |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.nhead = nhead |
|
|
self.num_encoder_layers = num_encoder_layers |
|
|
self.num_decoder_layers = num_decoder_layers |
|
|
self.dim_feedforward = dim_feedforward |
|
|
self.dropout = dropout |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
class SmallTransformerPreTrainedModel(PreTrainedModel): |
|
|
config_class = SmallTransformerConfig |
|
|
base_model_prefix = "small_transformer" |
|
|
supports_gradient_checkpointing = False |
|
|
_no_split_modules = [] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
|
|
|
class SmallTransformer(SmallTransformerPreTrainedModel): |
|
|
def __init__(self, config: SmallTransformerConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embedding = nn.Embedding( |
|
|
config.vocab_size, |
|
|
config.d_model, |
|
|
padding_idx=config.pad_token_id |
|
|
) |
|
|
self.pos_encoder = nn.Embedding(config.max_position_embeddings, config.d_model) |
|
|
self.pos_decoder = nn.Embedding(config.max_position_embeddings, config.d_model) |
|
|
self.embed_scale = math.sqrt(config.d_model) |
|
|
|
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=config.d_model, |
|
|
nhead=config.nhead, |
|
|
dim_feedforward=config.dim_feedforward, |
|
|
dropout=config.dropout, |
|
|
batch_first=True |
|
|
) |
|
|
dec_layer = nn.TransformerDecoderLayer( |
|
|
d_model=config.d_model, |
|
|
nhead=config.nhead, |
|
|
dim_feedforward=config.dim_feedforward, |
|
|
dropout=config.dropout, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=config.num_encoder_layers) |
|
|
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=config.num_decoder_layers) |
|
|
self.output_layer = nn.Linear(config.d_model, config.vocab_size) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_encoder(self): |
|
|
return self.encoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.decoder |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
|
decoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
if decoder_input_ids is None and labels is not None: |
|
|
decoder_input_ids = labels.clone() |
|
|
|
|
|
src = input_ids |
|
|
tgt = decoder_input_ids |
|
|
|
|
|
assert src.dim() == 2 and tgt.dim() == 2 |
|
|
|
|
|
|
|
|
src_mask = (src == self.config.pad_token_id) |
|
|
tgt_mask_pad = (tgt == self.config.pad_token_id) |
|
|
|
|
|
T = tgt.size(1) |
|
|
causal_mask = torch.triu(torch.ones((T, T), device=tgt.device), diagonal=1).bool() |
|
|
|
|
|
|
|
|
src_pos = torch.arange(0, src.size(1), device=src.device).unsqueeze(0).expand(src.size(0), -1).clamp( |
|
|
max=self.config.max_position_embeddings - 1 |
|
|
) |
|
|
tgt_pos = torch.arange(0, tgt.size(1), device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1).clamp( |
|
|
max=self.config.max_position_embeddings - 1 |
|
|
) |
|
|
|
|
|
|
|
|
src_emb = self.embedding(src) * self.embed_scale + self.pos_encoder(src_pos) |
|
|
tgt_emb = self.embedding(tgt) * self.embed_scale + self.pos_decoder(tgt_pos) |
|
|
|
|
|
|
|
|
memory = self.encoder(src_emb, src_key_padding_mask=src_mask) |
|
|
output = self.decoder( |
|
|
tgt_emb, |
|
|
memory, |
|
|
tgt_mask=causal_mask, |
|
|
tgt_key_padding_mask=tgt_mask_pad, |
|
|
memory_key_padding_mask=src_mask |
|
|
) |
|
|
logits = self.output_layer(output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
|
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return Seq2SeqLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
decoder_hidden_states=None, |
|
|
decoder_attentions=None, |
|
|
cross_attentions=None, |
|
|
encoder_last_hidden_state=memory, |
|
|
encoder_hidden_states=None, |
|
|
encoder_attentions=None, |
|
|
) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
max_length: int = None, |
|
|
max_new_tokens: int = None, |
|
|
lang_token_id: int = None, |
|
|
eos_token_id: int = None, |
|
|
**kwargs |
|
|
): |
|
|
"""Simple greedy generation for translation.""" |
|
|
if eos_token_id is None: |
|
|
eos_token_id = self.config.eos_token_id |
|
|
|
|
|
|
|
|
if max_new_tokens is not None: |
|
|
max_length = max_new_tokens |
|
|
elif max_length is None: |
|
|
max_length = 64 |
|
|
|
|
|
batch_size = input_ids.size(0) |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
if lang_token_id is None: |
|
|
raise ValueError("lang_token_id must be provided for generation") |
|
|
|
|
|
decoder_input_ids = torch.full((batch_size, 1), lang_token_id, dtype=torch.long, device=device) |
|
|
|
|
|
for _ in range(max_length - 1): |
|
|
outputs = self.forward( |
|
|
input_ids=input_ids, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) |
|
|
|
|
|
|
|
|
if (next_tokens == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
return decoder_input_ids |