chirmy commited on
Commit
e3deef7
·
verified ·
1 Parent(s): a354281

Delete script.py

Browse files
Files changed (1) hide show
  1. script.py +0 -97
script.py DELETED
@@ -1,97 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- import onnxruntime as ort
4
- import os
5
- from tqdm import tqdm
6
- import timm
7
- import torchvision.transforms as T
8
- from PIL import Image
9
- import torch
10
-
11
- def is_gpu_available():
12
- """Check if the python package `onnxruntime-gpu` is installed."""
13
- return torch.cuda.is_available()
14
-
15
-
16
- class PytorchWorker:
17
- """Run inference using ONNX runtime."""
18
-
19
- def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1604):
20
-
21
- def _load_model(model_name, model_path):
22
-
23
- print("Setting up Pytorch Model")
24
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
- print(f"Using devide: {self.device}")
26
-
27
- model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
28
-
29
- # if not torch.cuda.is_available():
30
- # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
31
- # else:
32
- # model_ckpt = torch.load(model_path)
33
-
34
- model_ckpt = torch.load(model_path, map_location=self.device)
35
- model.load_state_dict(model_ckpt)
36
-
37
- return model.to(self.device).eval()
38
-
39
- self.model = _load_model(model_name, model_path)
40
-
41
- self.transforms = T.Compose([T.Resize((299, 299)),
42
- T.ToTensor(),
43
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
44
-
45
-
46
- def predict_image(self, image: np.ndarray) -> list():
47
- """Run inference using ONNX runtime.
48
-
49
- :param image: Input image as numpy array.
50
- :return: A list with logits and confidences.
51
- """
52
-
53
- logits = self.model(self.transforms(image).unsqueeze(0).to(self.device))
54
-
55
- return logits.tolist()
56
-
57
-
58
- def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
59
- """Make submission with given """
60
-
61
- model = PytorchWorker(model_path, model_name)
62
-
63
- predictions = []
64
-
65
- for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
66
- image_path = os.path.join(images_root_path, row.image_path)
67
-
68
- test_image = Image.open(image_path).convert("RGB")
69
-
70
- logits = model.predict_image(test_image)
71
-
72
- predictions.append(np.argmax(logits))
73
-
74
- test_metadata["class_id"] = predictions
75
-
76
- user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
77
- user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
78
-
79
-
80
- if __name__ == "__main__":
81
-
82
- import zipfile
83
-
84
- with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
85
- zip_ref.extractall("/tmp/data")
86
-
87
- MODEL_PATH = "pytorch_model.bin"
88
- MODEL_NAME = "tf_efficientnet_b1.ap_in1k"
89
-
90
- metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
91
- test_metadata = pd.read_csv(metadata_file_path)
92
-
93
- make_submission(
94
- test_metadata=test_metadata,
95
- model_path=MODEL_PATH,
96
- model_name=MODEL_NAME
97
- )