Update modeling_footballbert.py
Browse files- modeling_footballbert.py +6 -6
modeling_footballbert.py
CHANGED
|
@@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|
| 8 |
from transformers import PreTrainedModel
|
| 9 |
from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
@dataclass
|
| 14 |
class FootballBERTOutput(BaseModelOutputWithPooling):
|
|
@@ -72,7 +72,7 @@ class PlayerSelfAttention(nn.Module):
|
|
| 72 |
class PlayerTransformerBlock(nn.Module):
|
| 73 |
"""Standard Transformer encoder layer."""
|
| 74 |
|
| 75 |
-
def __init__(self, config
|
| 76 |
super().__init__()
|
| 77 |
self.attention = PlayerSelfAttention(config.hidden_size, config.num_attention_heads)
|
| 78 |
|
|
@@ -96,7 +96,7 @@ class PlayerTransformerBlock(nn.Module):
|
|
| 96 |
|
| 97 |
class PlayerEncoder(nn.Module):
|
| 98 |
|
| 99 |
-
def __init__(self, config
|
| 100 |
|
| 101 |
super(PlayerEncoder, self).__init__()
|
| 102 |
|
|
@@ -133,7 +133,7 @@ class FootballBERTPreTrainedModel(PreTrainedModel):
|
|
| 133 |
and loading pretrained models.
|
| 134 |
"""
|
| 135 |
|
| 136 |
-
|
| 137 |
base_model_prefix = "footballbert"
|
| 138 |
supports_gradient_checkpointing = False
|
| 139 |
|
|
@@ -186,7 +186,7 @@ class FootballBERTModel(FootballBERTPreTrainedModel):
|
|
| 186 |
```
|
| 187 |
"""
|
| 188 |
|
| 189 |
-
def __init__(self, config
|
| 190 |
super().__init__(config)
|
| 191 |
self.config = config
|
| 192 |
|
|
@@ -259,7 +259,7 @@ class FootballBERTForMaskedPlayerPrediction(FootballBERTPreTrainedModel):
|
|
| 259 |
```
|
| 260 |
"""
|
| 261 |
|
| 262 |
-
def __init__(self, config
|
| 263 |
super().__init__(config)
|
| 264 |
self.config = config
|
| 265 |
|
|
|
|
| 8 |
from transformers import PreTrainedModel
|
| 9 |
from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput
|
| 10 |
|
| 11 |
+
from .configuration_footballbert import FootballBERTConfig
|
| 12 |
|
| 13 |
@dataclass
|
| 14 |
class FootballBERTOutput(BaseModelOutputWithPooling):
|
|
|
|
| 72 |
class PlayerTransformerBlock(nn.Module):
|
| 73 |
"""Standard Transformer encoder layer."""
|
| 74 |
|
| 75 |
+
def __init__(self, config: FootballBERTConfig):
|
| 76 |
super().__init__()
|
| 77 |
self.attention = PlayerSelfAttention(config.hidden_size, config.num_attention_heads)
|
| 78 |
|
|
|
|
| 96 |
|
| 97 |
class PlayerEncoder(nn.Module):
|
| 98 |
|
| 99 |
+
def __init__(self, config: FootballBERTConfig):
|
| 100 |
|
| 101 |
super(PlayerEncoder, self).__init__()
|
| 102 |
|
|
|
|
| 133 |
and loading pretrained models.
|
| 134 |
"""
|
| 135 |
|
| 136 |
+
config_class = FootballBERTConfig
|
| 137 |
base_model_prefix = "footballbert"
|
| 138 |
supports_gradient_checkpointing = False
|
| 139 |
|
|
|
|
| 186 |
```
|
| 187 |
"""
|
| 188 |
|
| 189 |
+
def __init__(self, config: FootballBERTConfig):
|
| 190 |
super().__init__(config)
|
| 191 |
self.config = config
|
| 192 |
|
|
|
|
| 259 |
```
|
| 260 |
"""
|
| 261 |
|
| 262 |
+
def __init__(self, config: FootballBERTConfig):
|
| 263 |
super().__init__(config)
|
| 264 |
self.config = config
|
| 265 |
|