laba2 / evaluate.py
Bellou1337's picture
feat: svm model
711e816 verified
import argparse
import csv
from sklearn.metrics import accuracy_score, confusion_matrix
parser = argparse.ArgumentParser()
parser.add_argument(
"--ground-truth",
required=True
)
parser.add_argument(
"--predictions",
required=True
)
args = parser.parse_args()
ground_truth_csv = args.ground_truth
predictions_csv = args.predictions
gt = {}
with open(ground_truth_csv, newline="") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
path = row[0]
label = row[1]
gt[path] = label
preds = {}
with open(predictions_csv, newline="") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
path = row[0]
label = row[1]
preds[path] = label
y_true = []
y_pred = []
for path, true_label in gt.items():
if path in preds:
y_true.append(true_label)
y_pred.append(preds[path])
acc = accuracy_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
print("Accuracy:", acc)
print("Confusion matrix:")
for row in cm:
print(" ".join(str(x) for x in row))