Upload modeling_baseline.py with huggingface_hub
Browse files- modeling_baseline.py +12 -21
modeling_baseline.py
CHANGED
|
@@ -2,16 +2,11 @@
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from transformers import PreTrainedModel
|
| 5 |
-
|
| 6 |
-
try:
|
| 7 |
-
from .configuration_baseline import BaselineConfig
|
| 8 |
-
except ImportError:
|
| 9 |
-
from configuration_baseline import BaselineConfig
|
| 10 |
-
|
| 11 |
try:
|
| 12 |
from x_transformers import TransformerWrapper, Encoder
|
| 13 |
except ImportError:
|
| 14 |
-
raise ImportError("
|
| 15 |
|
| 16 |
class BaselineModel(PreTrainedModel):
|
| 17 |
config_class = BaselineConfig
|
|
@@ -19,13 +14,11 @@ class BaselineModel(PreTrainedModel):
|
|
| 19 |
def __init__(self, config):
|
| 20 |
super().__init__(config)
|
| 21 |
self.config = config
|
| 22 |
-
|
| 23 |
-
# Instantiate the x_transformers model
|
| 24 |
self.model = TransformerWrapper(
|
| 25 |
num_tokens=config.vocab_size,
|
| 26 |
max_seq_len=config.seq_len,
|
| 27 |
-
use_abs_pos_emb=False,
|
| 28 |
-
tie_embedding=True,
|
| 29 |
attn_layers=Encoder(
|
| 30 |
dim=config.d_model,
|
| 31 |
depth=config.depth,
|
|
@@ -33,21 +26,19 @@ class BaselineModel(PreTrainedModel):
|
|
| 33 |
layer_dropout=config.dropout,
|
| 34 |
attn_dropout=config.dropout,
|
| 35 |
ff_dropout=config.dropout,
|
| 36 |
-
rotary_pos_emb=True,
|
| 37 |
-
attn_flash=True,
|
| 38 |
use_scalenorm=False
|
| 39 |
)
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def forward(self, input_ids, labels=None, mask=None):
|
| 43 |
-
# x_transformers takes 'mask' argument if provided
|
| 44 |
logits = self.model(input_ids, mask=mask)
|
| 45 |
-
|
| 46 |
-
loss = None
|
| 47 |
if labels is not None:
|
| 48 |
-
|
| 49 |
-
# Reshape for loss calculation
|
| 50 |
-
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 51 |
-
return {"loss": loss, "logits": logits}
|
| 52 |
-
|
| 53 |
return logits
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from transformers import PreTrainedModel
|
| 5 |
+
from .configuration_baseline import BaselineConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
try:
|
| 7 |
from x_transformers import TransformerWrapper, Encoder
|
| 8 |
except ImportError:
|
| 9 |
+
raise ImportError("pip install x-transformers")
|
| 10 |
|
| 11 |
class BaselineModel(PreTrainedModel):
|
| 12 |
config_class = BaselineConfig
|
|
|
|
| 14 |
def __init__(self, config):
|
| 15 |
super().__init__(config)
|
| 16 |
self.config = config
|
|
|
|
|
|
|
| 17 |
self.model = TransformerWrapper(
|
| 18 |
num_tokens=config.vocab_size,
|
| 19 |
max_seq_len=config.seq_len,
|
| 20 |
+
use_abs_pos_emb=False,
|
| 21 |
+
tie_embedding=True,
|
| 22 |
attn_layers=Encoder(
|
| 23 |
dim=config.d_model,
|
| 24 |
depth=config.depth,
|
|
|
|
| 26 |
layer_dropout=config.dropout,
|
| 27 |
attn_dropout=config.dropout,
|
| 28 |
ff_dropout=config.dropout,
|
| 29 |
+
rotary_pos_emb=True,
|
| 30 |
+
attn_flash=True,
|
| 31 |
use_scalenorm=False
|
| 32 |
)
|
| 33 |
)
|
| 34 |
+
# TIE FIX
|
| 35 |
+
if hasattr(self.model.token_emb, 'emb'):
|
| 36 |
+
self.model.to_logits.weight = self.model.token_emb.emb.weight
|
| 37 |
+
else:
|
| 38 |
+
self.model.to_logits.weight = self.model.token_emb.weight
|
| 39 |
|
| 40 |
def forward(self, input_ids, labels=None, mask=None):
|
|
|
|
| 41 |
logits = self.model(input_ids, mask=mask)
|
|
|
|
|
|
|
| 42 |
if labels is not None:
|
| 43 |
+
return {"loss": nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)), "logits": logits}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return logits
|