Update modeling.py
Browse files- modeling.py +35 -1
modeling.py
CHANGED
|
@@ -129,8 +129,42 @@ import torch.nn as nn
|
|
| 129 |
from typing import Optional, Tuple
|
| 130 |
from transformers import PreTrainedModel
|
| 131 |
from transformers.modeling_outputs import Seq2SeqLMOutput
|
| 132 |
-
from .configuration_small_transformer import SmallTransformerConfig
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
class SmallTransformerPreTrainedModel(PreTrainedModel):
|
| 136 |
config_class = SmallTransformerConfig
|
|
|
|
| 129 |
from typing import Optional, Tuple
|
| 130 |
from transformers import PreTrainedModel
|
| 131 |
from transformers.modeling_outputs import Seq2SeqLMOutput
|
|
|
|
| 132 |
|
| 133 |
+
class SmallTransformerConfig(PretrainedConfig):
|
| 134 |
+
model_type = "small_transformer"
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
vocab_size=80000,
|
| 139 |
+
d_model=256,
|
| 140 |
+
nhead=8,
|
| 141 |
+
num_encoder_layers=3,
|
| 142 |
+
num_decoder_layers=3,
|
| 143 |
+
dim_feedforward=512,
|
| 144 |
+
dropout=0.1,
|
| 145 |
+
max_position_embeddings=512,
|
| 146 |
+
pad_token_id=0,
|
| 147 |
+
bos_token_id=1,
|
| 148 |
+
eos_token_id=2,
|
| 149 |
+
use_return_dict=True,
|
| 150 |
+
**kwargs
|
| 151 |
+
):
|
| 152 |
+
super().__init__(
|
| 153 |
+
pad_token_id=pad_token_id,
|
| 154 |
+
bos_token_id=bos_token_id,
|
| 155 |
+
eos_token_id=eos_token_id,
|
| 156 |
+
**kwargs
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.vocab_size = vocab_size
|
| 160 |
+
self.d_model = d_model
|
| 161 |
+
self.nhead = nhead
|
| 162 |
+
self.num_encoder_layers = num_encoder_layers
|
| 163 |
+
self.num_decoder_layers = num_decoder_layers
|
| 164 |
+
self.dim_feedforward = dim_feedforward
|
| 165 |
+
self.dropout = dropout
|
| 166 |
+
self.max_position_embeddings = max_position_embeddings
|
| 167 |
+
self.use_return_dict = use_return_dict
|
| 168 |
|
| 169 |
class SmallTransformerPreTrainedModel(PreTrainedModel):
|
| 170 |
config_class = SmallTransformerConfig
|