Commit ·
65a341c
1
Parent(s): aab5de0
Update augvit_model.py
Browse files- augvit_model.py +4 -2
augvit_model.py
CHANGED
|
@@ -179,6 +179,8 @@ class AugViTForImageClassification(TFPreTrainedModel):
|
|
| 179 |
return_dict: Optional[bool] = None,
|
| 180 |
training: Optional[bool] = False,
|
| 181 |
**kwargs):
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
| 184 |
return logits
|
|
|
|
| 179 |
return_dict: Optional[bool] = None,
|
| 180 |
training: Optional[bool] = False,
|
| 181 |
**kwargs):
|
| 182 |
+
inp = pixel_values['pixel_values']
|
| 183 |
+
if inp.shape[-1]!=3:
|
| 184 |
+
inp = tf.transpose(inp,[0,2,3,1])
|
| 185 |
+
logits = self.model()
|
| 186 |
return logits
|