stashface / models /image_processor.py
cc1234's picture
feat: implement batch processing for face embeddings in EnsembleFaceRecognition
9b06224
raw
history blame
3.84 kB
import io
import base64
import numpy as np
from uuid import uuid4
from PIL import Image as PILImage
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
from utils.vtt_parser import parse_vtt_offsets
def get_face_predictions(face, ensemble, data_manager, results):
"""
Get predictions for a single face
Parameters:
face: Face image array
ensemble: EnsembleFaceRecognition instance
data_manager: DataManager instance
results: Number of results to return
Returns:
List of (name, confidence) tuples
"""
# Create batch with original and flipped images
face_batch = np.stack([face, face[:, ::-1, :]], axis=0)
# Get embeddings for both orientations in one batch call
embeddings_batch = ensemble.get_face_embeddings_batch(face_batch)
# Average the embeddings across orientations
facenet = np.mean(embeddings_batch['facenet'], axis=0)
arc = np.mean(embeddings_batch['arc'], axis=0)
# Get predictions from both models
model_predictions = {
'facenet': data_manager.query_facenet_index(facenet, max(results, 50)),
'arc': data_manager.query_arc_index(arc, max(results, 50)),
}
return ensemble.ensemble_prediction(model_predictions)
def image_search_performers(image, data_manager, threshold=0.5, results=3):
"""
Search for multiple performers in an image
Parameters:
image: PIL Image object
data_manager: DataManager instance
threshold: Confidence threshold
results: Number of results to return
Returns:
List of dictionaries with face image and performer information
"""
image_array = np.array(image)
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
try:
faces = extract_faces(image_array)
except ValueError:
raise ValueError("No faces found")
response = []
for face in faces:
predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
# Crop and encode face image
area = face['facial_area']
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
buf = io.BytesIO()
cimage.save(buf, format='JPEG')
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
# Get performer information
performers = []
for name, confidence in predictions:
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
if performer_info:
performers.append(performer_info)
response.append({
'image': im_b64,
'area': area,
'confidence': face['confidence'],
'performers': performers
})
return response
def find_faces_in_sprite(image, vtt_file):
"""
Find faces in a sprite image using VTT data
Parameters:
image: PIL Image object
vtt_file: File object containing VTT data
Returns:
List of dictionaries with face information
"""
with open(vtt_file.name, 'r', encoding='utf-8') as f:
vtt = f.read().encode('utf-8')
sprite = PILImage.fromarray(image)
results = []
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
cut_frame = sprite.crop((left, top, left + right, top + bottom))
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
faces = [face for face in faces if face['confidence'] > 0.6]
if faces:
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
results.append(data)