xcssgzs commited on
Commit
c4c6fe1
·
verified ·
1 Parent(s): 5a00965

Delete script.py

Browse files
Files changed (1) hide show
  1. script.py +0 -94
script.py DELETED
@@ -1,94 +0,0 @@
1
- import os
2
- import json
3
-
4
- import pandas as pd
5
- import torch
6
- from PIL import Image
7
- from torchvision import transforms
8
- from model import resnet101
9
-
10
-
11
- def predict(test_metadata, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'):
12
-
13
- data_transform = transforms.Compose(
14
- [transforms.Resize(256),
15
- transforms.CenterCrop(224),
16
- transforms.ToTensor(),
17
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
18
-
19
- # load image
20
- # img_name_list = ["1163.jpg", "1164.jpg"]
21
- # id_list = [1, 2]
22
- id_list = test_metadata['observation_id'].tolist()
23
- img_name_list = test_metadata['filename'].tolist()
24
- img_list = []
25
- print(os.path.abspath(os.path.dirname(__file__)))
26
- for img_name in img_name_list:
27
- img_path = os.path.join(root_path, img_name)
28
- assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
29
- img = Image.open(img_path).convert('RGB')
30
- img = data_transform(img)
31
- img_list.append(img)
32
-
33
- # batch img
34
- batch_img = torch.stack(img_list, dim=0)
35
-
36
- with torch.no_grad():
37
- # predict class
38
- output = model(batch_img.to(device)).cpu()
39
- predict = torch.softmax(output, dim=1)
40
- probs, classesId = torch.max(predict, dim=1)
41
- probs = probs.data.numpy().tolist()
42
- classesId = classesId.data.numpy().tolist()
43
- id2classId = dict()
44
- id2prob = dict()
45
- for i, id in enumerate(id_list):
46
- if id not in id2classId.keys():
47
- id2classId[id] = classesId[i]
48
- id2prob[id] = probs[i]
49
- else:
50
- if probs[i] > id2prob[id]:
51
- id2classId[id] = classesId[i]
52
- id2prob[id] = probs[i]
53
- classes = list()
54
- for id in id_list:
55
- classes.append(str(id2classId[id]))
56
- test_metadata["class_id"] = classes
57
-
58
- user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
59
- user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
60
-
61
-
62
- if __name__ == '__main__':
63
-
64
- import zipfile
65
-
66
- with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
67
- zip_ref.extractall("/tmp/data")
68
-
69
- root_path = '/tmp/data/private_testset'
70
- # root_path = "../../data_set/flower_data/val/n1"
71
-
72
- # read class_indict
73
- json_path = './class_indices.json'
74
- assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
75
-
76
- # json_file = open(json_path, "r")
77
- # index2class = json.load(json_file)
78
-
79
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80
- # create model
81
- model = resnet101(num_classes=1784).to(device)
82
-
83
- # load model weights
84
- weights_path = "./resNet101.pth"
85
- assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
86
- model.load_state_dict(torch.load(weights_path, map_location=device))
87
-
88
- # prediction
89
- model.eval()
90
-
91
- metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
92
- # metadata_file_path = "./test.csv"
93
- test_metadata = pd.read_csv(metadata_file_path)
94
- predict(test_metadata, root_path)