tensorgirl commited on
Commit
65a341c
·
1 Parent(s): aab5de0

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +4 -2
augvit_model.py CHANGED
@@ -179,6 +179,8 @@ class AugViTForImageClassification(TFPreTrainedModel):
179
  return_dict: Optional[bool] = None,
180
  training: Optional[bool] = False,
181
  **kwargs):
182
- print(pixel_values)
183
- logits = self.model(pixel_values['pixel_values'])
 
 
184
  return logits
 
179
  return_dict: Optional[bool] = None,
180
  training: Optional[bool] = False,
181
  **kwargs):
182
+ inp = pixel_values['pixel_values']
183
+ if inp.shape[-1]!=3:
184
+ inp = tf.transpose(inp,[0,2,3,1])
185
+ logits = self.model()
186
  return logits