Commit
·
fcf40f1
1
Parent(s):
f203678
Update augvit_model.py
Browse files- augvit_model.py +8 -3
augvit_model.py
CHANGED
|
@@ -156,7 +156,7 @@ class AUGViT(Model):
|
|
| 156 |
|
| 157 |
from transformers import TFPreTrainedModel
|
| 158 |
from .augvit_config import AugViTConfig
|
| 159 |
-
|
| 160 |
class AugViTForImageClassification(TFPreTrainedModel):
|
| 161 |
config_class = AugViTConfig
|
| 162 |
def __init__(self, config):
|
|
@@ -173,6 +173,11 @@ class AugViTForImageClassification(TFPreTrainedModel):
|
|
| 173 |
emb_dropout =config.emb_dropout
|
| 174 |
)
|
| 175 |
|
| 176 |
-
def call(self,
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
return logits
|
|
|
|
| 156 |
|
| 157 |
from transformers import TFPreTrainedModel
|
| 158 |
from .augvit_config import AugViTConfig
|
| 159 |
+
from typing import Dict, Optional, Tuple, Union
|
| 160 |
class AugViTForImageClassification(TFPreTrainedModel):
|
| 161 |
config_class = AugViTConfig
|
| 162 |
def __init__(self, config):
|
|
|
|
| 173 |
emb_dropout =config.emb_dropout
|
| 174 |
)
|
| 175 |
|
| 176 |
+
def call(self, pixel_values: tf.Tensor | None = None,
|
| 177 |
+
output_hidden_states: Optional[bool] = None,
|
| 178 |
+
labels: tf.Tensor | None = None,
|
| 179 |
+
return_dict: Optional[bool] = None,
|
| 180 |
+
training: Optional[bool] = False,
|
| 181 |
+
**kwargs):
|
| 182 |
+
logits = self.model(pixel_values)
|
| 183 |
return logits
|