File size: 4,010 Bytes
39ae7cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import os
import sqlite3
from typing import Dict, Any, Optional
from voyager import Index, Space, StorageDataType

class DataManager:
    def __init__(self, faces_path: str = "data/faces.json",
                 arc_index_path: str = "data/face_arc.voy",
                 performers_db_path: str = "data/peeps.db"):
        """
        Initialize the data manager.
        
        Parameters:
        faces_path: Path to the faces.json file
        performers_zip: Path to the performers zip file
        facenet_index_path: Path to the facenet index file
        arc_index_path: Path to the arc index file
        """
        self.faces_path = faces_path
        self.arc_index_path = arc_index_path
        self.performers_db_path = performers_db_path
        
        # Initialize indices
        self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
        
        # Load data
        self.faces = {}
        self.load_data()
    
    def load_data(self):
        """Load all data from files"""
        self._load_faces()
        self._load_indices()
    
    def _load_faces(self):
        """Load faces from JSON file"""
        try:
            with open(self.faces_path, 'r') as f:
                self.faces = json.load(f)
        except Exception as e:
            print(f"Error loading faces: {e}")
            self.faces = {}
    
    def _load_indices(self):
        """Load face recognition indices"""
        try:
            with open(self.arc_index_path, 'rb') as f:
                self.index_arc = self.index_arc.load(f)
        except Exception as e:
            print(f"Error loading indices: {e}")
    
    def get_performer_data(self, image_filename: str) -> Optional[Dict[str, str]]:
        """
        Look up performer data by image filename
        
        Parameters:
        image_filename: The image filename to look up
        
        Returns:
        Dict with name, url, and image_url or None if not found
        """
        try:
            # Create a new connection for each query to avoid threading issues
            with sqlite3.connect(self.performers_db_path) as conn:
                cursor = conn.cursor()
                cursor.execute('SELECT slug, url FROM performers WHERE image_filename = ?', (image_filename,))
                result = cursor.fetchone()
                if result:
                    return {
                        'name': result[0],
                        'url': result[1]
                    }
                return None
        except Exception as e:
            print(f"Error querying performer database: {e}")
            return None
    
    def get_performer_info(self, id: str, confidence: float) -> Optional[Dict[str, Any]]:
        """
        Get performer information from the database
        
        Parameters:
        stash_id: Stash ID of the performer
        confidence: Confidence score (0-1)
        
        Returns:
        Dictionary with performer information or None if not found
        """

        confidence_int = int(confidence * 100)
        filename = os.path.basename(id)
        
        # Try to get performer data from database
        performer_data = self.get_performer_data(filename)
        name = filename.replace('.jpg', '').replace('.png', '').replace('.jpeg', '')

        if performer_data:
            if performer_data['name'] != "NULL":
                name = performer_data['name'] or name
            url = performer_data['url']
        else:
            url = None

        image_url = 'https://meta4allphotos.s3.us-west-1.amazonaws.com/' + id
        
        return {  
            'id': id,
            "name": name,
            "confidence": confidence_int,
            'image': image_url,
            'distance': confidence_int,
            'url': url
        }
    
    def query_arc_index(self, embedding, limit):
        """Query the arc index with an embedding"""
        return self.index_arc.query(embedding, limit)