sharvari0b26 commited on
Commit
a207b60
·
verified ·
1 Parent(s): ed5fe90

Upload 2 files

Browse files
Files changed (2) hide show
  1. multiclass_model.pkl +3 -0
  2. script.py +60 -0
multiclass_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d799eecd128c540ab311a7cb77db6ae088d9b8159a2a6d7f04238ea7859e4d6
3
+ size 1178808
script.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import cv2
4
+ import pandas as pd
5
+ import numpy as np
6
+ from utils.utils import extract_features_from_image, perform_pca, train_svm_model
7
+
8
+
9
+ def run_inference(TEST_IMAGE_PATH, svm_model, k, SUBMISSION_CSV_SAVE_PATH):
10
+
11
+ test_images = os.listdir(TEST_IMAGE_PATH)
12
+ test_images.sort()
13
+
14
+ image_feature_list = []
15
+
16
+ for test_image in test_images:
17
+
18
+ path_to_image = os.path.join(TEST_IMAGE_PATH, test_image)
19
+
20
+ image = cv2.imread(path_to_image)
21
+ image_features = extract_features_from_image(image)
22
+
23
+ image_feature_list.append(image_features)
24
+
25
+ features_multiclass = np.array(image_feature_list)
26
+
27
+ features_multiclass_reduced = perform_pca(features_multiclass, k)
28
+
29
+ multiclass_predictions = svm_model.predict(features_multiclass_reduced)
30
+
31
+ df_predictions = pd.DataFrame(columns=["file_name", "category_id"])
32
+
33
+ for i in range(len(test_images)):
34
+ file_name = test_images[i]
35
+ new_row = pd.DataFrame({"file_name": file_name,
36
+ "category_id": multiclass_predictions[i]}, index=[0])
37
+ df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
38
+
39
+ df_predictions.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
40
+
41
+
42
+
43
+
44
+ if __name__ == "__main__":
45
+
46
+ current_directory = os.path.dirname(os.path.abspath(__file__))
47
+ TEST_IMAGE_PATH = "/tmp/data/test_images"
48
+
49
+ MODEL_NAME = "multiclass_model.pkl"
50
+ MODEL_PATH = os.path.join(current_directory, MODEL_NAME)
51
+
52
+ k = 100
53
+ SUBMISSION_CSV_SAVE_PATH = os.path.join(current_directory, "submission.csv")
54
+
55
+ # load the model
56
+ with open(MODEL_PATH, 'rb') as file:
57
+ svm_model = pickle.load(file)
58
+
59
+
60
+ run_inference(TEST_IMAGE_PATH, svm_model, k, SUBMISSION_CSV_SAVE_PATH)