Mhara's picture
Upload folder using huggingface_hub
dae5c90 verified
Raw
History Blame Contribute Delete
3.07 kB
import numpy as np
import onnxruntime
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision import datasets
from PIL import Image
import cv2
import pandas as pd
import os
from tqdm import tqdm
import torch
import torch.nn.functional as F
import argparse
def compute_preds_sum_prob_w_prior_shift(outputs, num_classes, num_domains):
# Training distributions per domain
prior_shift_weight = torch.tensor([
1088/1072, 1088/16, 17746/17515, 17746/231, 6454/6273, 6454/181, 850/834, 850/16
], device=outputs.device) / 100
probs = F.softmax(outputs, dim=0) * prior_shift_weight
domain_probs = []
for i in range(num_domains):
domain_probs.append(probs[:, i * num_classes:(i + 1) * num_classes])
summed_probs = torch.stack(domain_probs, dim=0).sum(dim=0)
predictions = torch.argmax(summed_probs, axis=1)
return predictions
def run_inference(onnx_model_path, image_folder, image_width, image_height):
sessions = onnxruntime.InferenceSession(onnx_model_path)
print(f"Model loaded from {onnx_model_path}")
input_name = sessions.get_inputs()[0].name
_transforms = transforms.Compose([
transforms.Resize((image_width, image_height)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.370, 0.133, 0.092], std=[0.327, 0.090, 0.105])
])
predictions = []
for image_path in tqdm(os.listdir(image_folder), desc=f"Processing images in {image_folder}"):
image = Image.fromarray(cv2.cvtColor(cv2.imread(os.path.join(image_folder, image_path)), cv2.COLOR_BGR2LAB))
image = _transforms(image)
output = sessions.run(None, {input_name: image.unsqueeze(0).numpy()})
preds = compute_preds_sum_prob_w_prior_shift(torch.tensor(output), 2, 4)
predictions.append({
'image_name': image_path,
'target': int(preds)
})
return pd.DataFrame(predictions)
def main():
parser = argparse.ArgumentParser(description="Run inference on images using an ONNX model.")
parser.add_argument("--onnx_model_path", type=str, required=False, default = "./weights/best_model.onnx",
help="Path to the ONNX model.")
parser.add_argument("--image_folder", type=str, required=True,
defauklt = "./isic2020_challenge/valid/malignant",
help="Path to the folder containing images.")
parser.add_argument("--image_width", type=int, default=224, help="Width of the input images.")
parser.add_argument("--image_height", type=int, default=224, help="Height of the input images.")
parser.add_argument("--output_csv", type=str, default='/predictions/pred.csv',
help="Path where to save predictions csv..")
args = parser.parse_args()
predictions = run_inference(
args.onnx_model_path,
args.image_folder,
args.image_width,
args.image_height
)
predictions.to_csv(args.output_csv, index=False)