Someshfengde commited on
Commit
14b3fa8
·
verified ·
1 Parent(s): 7e3ab71

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. script.py +20 -39
  2. submission.csv +4 -0
script.py CHANGED
@@ -5,8 +5,7 @@ import os
5
  from tqdm import tqdm
6
  import timm
7
  import torchvision.transforms as T
8
- import timm,wandb,albumentations as A
9
- from albumentations.pytorch import ToTensorV2
10
 
11
  from PIL import Image
12
  import torch
@@ -25,40 +24,14 @@ def is_gpu_available():
25
  """Check if the python package `onnxruntime-gpu` is installed."""
26
  return torch.cuda.is_available()
27
 
28
- VALID_AUG = A.Compose([
29
- A.SmallestMaxSize(max_size=SZ + 16, p=1.0),
30
- A.CenterCrop(height=SZ, width=SZ, p=1.0),
31
- A.Normalize(),
32
- ToTensorV2(),
33
- ])
34
 
35
 
36
- class LoadImagesAndLabels(torch.utils.data.Dataset):
37
-
38
- def __init__(self, df, transforms, mode='train'):
39
- self.df = df
40
- self.transforms = transforms
41
- self.mode = mode
42
-
43
- def __len__(self): return len(self.df)
44
-
45
- def __getitem__(self,index):
46
- row = self.df.iloc[index]
47
- image_path = str(row.filename)
48
- images_root_path="/tmp/data/private_testset"
49
- image_path = os.path.join(images_root_path, str(row.filename))
50
- img = Image.open(image_path).convert("RGB")
51
- img = np.array(img)
52
-
53
- if self.transforms is not None:
54
- img = self.transforms(image=img)['image']
55
-
56
- if self.mode == 'test':
57
- return img
58
-
59
- label = torch.tensor(labels_class_map[row.binomial_name]).long()
60
- return img, label
61
-
62
  def get_corn_model(model_name, pretrained=True, **kwargs):
63
  model = timm.create_model(model_name, pretrained=pretrained, **kwargs)
64
  model = nn.Sequential(
@@ -82,7 +55,10 @@ class PytorchWorker:
82
  model_ckpt = torch.load("./NB_EXP_V2_008/vit_base_patch16_224_224_bs32_ep16_lr6e05_wd0.05_mixup_cutmix_CV_0.pth", map_location=self.device)
83
  model.load_state_dict(model_ckpt)
84
  return model.to(self.device)
85
-
 
 
 
86
  self.model = _load_model()
87
 
88
  def predict_image(self, image: np.ndarray) -> list():
@@ -90,8 +66,8 @@ class PytorchWorker:
90
  :param image: Input image as numpy array.
91
  :return: A list with logits and confidences.
92
  """
93
- image = image.to(self.device)
94
- outputs = self.model(image.unsqueeze(dim = 0))
95
  logits = outputs
96
  return logits.tolist()
97
 
@@ -100,10 +76,15 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
100
  """Make submission with given """
101
 
102
  model = PytorchWorker()
103
- data = LoadImagesAndLabels(test_metadata, VALID_AUG, mode='test')
104
  predictions = []
105
 
106
- for image in data:
 
 
 
 
 
107
  output = model.predict_image(image)
108
  string_label_dup = LABELS.get(str(np.argmax(output)), 'Acanthophis antarcticus')
109
  prediction_class = ORIGINAL_LABELS.get(string_label_dup, 1)
 
5
  from tqdm import tqdm
6
  import timm
7
  import torchvision.transforms as T
8
+ # from albumentations.pytorch import ToTensorV2
 
9
 
10
  from PIL import Image
11
  import torch
 
24
  """Check if the python package `onnxruntime-gpu` is installed."""
25
  return torch.cuda.is_available()
26
 
27
+ # VALID_AUG = A.Compose([
28
+ # A.SmallestMaxSize(max_size=SZ + 16, p=1.0),
29
+ # A.CenterCrop(height=SZ, width=SZ, p=1.0),
30
+ # A.Normalize(),
31
+ # ToTensorV2(),
32
+ # ])
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def get_corn_model(model_name, pretrained=True, **kwargs):
36
  model = timm.create_model(model_name, pretrained=pretrained, **kwargs)
37
  model = nn.Sequential(
 
55
  model_ckpt = torch.load("./NB_EXP_V2_008/vit_base_patch16_224_224_bs32_ep16_lr6e05_wd0.05_mixup_cutmix_CV_0.pth", map_location=self.device)
56
  model.load_state_dict(model_ckpt)
57
  return model.to(self.device)
58
+
59
+ self.transforms = T.Compose([T.Resize((SZ, SZ)),
60
+ T.ToTensor(),
61
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
62
  self.model = _load_model()
63
 
64
  def predict_image(self, image: np.ndarray) -> list():
 
66
  :param image: Input image as numpy array.
67
  :return: A list with logits and confidences.
68
  """
69
+ image_data = self.transforms(image).unsqueeze(0).to(self.device)
70
+ outputs = self.model(image_data)
71
  logits = outputs
72
  return logits.tolist()
73
 
 
76
  """Make submission with given """
77
 
78
  model = PytorchWorker()
79
+ data = LoadImagesAndLabels(test_metadata, None, mode='test')
80
  predictions = []
81
 
82
+ for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
83
+ image_path = os.path.join(images_root_path, row.filename)
84
+ # image_path = row.filename
85
+
86
+ image = Image.open(image_path).convert("RGB")
87
+
88
  output = model.predict_image(image)
89
  string_label_dup = LABELS.get(str(np.argmax(output)), 'Acanthophis antarcticus')
90
  prediction_class = ORIGINAL_LABELS.get(string_label_dup, 1)
submission.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ observation_id,class_id
2
+ 1,419
3
+ 2,419
4
+ 3,419