laba2 / mnist.py
Bellou1337's picture
feat: svm model
711e816 verified
import csv
import argparse
import numpy as np
from PIL import Image
from sklearn.svm import SVC
from joblib import dump, load
def load_image_as_vector(path):
img = Image.open(path).convert('L')
arr = np.array(img)
return arr.flatten()
parser = argparse.ArgumentParser()
parser.add_argument("--mode", required=True)
parser.add_argument("--dataset")
parser.add_argument("--model", required=True)
parser.add_argument("--input")
parser.add_argument("--output")
args = parser.parse_args()
if args.mode == "train":
dataset_csv = args.dataset
model_path = args.model
paths = []
labels = []
with open(dataset_csv, newline="") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
path = row[0]
label = row[1]
paths.append(path)
labels.append(label)
X_list = []
Y_list = []
for path, label in zip(paths, labels):
vec = load_image_as_vector(path)
X_list.append(vec)
Y_list.append(label)
X = np.array(X_list)
Y = np.array(Y_list)
model = SVC(kernel="rbf", gamma="scale")
model.fit(X, Y)
dump(model, model_path)
elif args.mode == "inference":
model_path = args.model
input_csv = args.input
output_csv = args.output
model = load(model_path)
paths = []
with open(input_csv, newline="") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
path = row[0]
paths.append(path)
X_list = []
for path in paths:
vec = load_image_as_vector(path)
X_list.append(vec)
X = np.array(X_list)
preds = model.predict(X)
with open(output_csv, 'w', newline="") as f:
writer = csv.writer(f)
writer.writerow(['path', 'label'])
for path, label in zip(paths, preds):
writer.writerow([path, label])