sieberm commited on
Commit
aa9560a
·
verified ·
1 Parent(s): aa37edc

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +23 -10
script.py CHANGED
@@ -1,31 +1,45 @@
1
  import pandas as pd
2
  import numpy as np
3
- import onnxruntime as ort
4
  import os
 
5
  from tqdm import tqdm
6
  import timm
7
  import torchvision.transforms as T
8
  from PIL import Image
9
  import torch
10
 
 
11
  def is_gpu_available():
12
  """Check if the python package `onnxruntime-gpu` is installed."""
13
  return torch.cuda.is_available()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class PytorchWorker:
17
  """Run inference using ONNX runtime."""
18
 
19
  def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1784):
20
-
21
  def _load_model(model_name, model_path):
22
-
23
  print("Setting up Pytorch Model")
24
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
  print(f"Using devide: {self.device}")
26
 
27
- model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
28
-
29
  # if not torch.cuda.is_available():
30
  # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
31
  # else:
@@ -42,10 +56,8 @@ class PytorchWorker:
42
  T.ToTensor(),
43
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
44
 
45
-
46
  def predict_image(self, image: np.ndarray) -> list():
47
  """Run inference using ONNX runtime.
48
-
49
  :param image: Input image as numpy array.
50
  :return: A list with logits and confidences.
51
  """
@@ -55,7 +67,8 @@ class PytorchWorker:
55
  return logits.tolist()
56
 
57
 
58
- def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
 
59
  """Make submission with given """
60
 
61
  model = PytorchWorker(model_path, model_name)
@@ -78,9 +91,8 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
78
 
79
 
80
  if __name__ == "__main__":
81
-
82
  import zipfile
83
-
84
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
85
  zip_ref.extractall("/tmp/data")
86
 
@@ -99,3 +111,4 @@ if __name__ == "__main__":
99
  model_name=MODEL_NAME,
100
  # images_root_path='/home/zeleznyt/mnt/data-ntis/projects/korpusy_cv/SnakeCLEF2024/val/SnakeCLEF2023-medium_size'
101
  )
 
 
1
  import pandas as pd
2
  import numpy as np
 
3
  import os
4
+ from torch import nn
5
  from tqdm import tqdm
6
  import timm
7
  import torchvision.transforms as T
8
  from PIL import Image
9
  import torch
10
 
11
+
12
  def is_gpu_available():
13
  """Check if the python package `onnxruntime-gpu` is installed."""
14
  return torch.cuda.is_available()
15
 
16
+ class CustomModel(nn.Module):
17
+ def __init__(self, base_model_name, num_classes1, num_classes2):
18
+ super(CustomModel, self).__init__()
19
+ self.base_model = timm.create_model(base_model_name, pretrained=False)
20
+ in_features = self.base_model.get_classifier().in_features
21
+ self.base_model.reset_classifier(0) # Remove the original classification layer
22
+
23
+ self.fc1 = nn.Linear(in_features, num_classes1) # Binary classification output
24
+ self.fc2 = nn.Linear(in_features, num_classes2) # Categorical classification output
25
+
26
+ def forward(self, x):
27
+ x = self.base_model(x)
28
+ out1 = torch.sigmoid(self.fc1(x)) # Binary output
29
+ out2 = self.fc2(x) # Categorical output
30
+ return out2
31
 
32
  class PytorchWorker:
33
  """Run inference using ONNX runtime."""
34
 
35
  def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1784):
 
36
  def _load_model(model_name, model_path):
 
37
  print("Setting up Pytorch Model")
38
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39
  print(f"Using devide: {self.device}")
40
 
41
+ # model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
42
+ model =CustomModel(model_name, 1, number_of_categories)
43
  # if not torch.cuda.is_available():
44
  # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
45
  # else:
 
56
  T.ToTensor(),
57
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
58
 
 
59
  def predict_image(self, image: np.ndarray) -> list():
60
  """Run inference using ONNX runtime.
 
61
  :param image: Input image as numpy array.
62
  :return: A list with logits and confidences.
63
  """
 
67
  return logits.tolist()
68
 
69
 
70
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv",
71
+ images_root_path="/tmp/data/private_testset"):
72
  """Make submission with given """
73
 
74
  model = PytorchWorker(model_path, model_name)
 
91
 
92
 
93
  if __name__ == "__main__":
 
94
  import zipfile
95
+
96
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
97
  zip_ref.extractall("/tmp/data")
98
 
 
111
  model_name=MODEL_NAME,
112
  # images_root_path='/home/zeleznyt/mnt/data-ntis/projects/korpusy_cv/SnakeCLEF2024/val/SnakeCLEF2023-medium_size'
113
  )
114
+