Chukwuka commited on
Commit
bf737d3
·
1 Parent(s): 0654871

Corrected the torch.transforms

Browse files
Files changed (1) hide show
  1. data_setup.py +6 -7
data_setup.py CHANGED
@@ -7,17 +7,17 @@ import numpy as np
7
  import json
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
- # import torchvision.transforms as tt
11
- import albumentations as A
12
- from albumentations.pytorch import ToTensorV2
13
 
14
 
15
  stats = (0.4862, 0.4561, 0.3941), (0.2202, 0.2142, 0.2160)
16
 
17
- model_tsfm = A.Compose([
18
- A.Resize(224, 224),
19
  # A.Normalize(*stats),
20
- ToTensorV2()
21
  ])
22
 
23
 
@@ -29,7 +29,6 @@ cat_to_name.index = cat_to_name.index.astype(np.int32)
29
  cat_to_name.sort_index(inplace=True)
30
  classes = cat_to_name.values
31
 
32
- # classes = ['Australian terrier', 'Border terrier', 'Samoyed', 'Beagle', 'Shih-Tzu', 'English foxhound', 'Rhodesian ridgeback', 'Dingo', 'Golden retriever', 'Old English sheepdog']
33
 
34
  if __name__ == "__main__":
35
  parser = argparse.ArgumentParser()
 
7
  import json
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ # import albumentations as A
12
+ # from albumentations.pytorch import ToTensorV2
13
 
14
 
15
  stats = (0.4862, 0.4561, 0.3941), (0.2202, 0.2142, 0.2160)
16
 
17
+ model_tsfm = T.Compose([
18
+ T.Resize((224, 224)),
19
  # A.Normalize(*stats),
20
+ T.ToTensor()
21
  ])
22
 
23
 
 
29
  cat_to_name.sort_index(inplace=True)
30
  classes = cat_to_name.values
31
 
 
32
 
33
  if __name__ == "__main__":
34
  parser = argparse.ArgumentParser()