import os import json import pandas as pd import torch from PIL import Image from torchvision import transforms from model import resnet101 def predict(test_metadata, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'): data_transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # load image # img_name_list = ["1163.jpg", "1164.jpg"] # id_list = [1, 2] id_list = test_metadata['observation_id'].tolist() img_name_list = test_metadata['filename'].tolist() img_list = [] print(os.path.abspath(os.path.dirname(__file__))) for img_name in img_name_list: img_path = os.path.join(root_path, img_name) assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path).convert('RGB') img = data_transform(img) img_list.append(img) # batch img batch_img = torch.stack(img_list, dim=0) with torch.no_grad(): # predict class output = model(batch_img.to(device)).cpu() predict = torch.softmax(output, dim=1) probs, classesId = torch.max(predict, dim=1) probs = probs.data.numpy().tolist() classesId = classesId.data.numpy().tolist() id2classId = dict() id2prob = dict() for i, id in enumerate(id_list): if id not in id2classId.keys(): id2classId[id] = classesId[i] id2prob[id] = probs[i] else: if probs[i] > id2prob[id]: id2classId[id] = classesId[i] id2prob[id] = probs[i] classes = list() for id in id_list: classes.append(str(id2classId[id])) test_metadata["class_id"] = classes user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) if __name__ == '__main__': import zipfile with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: zip_ref.extractall("/tmp/data") root_path = '/tmp/data/private_testset' # root_path = "../../data_set/flower_data/val/n1" # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) # json_file = open(json_path, "r") # index2class = json.load(json_file) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # create model model = resnet101(num_classes=1784).to(device) # load model weights weights_path = "./resNet101.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) model.load_state_dict(torch.load(weights_path, map_location=device)) # prediction model.eval() metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv" # metadata_file_path = "./test.csv" test_metadata = pd.read_csv(metadata_file_path) predict(test_metadata, root_path)