encoder-decoder-en-hi-bn / configuration_translation_transformer.py
tamoghna's picture
Update configuration_translation_transformer.py
5b15156 verified
from transformers import PretrainedConfig
class TranslationTransformerConfig(PretrainedConfig):
model_type = "translation_transformer" # matches your config.json
def __init__(
self,
vocab_size=32000,
d_model=256,
nhead=8,
num_encoder_layers=3,
num_decoder_layers=3,
dim_feedforward=512,
dropout=0.1,
max_length=128,
bos_token_id=2,
eos_token_id=3,
pad_token_id=0,
decoder_start_token_id=2,
**kwargs
):
"""
Configuration class for TranslationTransformerModel
"""
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.d_model = d_model
self.nhead = nhead
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.max_length = max_length