tamoghna's picture
Update modeling.py
a0c9612 verified
"""PyTorch Small Transformer model for English to Hindi/Bengali translation."""
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.configuration_utils import PretrainedConfig
class SmallTransformerConfig(PretrainedConfig):
model_type = "small_transformer"
def __init__(
self,
vocab_size=80000,
d_model=256,
nhead=8,
num_encoder_layers=3,
num_decoder_layers=3,
dim_feedforward=512,
dropout=0.1,
max_position_embeddings=512,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
self.vocab_size = vocab_size
self.d_model = d_model
self.nhead = nhead
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.max_position_embeddings = max_position_embeddings
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
class SmallTransformerPreTrainedModel(PreTrainedModel):
config_class = SmallTransformerConfig
base_model_prefix = "small_transformer"
supports_gradient_checkpointing = False
_no_split_modules = []
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class SmallTransformer(SmallTransformerPreTrainedModel):
def __init__(self, config: SmallTransformerConfig):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(
config.vocab_size,
config.d_model,
padding_idx=config.pad_token_id
)
self.pos_encoder = nn.Embedding(config.max_position_embeddings, config.d_model)
self.pos_decoder = nn.Embedding(config.max_position_embeddings, config.d_model)
self.embed_scale = math.sqrt(config.d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
batch_first=True
)
dec_layer = nn.TransformerDecoderLayer(
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
batch_first=True
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=config.num_encoder_layers)
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=config.num_decoder_layers)
self.output_layer = nn.Linear(config.d_model, config.vocab_size)
# Initialize weights
self.post_init()
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Use decoder_input_ids if provided, otherwise shift labels
if decoder_input_ids is None and labels is not None:
decoder_input_ids = labels.clone()
src = input_ids
tgt = decoder_input_ids
assert src.dim() == 2 and tgt.dim() == 2
# Create masks
src_mask = (src == self.config.pad_token_id)
tgt_mask_pad = (tgt == self.config.pad_token_id)
T = tgt.size(1)
causal_mask = torch.triu(torch.ones((T, T), device=tgt.device), diagonal=1).bool()
# Positional indices
src_pos = torch.arange(0, src.size(1), device=src.device).unsqueeze(0).expand(src.size(0), -1).clamp(
max=self.config.max_position_embeddings - 1
)
tgt_pos = torch.arange(0, tgt.size(1), device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1).clamp(
max=self.config.max_position_embeddings - 1
)
# Embeddings
src_emb = self.embedding(src) * self.embed_scale + self.pos_encoder(src_pos)
tgt_emb = self.embedding(tgt) * self.embed_scale + self.pos_decoder(tgt_pos)
# Encode and decode
memory = self.encoder(src_emb, src_key_padding_mask=src_mask)
output = self.decoder(
tgt_emb,
memory,
tgt_mask=causal_mask,
tgt_key_padding_mask=tgt_mask_pad,
memory_key_padding_mask=src_mask
)
logits = self.output_layer(output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=None,
decoder_hidden_states=None,
decoder_attentions=None,
cross_attentions=None,
encoder_last_hidden_state=memory,
encoder_hidden_states=None,
encoder_attentions=None,
)
def generate(
self,
input_ids: torch.LongTensor,
max_length: int = None,
max_new_tokens: int = None,
lang_token_id: int = None,
eos_token_id: int = None,
**kwargs
):
"""Simple greedy generation for translation."""
if eos_token_id is None:
eos_token_id = self.config.eos_token_id
# Handle max_new_tokens parameter
if max_new_tokens is not None:
max_length = max_new_tokens
elif max_length is None:
max_length = 64
batch_size = input_ids.size(0)
device = input_ids.device
# Start with language token
if lang_token_id is None:
raise ValueError("lang_token_id must be provided for generation")
decoder_input_ids = torch.full((batch_size, 1), lang_token_id, dtype=torch.long, device=device)
for _ in range(max_length - 1):
outputs = self.forward(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
return_dict=True
)
next_token_logits = outputs.logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1)
# Stop if all sequences have generated EOS
if (next_tokens == eos_token_id).all():
break
return decoder_input_ids