| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| import onnxruntime as ort | |
| import os | |
| from tqdm import tqdm | |
| def is_gpu_available(): | |
| """Check if the python package `onnxruntime-gpu` is installed.""" | |
| return ort.get_device() == "GPU" | |
| class ONNXWorker: | |
| """Run inference using ONNX runtime.""" | |
| def __init__(self, onnx_path: str): | |
| print("Setting up ONNX runtime session.") | |
| self.use_gpu = is_gpu_available() | |
| if self.use_gpu: | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| providers = ["CPUExecutionProvider"] | |
| print(f"Using {providers}") | |
| self.ort_session = ort.InferenceSession(onnx_path, providers=providers) | |
| def _resize_image(self, image: np.ndarray) -> np.ndarray: | |
| """ | |
| :param image: | |
| :return: | |
| """ | |
| newsize = (300, 300) | |
| im1 = im1.resize(newsize) | |
| def predict_image(self, image: np.ndarray) -> list(): | |
| """Run inference using ONNX runtime. | |
| :param image: Input image as numpy array. | |
| :return: A list with logits and confidences. | |
| """ | |
| logits, _ = self.ort_session.run(None, {"input": image.astype(dtype=np.uint8)}) | |
| return logits.tolist() | |
| def make_submission(test_metadata, model_path, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"): | |
| """Make submission with given """ | |
| model = ONNXWorker(model_path) | |
| predictions = [] | |
| for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)): | |
| image_path = os.path.join(images_root_path, row.filename) | |
| test_image = Image.open(image_path).convert("RGB") | |
| test_image_resized = np.asarray(test_image.resize((256, 256))) | |
| logits = model.predict_image(test_image_resized) | |
| predictions.append(np.argmax(logits)) | |
| test_metadata["class_id"] = predictions | |
| user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") | |
| user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) | |
| if __name__ == "__main__": | |
| import zipfile | |
| with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: | |
| zip_ref.extractall("/tmp/data") | |
| ONNX_MODEL_PATH = "./swinv2_tiny_window16_256.onnx" | |
| metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv" | |
| test_metadata = pd.read_csv(metadata_file_path) | |
| make_submission( | |
| test_metadata=test_metadata, | |
| model_path=ONNX_MODEL_PATH, | |
| ) | |