i4ata commited on
Commit
d835e09
·
1 Parent(s): 563e5fb
Files changed (1) hide show
  1. utils.py +6 -16
utils.py CHANGED
@@ -4,22 +4,12 @@ from typing import Literal, Dict
4
 
5
  _weights = models.ViT_B_16_Weights.DEFAULT
6
 
7
- model_transforms: Dict[Literal['custom', 'pretrained'], Dict[Literal['train', 'val'], transforms.Compose]] = {
8
- 'custom': {
9
- 'train': transforms.Compose([
10
- transforms.Resize((224, 224)),
11
- transforms.TrivialAugmentWide(),
12
- transforms.ToTensor()
13
- ]),
14
- 'val': transforms.Compose([
15
- transforms.Resize((224, 224)),
16
- transforms.ToTensor()
17
- ])
18
- },
19
- 'pretrained': {
20
- 'train': _weights.transforms(),
21
- 'val': _weights.transforms()
22
- }
23
  }
24
 
25
  def get_pretrained_vit() -> models.VisionTransformer:
 
4
 
5
  _weights = models.ViT_B_16_Weights.DEFAULT
6
 
7
+ model_transforms: Dict[Literal['Custom', 'Pretrained'], transforms.Compose] = {
8
+ 'Custom': transforms.Compose([
9
+ transforms.Resize((224, 224)),
10
+ transforms.ToTensor()
11
+ ]),
12
+ 'Pretrained': _weights.transforms()
 
 
 
 
 
 
 
 
 
 
13
  }
14
 
15
  def get_pretrained_vit() -> models.VisionTransformer: