megaface / models /data_manager.py
cc1234's picture
Initial commit: MegaFace facial recognition system
39ae7cb
import json
import os
import sqlite3
from typing import Dict, Any, Optional
from voyager import Index, Space, StorageDataType
class DataManager:
def __init__(self, faces_path: str = "data/faces.json",
arc_index_path: str = "data/face_arc.voy",
performers_db_path: str = "data/peeps.db"):
"""
Initialize the data manager.
Parameters:
faces_path: Path to the faces.json file
performers_zip: Path to the performers zip file
facenet_index_path: Path to the facenet index file
arc_index_path: Path to the arc index file
"""
self.faces_path = faces_path
self.arc_index_path = arc_index_path
self.performers_db_path = performers_db_path
# Initialize indices
self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
# Load data
self.faces = {}
self.load_data()
def load_data(self):
"""Load all data from files"""
self._load_faces()
self._load_indices()
def _load_faces(self):
"""Load faces from JSON file"""
try:
with open(self.faces_path, 'r') as f:
self.faces = json.load(f)
except Exception as e:
print(f"Error loading faces: {e}")
self.faces = {}
def _load_indices(self):
"""Load face recognition indices"""
try:
with open(self.arc_index_path, 'rb') as f:
self.index_arc = self.index_arc.load(f)
except Exception as e:
print(f"Error loading indices: {e}")
def get_performer_data(self, image_filename: str) -> Optional[Dict[str, str]]:
"""
Look up performer data by image filename
Parameters:
image_filename: The image filename to look up
Returns:
Dict with name, url, and image_url or None if not found
"""
try:
# Create a new connection for each query to avoid threading issues
with sqlite3.connect(self.performers_db_path) as conn:
cursor = conn.cursor()
cursor.execute('SELECT slug, url FROM performers WHERE image_filename = ?', (image_filename,))
result = cursor.fetchone()
if result:
return {
'name': result[0],
'url': result[1]
}
return None
except Exception as e:
print(f"Error querying performer database: {e}")
return None
def get_performer_info(self, id: str, confidence: float) -> Optional[Dict[str, Any]]:
"""
Get performer information from the database
Parameters:
stash_id: Stash ID of the performer
confidence: Confidence score (0-1)
Returns:
Dictionary with performer information or None if not found
"""
confidence_int = int(confidence * 100)
filename = os.path.basename(id)
# Try to get performer data from database
performer_data = self.get_performer_data(filename)
name = filename.replace('.jpg', '').replace('.png', '').replace('.jpeg', '')
if performer_data:
if performer_data['name'] != "NULL":
name = performer_data['name'] or name
url = performer_data['url']
else:
url = None
image_url = 'https://meta4allphotos.s3.us-west-1.amazonaws.com/' + id
return {
'id': id,
"name": name,
"confidence": confidence_int,
'image': image_url,
'distance': confidence_int,
'url': url
}
def query_arc_index(self, embedding, limit):
"""Query the arc index with an embedding"""
return self.index_arc.query(embedding, limit)