FaceRecognisation / face_recognition.py
ani0226's picture
Create face_recognition.py
4b4f0b9 verified
import numpy as np
from PIL import Image
import tensorflow as tf # Or import your PyTorch modules
# --- Load your trained HTCNN model ---
try:
htcnn_model = tf.keras.models.load_model('htcnn_model.h5') # Replace with your model path
except Exception as e:
print(f"Error loading HTCNN model: {e}")
htcnn_model = None
def extract_embedding(face_image):
"""
Extracts the feature embedding from a face image using the HTCNN model.
Args:
face_image (numpy.ndarray): The cropped face image.
Returns:
numpy.ndarray: The feature embedding, or None if an error occurs.
"""
if htcnn_model is None:
print("HTCNN model not loaded. Cannot extract embedding.")
return None
# Preprocess the face image (resize, normalize, etc.) as required by your model
resized_face = cv2.resize(face_image, (160, 160)) # Example size, adjust as needed
normalized_face = resized_face / 255.0 # Example normalization
# Ensure the input has the correct batch dimension
embedding = htcnn_model.predict(np.expand_dims(normalized_face, axis=0))[0]
return embedding
def recognize_face(face_image, face_embeddings_db, threshold=0.6):
"""
Recognizes a face by comparing its embedding with the embeddings in the database.
Args:
face_image (numpy.ndarray): The cropped face image.
face_embeddings_db (dict): Dictionary of known face embeddings (name: embedding).
threshold (float): The similarity threshold for recognition.
Returns:
str or None: The name of the recognized person, or None if no match is found.
"""
embedding = extract_embedding(face_image)
if embedding is None:
return None
min_distance = float('inf')
recognized_identity = None
for name, stored_embedding in face_embeddings_db.items():
# Calculate the distance (e.g., Euclidean distance) between the embeddings
distance = np.linalg.norm(embedding - stored_embedding)
if distance < min_distance:
min_distance = distance
recognized_identity = name
if min_distance < threshold:
return recognized_identity
else:
return None
if __name__ == '__main__':
# Example usage (requires having a trained HTCNN model and a database)
pass