| from transformers import PretrainedConfig | |
| class TranslationTransformerConfig(PretrainedConfig): | |
| model_type = "translation_transformer" # matches your config.json | |
| def __init__( | |
| self, | |
| vocab_size=32000, | |
| d_model=256, | |
| nhead=8, | |
| num_encoder_layers=3, | |
| num_decoder_layers=3, | |
| dim_feedforward=512, | |
| dropout=0.1, | |
| max_length=128, | |
| bos_token_id=2, | |
| eos_token_id=3, | |
| pad_token_id=0, | |
| decoder_start_token_id=2, | |
| **kwargs | |
| ): | |
| """ | |
| Configuration class for TranslationTransformerModel | |
| """ | |
| super().__init__( | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| pad_token_id=pad_token_id, | |
| decoder_start_token_id=decoder_start_token_id, | |
| **kwargs | |
| ) | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| self.num_encoder_layers = num_encoder_layers | |
| self.num_decoder_layers = num_decoder_layers | |
| self.dim_feedforward = dim_feedforward | |
| self.dropout = dropout | |
| self.max_length = max_length | |