|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import matplotlib.pyplot as plt |
|
|
import xml.etree.ElementTree as ET |
|
|
import cv2 |
|
|
import glob |
|
|
import os |
|
|
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score |
|
|
|
|
|
|
|
|
test_images_dir = "test_images/" |
|
|
test_annotations_dir = "test_annotations/" |
|
|
tflite_model_path = "efficientdet_lite0.tflite" |
|
|
|
|
|
|
|
|
interpreter = tf.lite.Interpreter(model_path=tflite_model_path) |
|
|
interpreter.allocate_tensors() |
|
|
|
|
|
|
|
|
def run_inference(interpreter, image): |
|
|
input_details = interpreter.get_input_details() |
|
|
output_details = interpreter.get_output_details() |
|
|
|
|
|
|
|
|
input_shape = input_details[0]['shape'] |
|
|
image = cv2.resize(image, (input_shape[1], input_shape[2])) |
|
|
image = image.astype(np.float32) / 255.0 |
|
|
image = np.expand_dims(image, axis=0) |
|
|
|
|
|
|
|
|
interpreter.set_tensor(input_details[0]['index'], image) |
|
|
interpreter.invoke() |
|
|
|
|
|
|
|
|
output_data = interpreter.get_tensor(output_details[0]['index']) |
|
|
return output_data |
|
|
|
|
|
|
|
|
def parse_voc_annotation(xml_file): |
|
|
tree = ET.parse(xml_file) |
|
|
root = tree.getroot() |
|
|
|
|
|
objects = root.findall("object") |
|
|
return 1 if objects else 0 |
|
|
|
|
|
|
|
|
image_files = glob.glob(os.path.join(test_images_dir, "*.jpg")) |
|
|
y_scores = [] |
|
|
y_true = [] |
|
|
|
|
|
for image_file in image_files: |
|
|
|
|
|
image = cv2.imread(image_file) |
|
|
|
|
|
|
|
|
xml_file = os.path.join(test_annotations_dir, os.path.splitext(os.path.basename(image_file))[0] + ".xml") |
|
|
if not os.path.exists(xml_file): |
|
|
continue |
|
|
|
|
|
|
|
|
true_label = parse_voc_annotation(xml_file) |
|
|
|
|
|
|
|
|
scores = run_inference(interpreter, image) |
|
|
max_score = np.max(scores) |
|
|
|
|
|
|
|
|
y_scores.append(max_score) |
|
|
y_true.append(true_label) |
|
|
|
|
|
|
|
|
y_scores = np.array(y_scores) |
|
|
y_true = np.array(y_true) |
|
|
|
|
|
|
|
|
fpr, tpr, _ = roc_curve(y_true, y_scores) |
|
|
roc_auc = auc(fpr, tpr) |
|
|
|
|
|
|
|
|
precision, recall, _ = precision_recall_curve(y_true, y_scores) |
|
|
average_precision = average_precision_score(y_true, y_scores) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.2f})') |
|
|
plt.plot([0, 1], [0, 1], color='gray', linestyle='--') |
|
|
plt.xlabel("False Positive Rate") |
|
|
plt.ylabel("True Positive Rate") |
|
|
plt.title("ROC Curve") |
|
|
plt.legend() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
plt.plot(recall, precision, color='green', lw=2, label=f'PR Curve (AP = {average_precision:.2f})') |
|
|
plt.xlabel("Recall") |
|
|
plt.ylabel("Precision") |
|
|
plt.title("Precision-Recall Curve") |
|
|
plt.legend() |
|
|
plt.show() |
|
|
|