tamoghna commited on
Commit
31831aa
·
verified ·
1 Parent(s): ee3183a

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +35 -1
modeling.py CHANGED
@@ -129,8 +129,42 @@ import torch.nn as nn
129
  from typing import Optional, Tuple
130
  from transformers import PreTrainedModel
131
  from transformers.modeling_outputs import Seq2SeqLMOutput
132
- from .configuration_small_transformer import SmallTransformerConfig
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  class SmallTransformerPreTrainedModel(PreTrainedModel):
136
  config_class = SmallTransformerConfig
 
129
  from typing import Optional, Tuple
130
  from transformers import PreTrainedModel
131
  from transformers.modeling_outputs import Seq2SeqLMOutput
 
132
 
133
+ class SmallTransformerConfig(PretrainedConfig):
134
+ model_type = "small_transformer"
135
+
136
+ def __init__(
137
+ self,
138
+ vocab_size=80000,
139
+ d_model=256,
140
+ nhead=8,
141
+ num_encoder_layers=3,
142
+ num_decoder_layers=3,
143
+ dim_feedforward=512,
144
+ dropout=0.1,
145
+ max_position_embeddings=512,
146
+ pad_token_id=0,
147
+ bos_token_id=1,
148
+ eos_token_id=2,
149
+ use_return_dict=True,
150
+ **kwargs
151
+ ):
152
+ super().__init__(
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ **kwargs
157
+ )
158
+
159
+ self.vocab_size = vocab_size
160
+ self.d_model = d_model
161
+ self.nhead = nhead
162
+ self.num_encoder_layers = num_encoder_layers
163
+ self.num_decoder_layers = num_decoder_layers
164
+ self.dim_feedforward = dim_feedforward
165
+ self.dropout = dropout
166
+ self.max_position_embeddings = max_position_embeddings
167
+ self.use_return_dict = use_return_dict
168
 
169
  class SmallTransformerPreTrainedModel(PreTrainedModel):
170
  config_class = SmallTransformerConfig