Bellou1337 commited on
Commit
711e816
·
verified ·
1 Parent(s): ddc2ebc

feat: svm model

Browse files
evaluate.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+
4
+ from sklearn.metrics import accuracy_score, confusion_matrix
5
+
6
+ parser = argparse.ArgumentParser()
7
+
8
+ parser.add_argument(
9
+ "--ground-truth",
10
+ required=True
11
+ )
12
+
13
+ parser.add_argument(
14
+ "--predictions",
15
+ required=True
16
+ )
17
+
18
+ args = parser.parse_args()
19
+
20
+ ground_truth_csv = args.ground_truth
21
+ predictions_csv = args.predictions
22
+
23
+ gt = {}
24
+
25
+ with open(ground_truth_csv, newline="") as f:
26
+ reader = csv.reader(f)
27
+ header = next(reader)
28
+
29
+ for row in reader:
30
+ path = row[0]
31
+ label = row[1]
32
+ gt[path] = label
33
+
34
+
35
+ preds = {}
36
+
37
+ with open(predictions_csv, newline="") as f:
38
+ reader = csv.reader(f)
39
+ header = next(reader)
40
+
41
+ for row in reader:
42
+ path = row[0]
43
+ label = row[1]
44
+ preds[path] = label
45
+
46
+
47
+ y_true = []
48
+ y_pred = []
49
+
50
+ for path, true_label in gt.items():
51
+ if path in preds:
52
+ y_true.append(true_label)
53
+ y_pred.append(preds[path])
54
+
55
+
56
+ acc = accuracy_score(y_true, y_pred)
57
+
58
+ cm = confusion_matrix(y_true, y_pred)
59
+
60
+ print("Accuracy:", acc)
61
+ print("Confusion matrix:")
62
+ for row in cm:
63
+ print(" ".join(str(x) for x in row))
64
+
65
+
66
+
67
+
68
+
69
+
fashion_predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
fashion_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
fashion_train.csv ADDED
The diff for this file is too large to render. See raw diff
 
mnist.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import argparse
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ from sklearn.svm import SVC
7
+ from joblib import dump, load
8
+
9
+
10
+ def load_image_as_vector(path):
11
+ img = Image.open(path).convert('L')
12
+ arr = np.array(img)
13
+ return arr.flatten()
14
+
15
+
16
+
17
+ parser = argparse.ArgumentParser()
18
+
19
+ parser.add_argument("--mode", required=True)
20
+ parser.add_argument("--dataset")
21
+ parser.add_argument("--model", required=True)
22
+ parser.add_argument("--input")
23
+ parser.add_argument("--output")
24
+
25
+
26
+ args = parser.parse_args()
27
+
28
+ if args.mode == "train":
29
+ dataset_csv = args.dataset
30
+ model_path = args.model
31
+
32
+ paths = []
33
+ labels = []
34
+
35
+ with open(dataset_csv, newline="") as f:
36
+ reader = csv.reader(f)
37
+ header = next(reader)
38
+
39
+ for row in reader:
40
+ path = row[0]
41
+ label = row[1]
42
+ paths.append(path)
43
+ labels.append(label)
44
+
45
+
46
+ X_list = []
47
+ Y_list = []
48
+
49
+
50
+ for path, label in zip(paths, labels):
51
+ vec = load_image_as_vector(path)
52
+ X_list.append(vec)
53
+ Y_list.append(label)
54
+
55
+ X = np.array(X_list)
56
+ Y = np.array(Y_list)
57
+
58
+
59
+ model = SVC(kernel="rbf", gamma="scale")
60
+ model.fit(X, Y)
61
+
62
+
63
+ dump(model, model_path)
64
+
65
+ elif args.mode == "inference":
66
+ model_path = args.model
67
+ input_csv = args.input
68
+ output_csv = args.output
69
+
70
+ model = load(model_path)
71
+
72
+ paths = []
73
+
74
+ with open(input_csv, newline="") as f:
75
+ reader = csv.reader(f)
76
+ header = next(reader)
77
+
78
+ for row in reader:
79
+ path = row[0]
80
+ paths.append(path)
81
+
82
+ X_list = []
83
+
84
+ for path in paths:
85
+ vec = load_image_as_vector(path)
86
+ X_list.append(vec)
87
+
88
+ X = np.array(X_list)
89
+
90
+ preds = model.predict(X)
91
+
92
+ with open(output_csv, 'w', newline="") as f:
93
+ writer = csv.writer(f)
94
+ writer.writerow(['path', 'label'])
95
+ for path, label in zip(paths, preds):
96
+ writer.writerow([path, label])
97
+
98
+
mnist_predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
mnist_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
mnist_train.csv ADDED
The diff for this file is too large to render. See raw diff
 
