Car_VS_Rest / ROC_curve_TFlite_Model.py
Nekshay's picture
Create ROC_curve_TFlite_Model.py
993159a verified
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import cv2
import glob
import os
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
# Define paths
test_images_dir = "test_images/" # Path to test images
test_annotations_dir = "test_annotations/" # Path to Pascal VOC XML files
tflite_model_path = "efficientdet_lite0.tflite"
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
# Function to run inference
def run_inference(interpreter, image):
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Preprocess image
input_shape = input_details[0]['shape']
image = cv2.resize(image, (input_shape[1], input_shape[2])) # Resize to model input size
image = image.astype(np.float32) / 255.0 # Normalize
image = np.expand_dims(image, axis=0) # Add batch dimension
# Set input tensor
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
# Get output (bounding boxes and scores)
output_data = interpreter.get_tensor(output_details[0]['index'])
return output_data # Confidence scores
# Function to parse Pascal VOC XML annotation
def parse_voc_annotation(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
objects = root.findall("object")
return 1 if objects else 0 # If objects exist, return 1 (object present), else 0
# Load test images and annotations
image_files = glob.glob(os.path.join(test_images_dir, "*.jpg")) # Adjust if using .png
y_scores = []
y_true = []
for image_file in image_files:
# Load image
image = cv2.imread(image_file)
# Get corresponding XML annotation
xml_file = os.path.join(test_annotations_dir, os.path.splitext(os.path.basename(image_file))[0] + ".xml")
if not os.path.exists(xml_file):
continue # Skip if annotation is missing
# Get ground truth label (1 = object present, 0 = no object)
true_label = parse_voc_annotation(xml_file)
# Run inference
scores = run_inference(interpreter, image)
max_score = np.max(scores) # Get highest confidence score
# Append results
y_scores.append(max_score)
y_true.append(true_label)
# Convert to numpy arrays
y_scores = np.array(y_scores)
y_true = np.array(y_true)
# Compute ROC curve and AUC
fpr, tpr, _ = roc_curve(y_true, y_scores)
roc_auc = auc(fpr, tpr)
# Compute Precision-Recall curve and AP score
precision, recall, _ = precision_recall_curve(y_true, y_scores)
average_precision = average_precision_score(y_true, y_scores)
# Plot ROC Curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--') # Diagonal line
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.show()
# Plot Precision-Recall Curve
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='green', lw=2, label=f'PR Curve (AP = {average_precision:.2f})')
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.show()