prepare_from_hf.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import os
3
+ import csv
4
+
5
+
6
+ mnist = load_dataset("ylecun/mnist")
7
+ mnist_train = mnist["train"]
8
+ mnist_test = mnist["test"]
9
+
10
+
11
+ MNIST_TRAIN_DIR = "mnist_images_train"
12
+ MNIST_TEST_DIR = "mnist_images_test"
13
+
14
+ MNIST_TRAIN_CSV = "mnist_train.csv"
15
+ MNIST_TEST_CSV = "mnist_test.csv"
16
+
17
+ os.makedirs(MNIST_TRAIN_DIR, exist_ok=True)
18
+ os.makedirs(MNIST_TEST_DIR, exist_ok=True)
19
+
20
+ with open(MNIST_TRAIN_CSV, "w", newline="") as f:
21
+ writer = csv.writer(f)
22
+ writer.writerow(["path", "label"])
23
+
24
+ for idx, item in enumerate(mnist_train):
25
+ img = item["image"]
26
+ label = item["label"]
27
+
28
+ filename = f"mnist_train_{idx:05d}_{label}.png"
29
+ img_path = os.path.join(MNIST_TRAIN_DIR, filename)
30
+
31
+ img.save(img_path)
32
+ writer.writerow([img_path, label])
33
+
34
+
35
+ with open(MNIST_TEST_CSV, "w", newline="") as f:
36
+ writer = csv.writer(f)
37
+ writer.writerow(["path", "label"])
38
+
39
+ for idx, item in enumerate(mnist_test):
40
+ img = item["image"]
41
+ label = item["label"]
42
+
43
+ filename = f"mnist_test_{idx:05d}_{label}.png"
44
+ img_path = os.path.join(MNIST_TEST_DIR, filename)
45
+
46
+ img.save(img_path)
47
+ writer.writerow([img_path, label])
48
+
49
+
50
+ fashion = load_dataset("fashion_mnist")
51
+ fashion_train = fashion["train"]
52
+ fashion_test = fashion["test"]
53
+
54
+ FASHION_TRAIN_DIR = "fashion_images_train"
55
+ FASHION_TEST_DIR = "fashion_images_test"
56
+
57
+ FASHION_TRAIN_CSV = "fashion_train.csv"
58
+ FASHION_TEST_CSV = "fashion_test.csv"
59
+
60
+ os.makedirs(FASHION_TRAIN_DIR, exist_ok=True)
61
+ os.makedirs(FASHION_TEST_DIR, exist_ok=True)
62
+
63
+ with open(FASHION_TRAIN_CSV, "w", newline="") as f:
64
+ writer = csv.writer(f)
65
+ writer.writerow(["path", "label"])
66
+
67
+ for idx, item in enumerate(fashion_train):
68
+ img = item["image"]
69
+ label = item["label"]
70
+
71
+ filename = f"fashion_train_{idx:05d}_{label}.png"
72
+ img_path = os.path.join(FASHION_TRAIN_DIR, filename)
73
+
74
+ img.save(img_path)
75
+ writer.writerow([img_path, label])
76
+
77
+
78
+ with open(FASHION_TEST_CSV, "w", newline="") as f:
79
+ writer = csv.writer(f)
80
+ writer.writerow(["path", "label"])
81
+
82
+ for idx, item in enumerate(fashion_test):
83
+ img = item["image"]
84
+ label = item["label"]
85
+
86
+ filename = f"fashion_test_{idx:05d}_{label}.png"
87
+ img_path = os.path.join(FASHION_TEST_DIR, filename)
88
+
89
+ img.save(img_path)
90
+ writer.writerow([img_path, label])
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ Pillow
3
+ scikit-learn
4
+ joblib
5
+ datasets
svm_fashion.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51f932d8909a56d06dcf4dc6c8dc9454d2e8506949291352ca17e491d0db2b86
3
+ size 133147243
svm_mnist.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daee5169488a0e5eb3b481edfc1eaf86df62632f25e8335ae4e0ca268e826e93
3
+ size 79199979