cc1234 commited on
Commit
39ae7cb
·
0 Parent(s):

Initial commit: MegaFace facial recognition system

Browse files

Add complete facial recognition server with:
- DeepFace-based face detection and embedding generation
- Voyager vector similarity search indices
- SQLite performer database integration
- Gradio web interface
- ArcFace model weights and configurations
- Project documentation and dependencies

.deepface/weights/arcface_weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6336979c0c602cae08d1122a66f4dfb862d059bbcd8ef80306aef2b2249b0c93
3
+ size 137026640
.deepface/weights/yolov8n-face.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d545bf1add5aa736a4febac4f4f9245a6d596cd0fe70d5d57989fe0cb9e626ca
3
+ size 6389512
.gitattributes ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.db filter=lfs diff=lfs merge=lfs -text
37
+ face.json filter=lfs diff=lfs merge=lfs -text
38
+ .deepface/weights/yolov8n-face.pt filter=lfs diff=lfs merge=lfs -text
39
+ .deepface/weights/face_recognition_sface_2021dec.onnx filter=lfs diff=lfs merge=lfs -text
40
+ .deepface/weights/res10_300x300_ssd_iter_140000.caffemodel filter=lfs diff=lfs merge=lfs -text
41
+ .deepface/weights/centerface.onnx filter=lfs diff=lfs merge=lfs -text
42
+ .deepface/weights/deploy.prototxt filter=lfs diff=lfs merge=lfs -text
43
+ .deepface/weights/facenet512_weights.h5 filter=lfs diff=lfs merge=lfs -text
44
+ .deepface/weights/retinaface.h5 filter=lfs diff=lfs merge=lfs -text
45
+ .deepface/weights/face_detection_yunet_2023mar.onnx filter=lfs diff=lfs merge=lfs -text
46
+ .deepface/weights/arcface_weights.h5 filter=lfs diff=lfs merge=lfs -text
47
+ face_arc.voy filter=lfs diff=lfs merge=lfs -text
48
+ face_facenet.voy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv
2
+ flagged
3
+ temp.jpg
4
+ __pycache__
5
+ data/performers.json
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stashface
3
+ emoji: 👀
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.25.2
8
+ app_file: app.py
9
+ python_version: 3.11
10
+ pinned: false
11
+ license: mit
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Set DeepFace home directory
4
+ os.environ["DEEPFACE_HOME"] = "."
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6
+
7
+ from models.data_manager import DataManager
8
+ from web.interface import WebInterface
9
+
10
+ def main():
11
+ """Main entry point for the application"""
12
+ # Initialize data manager
13
+ data_manager = DataManager(
14
+ faces_path="data/faces.json",
15
+ arc_index_path="data/face_arc.voy"
16
+ )
17
+
18
+ # Initialize and launch web interface
19
+ web_interface = WebInterface(data_manager, default_threshold=0.5)
20
+ web_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
21
+
22
+ if __name__ == "__main__":
23
+ main()
data/face_arc.voy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ae15849f24e304e8a86dd125d0ff72169b5af4febafc9e24ed48dbbb0cfe68f
3
+ size 322825475
data/peeps.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa96f486d191c8adb48bc4735d81cb08bba1e7b7ad1c32320ccc80d46a6646c2
3
+ size 146644992
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # models package
models/data_manager.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sqlite3
4
+ from typing import Dict, Any, Optional
5
+ from voyager import Index, Space, StorageDataType
6
+
7
+ class DataManager:
8
+ def __init__(self, faces_path: str = "data/faces.json",
9
+ arc_index_path: str = "data/face_arc.voy",
10
+ performers_db_path: str = "data/peeps.db"):
11
+ """
12
+ Initialize the data manager.
13
+
14
+ Parameters:
15
+ faces_path: Path to the faces.json file
16
+ performers_zip: Path to the performers zip file
17
+ facenet_index_path: Path to the facenet index file
18
+ arc_index_path: Path to the arc index file
19
+ """
20
+ self.faces_path = faces_path
21
+ self.arc_index_path = arc_index_path
22
+ self.performers_db_path = performers_db_path
23
+
24
+ # Initialize indices
25
+ self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
26
+
27
+ # Load data
28
+ self.faces = {}
29
+ self.load_data()
30
+
31
+ def load_data(self):
32
+ """Load all data from files"""
33
+ self._load_faces()
34
+ self._load_indices()
35
+
36
+ def _load_faces(self):
37
+ """Load faces from JSON file"""
38
+ try:
39
+ with open(self.faces_path, 'r') as f:
40
+ self.faces = json.load(f)
41
+ except Exception as e:
42
+ print(f"Error loading faces: {e}")
43
+ self.faces = {}
44
+
45
+ def _load_indices(self):
46
+ """Load face recognition indices"""
47
+ try:
48
+ with open(self.arc_index_path, 'rb') as f:
49
+ self.index_arc = self.index_arc.load(f)
50
+ except Exception as e:
51
+ print(f"Error loading indices: {e}")
52
+
53
+ def get_performer_data(self, image_filename: str) -> Optional[Dict[str, str]]:
54
+ """
55
+ Look up performer data by image filename
56
+
57
+ Parameters:
58
+ image_filename: The image filename to look up
59
+
60
+ Returns:
61
+ Dict with name, url, and image_url or None if not found
62
+ """
63
+ try:
64
+ # Create a new connection for each query to avoid threading issues
65
+ with sqlite3.connect(self.performers_db_path) as conn:
66
+ cursor = conn.cursor()
67
+ cursor.execute('SELECT slug, url FROM performers WHERE image_filename = ?', (image_filename,))
68
+ result = cursor.fetchone()
69
+ if result:
70
+ return {
71
+ 'name': result[0],
72
+ 'url': result[1]
73
+ }
74
+ return None
75
+ except Exception as e:
76
+ print(f"Error querying performer database: {e}")
77
+ return None
78
+
79
+ def get_performer_info(self, id: str, confidence: float) -> Optional[Dict[str, Any]]:
80
+ """
81
+ Get performer information from the database
82
+
83
+ Parameters:
84
+ stash_id: Stash ID of the performer
85
+ confidence: Confidence score (0-1)
86
+
87
+ Returns:
88
+ Dictionary with performer information or None if not found
89
+ """
90
+
91
+ confidence_int = int(confidence * 100)
92
+ filename = os.path.basename(id)
93
+
94
+ # Try to get performer data from database
95
+ performer_data = self.get_performer_data(filename)
96
+ name = filename.replace('.jpg', '').replace('.png', '').replace('.jpeg', '')
97
+
98
+ if performer_data:
99
+ if performer_data['name'] != "NULL":
100
+ name = performer_data['name'] or name
101
+ url = performer_data['url']
102
+ else:
103
+ url = None
104
+
105
+ image_url = 'https://meta4allphotos.s3.us-west-1.amazonaws.com/' + id
106
+
107
+ return {
108
+ 'id': id,
109
+ "name": name,
110
+ "confidence": confidence_int,
111
+ 'image': image_url,
112
+ 'distance': confidence_int,
113
+ 'url': url
114
+ }
115
+
116
+ def query_arc_index(self, embedding, limit):
117
+ """Query the arc index with an embedding"""
118
+ return self.index_arc.query(embedding, limit)
models/face_recognition.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict, List, Tuple
3
+
4
+ from deepface import DeepFace
5
+ from deepface.modules import modeling, preprocessing
6
+
7
+ class EnsembleFaceRecognition:
8
+ def __init__(self, model_weights: Dict[str, float] = None):
9
+ """
10
+ Initialize ensemble face recognition system.
11
+
12
+ Parameters:
13
+ model_weights: Dictionary mapping model names to their weights
14
+ If None, all models are weighted equally
15
+ """
16
+ self.model_weights = model_weights or {}
17
+ self.boost_factor = 1.8
18
+
19
+ def normalize_distances(self, distances: np.ndarray) -> np.ndarray:
20
+ """Normalize distances to [0,1] range within each model's predictions"""
21
+ min_dist = np.min(distances)
22
+ max_dist = np.max(distances)
23
+ if max_dist == min_dist:
24
+ return np.zeros_like(distances)
25
+ return (distances - min_dist) / (max_dist - min_dist)
26
+
27
+ def compute_model_confidence(self,
28
+ distances: np.ndarray,
29
+ temperature: float = 0.1) -> np.ndarray:
30
+ """Convert distances to confidence scores for a single model"""
31
+ normalized_distances = self.normalize_distances(distances)
32
+ exp_distances = np.exp(-normalized_distances / temperature)
33
+ return exp_distances / np.sum(exp_distances)
34
+
35
+ def _preprocess_face_batch(self, faces: np.ndarray, target_size: Tuple[int, int], normalization: str) -> np.ndarray:
36
+ """Preprocess a batch of face images for model inference"""
37
+ batch_size = faces.shape[0]
38
+ processed_faces = []
39
+
40
+ for i in range(batch_size):
41
+ face = faces[i]
42
+ # Convert RGB to BGR (DeepFace expects BGR)
43
+ face = face[:, :, ::-1]
44
+
45
+ # Resize to model input size
46
+ resized = preprocessing.resize_image(face, target_size)
47
+
48
+ # Normalize
49
+ normalized = preprocessing.normalize_input(resized, normalization)
50
+
51
+ processed_faces.append(normalized)
52
+
53
+ # Stack into batch and remove the extra dimension added by resize_image
54
+ batch = np.vstack(processed_faces)
55
+ return batch
56
+
57
+ def get_face_embeddings_batch(self, faces: np.ndarray) -> Dict[str, np.ndarray]:
58
+ """Get face embeddings for a batch of images efficiently
59
+
60
+ Args:
61
+ faces: np.ndarray of shape (batch_size, height, width, channels)
62
+
63
+ Returns:
64
+ Dict with 'facenet' and 'arc' keys containing batched embeddings
65
+ """
66
+ # Load models (cached by DeepFace)
67
+ arcface_model = modeling.build_model(task="facial_recognition", model_name="ArcFace")
68
+
69
+ # Preprocess faces for each model
70
+ arcface_batch = self._preprocess_face_batch(faces, arcface_model.input_shape, "ArcFace")
71
+
72
+ # Get embeddings using direct model inference (bypassing DeepFace.represent)
73
+ arcface_embeddings = arcface_model.model(arcface_batch, training=False).numpy()
74
+
75
+ return {
76
+ 'arc': arcface_embeddings
77
+ }
78
+
79
+ def ensemble_prediction(self,
80
+ model_predictions: Dict[str, Tuple[List[str], List[float]]],
81
+ temperature: float = 0.1,
82
+ min_agreement: float = 0.5) -> List[Tuple[str, float]]:
83
+ """
84
+ Combine predictions from multiple models.
85
+
86
+ Parameters:
87
+ model_predictions: Dictionary mapping model names to their (distances, names) predictions
88
+ temperature: Temperature parameter for softmax scaling
89
+ min_agreement: Minimum agreement threshold between models
90
+
91
+ Returns:
92
+ final_predictions: List of (name, confidence) tuples
93
+ """
94
+ # Initialize vote counting
95
+ vote_dict = {}
96
+ confidence_dict = {}
97
+
98
+ # Process each model's predictions
99
+ for model_name, (names, distances) in model_predictions.items():
100
+ # Get model weight (default to 1.0 if not specified)
101
+ model_weight = self.model_weights.get(model_name, 1.0)
102
+
103
+ # Compute confidence scores for this model
104
+ confidences = self.compute_model_confidence(np.array(distances), temperature)
105
+
106
+ # Add weighted votes for top prediction
107
+ top_name = names[0]
108
+ top_confidence = confidences[0]
109
+
110
+ vote_dict[top_name] = vote_dict.get(top_name, 0) + model_weight
111
+ confidence_dict[top_name] = confidence_dict.get(top_name, [])
112
+ confidence_dict[top_name].append(top_confidence)
113
+
114
+ # Normalize votes
115
+ total_weight = sum(self.model_weights.values()) if self.model_weights else len(model_predictions)
116
+
117
+ # Compute final results with minimum agreement check
118
+ final_results = []
119
+ for name, votes in vote_dict.items():
120
+ normalized_votes = votes / total_weight
121
+ # Only include results that meet minimum agreement threshold
122
+ if normalized_votes >= min_agreement:
123
+ avg_confidence = np.mean(confidence_dict[name])
124
+ final_score = normalized_votes * avg_confidence * self.boost_factor
125
+ final_score = min(final_score, 1.0) # Cap at 1.0
126
+ final_results.append((name, final_score))
127
+
128
+ # Sort by final score
129
+ final_results.sort(key=lambda x: x[1], reverse=True)
130
+ return final_results
131
+
132
+ def extract_faces(image):
133
+ """Extract faces from an image using DeepFace"""
134
+ return DeepFace.extract_faces(image, detector_backend="yolov8")
models/image_processor.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import numpy as np
4
+
5
+ from models.face_recognition import EnsembleFaceRecognition, extract_faces
6
+
7
+
8
+
9
+ def get_face_predictions(face, ensemble, data_manager, results):
10
+ """
11
+ Get predictions for a single face
12
+
13
+ Parameters:
14
+ face: Face image array
15
+ ensemble: EnsembleFaceRecognition instance
16
+ data_manager: DataManager instance
17
+ results: Number of results to return
18
+
19
+ Returns:
20
+ List of (name, confidence) tuples
21
+ """
22
+ # Create batch with original and flipped images
23
+ face_batch = np.stack([face, face[:, ::-1, :]], axis=0)
24
+
25
+ # Get embeddings for both orientations in one batch call
26
+ embeddings_batch = ensemble.get_face_embeddings_batch(face_batch)
27
+ arc = np.mean(embeddings_batch['arc'], axis=0)
28
+
29
+ # Get predictions from both models
30
+ query_limit = max(results, 50)
31
+ arc_raw = data_manager.query_arc_index(arc, query_limit)
32
+
33
+ return ensemble.ensemble_prediction({'arc': arc_raw})
34
+
35
+
36
+ def image_search_performers(image, data_manager, threshold=0.5, results=3):
37
+ """
38
+ Search for multiple performers in an image
39
+
40
+ Parameters:
41
+ image: PIL Image object
42
+ data_manager: DataManager instance
43
+ threshold: Confidence threshold
44
+ results: Number of results to return
45
+
46
+ Returns:
47
+ List of dictionaries with face image and performer information
48
+ """
49
+ image_array = np.array(image)
50
+ ensemble = EnsembleFaceRecognition({"arc": 1.0})
51
+
52
+ try:
53
+ faces = extract_faces(image_array)
54
+ except ValueError:
55
+ raise ValueError("No faces found")
56
+
57
+ response = []
58
+ for face in faces:
59
+ predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
60
+
61
+ # Crop and encode face image
62
+ area = face['facial_area']
63
+ cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
64
+ buf = io.BytesIO()
65
+ cimage.save(buf, format='JPEG')
66
+ im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
67
+
68
+ # Get performer information
69
+ performers = []
70
+ for name, confidence in predictions:
71
+ performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
72
+ if performer_info:
73
+ performers.append(performer_info)
74
+
75
+ response.append({
76
+ 'image': im_b64,
77
+ 'area': area,
78
+ 'confidence': face['confidence'],
79
+ 'performers': performers
80
+ })
81
+ return response
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "stashface"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "deepface",
9
+ "gradio==5.25.2",
10
+ "mediapipe>=0.10.21",
11
+ "pyzipper==0.3.6",
12
+ "retina-face==0.0.17",
13
+ "tensorflow==2.14.1",
14
+ "tf-keras==2.14.1",
15
+ "ultralytics==8.3.69",
16
+ "voyager==2.1.0",
17
+ ]
18
+
19
+ [tool.uv.sources]
20
+ deepface = { git = "https://github.com/serengil/deepface.git", rev = "cc484b54be5188eb47faf132995af16a871d70b9" }
requirements.txt ADDED
File without changes
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
web/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # web package
web/interface.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import io
4
+ from PIL import Image as PILImage
5
+
6
+ from models.data_manager import DataManager
7
+ from models.image_processor import (
8
+ image_search_performers,
9
+ )
10
+
11
+ class WebInterface:
12
+ def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
13
+ """
14
+ Initialize the web interface.
15
+
16
+ Parameters:
17
+ data_manager: DataManager instance
18
+ default_threshold: Default confidence threshold
19
+ """
20
+ self.data_manager = data_manager
21
+ self.default_threshold = default_threshold
22
+
23
+ def multiple_image_search(self, img):
24
+ """Wrapper for the multiple image search function"""
25
+ try:
26
+ # Use default values: threshold=0.5, results=3
27
+ return image_search_performers(img, self.data_manager, 0.5, 3)
28
+ except ValueError as e:
29
+ if "No faces found" in str(e):
30
+ return {"error": "No faces detected in the uploaded image. Please try uploading an image with visible faces."}
31
+ else:
32
+ raise e
33
+
34
+ def format_results_for_visual_display(self, json_results):
35
+ """
36
+ Convert JSON results to visual components for better UX
37
+
38
+ Parameters:
39
+ json_results: List of face detection results from image_search_performers
40
+
41
+ Returns:
42
+ tuple: (gallery_images, html_content)
43
+ """
44
+ if not json_results:
45
+ return [], "<p>No faces detected or no matches found.</p>"
46
+
47
+ # Handle error case
48
+ if isinstance(json_results, dict) and "error" in json_results:
49
+ error_html = f"""
50
+ <div class="performer-card">
51
+ <div class="face-info">
52
+ <h3 style="color: #ff6b6b;">Error</h3>
53
+ <p>{json_results['error']}</p>
54
+ </div>
55
+ </div>
56
+ """
57
+ return [], error_html
58
+
59
+ gallery_images = []
60
+ html_parts = []
61
+
62
+ html_parts.append("""
63
+ <style>
64
+ body, .gradio-container {
65
+ background-color: #1e1e1e !important;
66
+ color: #d4d4d4 !important;
67
+ }
68
+ .performer-card {
69
+ border: 1px solid #404040;
70
+ border-radius: 12px;
71
+ padding: 24px;
72
+ margin: 16px 0;
73
+ background: #2d2d2d;
74
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3);
75
+ color: #d4d4d4;
76
+ }
77
+ .face-info {
78
+ background: #3c3c3c;
79
+ padding: 20px;
80
+ border-radius: 8px;
81
+ margin-bottom: 24px;
82
+ border: 1px solid #4a4a4a;
83
+ display: flex;
84
+ align-items: flex-start;
85
+ gap: 20px;
86
+ }
87
+ .face-info-content {
88
+ flex: 1;
89
+ }
90
+ .face-info h3 {
91
+ color: #ffffff;
92
+ margin-top: 0;
93
+ font-size: 1.4em;
94
+ }
95
+ .performer-grid {
96
+ display: grid;
97
+ grid-template-columns: repeat(auto-fit, minmax(350px, 1fr));
98
+ gap: 24px;
99
+ margin-top: 16px;
100
+ }
101
+ .performer-item {
102
+ border: 1px solid #4a4a4a;
103
+ border-radius: 12px;
104
+ padding: 24px;
105
+ background: #333333;
106
+ text-align: center;
107
+ transition: all 0.3s ease;
108
+ box-shadow: 0 2px 8px rgba(0,0,0,0.2);
109
+ display: flex;
110
+ flex-direction: column;
111
+ align-items: center;
112
+ }
113
+ .performer-item:hover {
114
+ border-color: #569cd6;
115
+ box-shadow: 0 4px 16px rgba(0,0,0,0.4);
116
+ transform: translateY(-2px);
117
+ }
118
+ .performer-image {
119
+ width: 120px;
120
+ height: 120px;
121
+ border-radius: 12px;
122
+ object-fit: cover;
123
+ margin: 0 auto 16px auto;
124
+ display: block;
125
+ border: 2px solid #4a4a4a;
126
+ transition: all 0.3s ease;
127
+ text-align: center;
128
+ }
129
+ .performer-image:hover {
130
+ border-color: #569cd6;
131
+ transform: scale(1.05);
132
+ }
133
+ .performer-item h4 {
134
+ color: #ffffff;
135
+ margin: 16px 0 8px 0;
136
+ font-size: 1.2em;
137
+ }
138
+ .performer-item h4 a {
139
+ color: #569cd6;
140
+ text-decoration: none;
141
+ transition: color 0.3s ease;
142
+ }
143
+ .performer-item h4 a:hover {
144
+ color: #9cdcfe;
145
+ text-decoration: underline;
146
+ }
147
+ .performer-item p {
148
+ color: #cccccc;
149
+ margin: 8px 0;
150
+ }
151
+ .performer-item small {
152
+ color: #999999;
153
+ }
154
+ .confidence-bar {
155
+ background: #404040;
156
+ border-radius: 12px;
157
+ overflow: hidden;
158
+ height: 28px;
159
+ margin: 12px 0;
160
+ border: 1px solid #4a4a4a;
161
+ width: 100%;
162
+ max-width: 200px;
163
+ }
164
+ .confidence-fill {
165
+ height: 100%;
166
+ transition: width 0.5s ease;
167
+ text-align: center;
168
+ line-height: 28px;
169
+ color: white;
170
+ font-size: 13px;
171
+ font-weight: bold;
172
+ text-shadow: 0 1px 2px rgba(0,0,0,0.5);
173
+ }
174
+ .high-confidence {
175
+ background: linear-gradient(135deg, #4caf50, #66bb6a);
176
+ }
177
+ .medium-confidence {
178
+ background: linear-gradient(135deg, #ff9800, #ffb74d);
179
+ }
180
+ .low-confidence {
181
+ background: linear-gradient(135deg, #f44336, #ef5350);
182
+ }
183
+ .face-info p strong {
184
+ color: #9cdcfe;
185
+ }
186
+ .country-flag {
187
+ font-size: 1.2em;
188
+ margin-right: 6px;
189
+ vertical-align: middle;
190
+ }
191
+ </style>
192
+ """)
193
+
194
+ for i, face_result in enumerate(json_results):
195
+ # Convert base64 face image to PIL for gallery
196
+ try:
197
+ face_image_data = base64.b64decode(face_result['image'])
198
+ face_pil = PILImage.open(io.BytesIO(face_image_data))
199
+ gallery_images.append(face_pil)
200
+ except Exception as e:
201
+ print(f"Error decoding face image: {e}")
202
+ continue
203
+
204
+ # Create HTML for this face
205
+ face_confidence = face_result['confidence']
206
+ performers = face_result['performers']
207
+
208
+ # Create base64 data URL for the detected face image
209
+ face_image_b64 = f"data:image/jpeg;base64,{face_result['image']}"
210
+
211
+ html_parts.append(f"""
212
+ <div class="performer-card">
213
+ <div class="face-info">
214
+ <div class="detected-face">
215
+ <img src="{face_image_b64}" alt="Detected Face {i+1}" style="width: 120px; height: 120px; border-radius: 12px; object-fit: cover; border: 2px solid #569cd6; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">
216
+ </div>
217
+ <div class="face-info-content">
218
+ <h3>Face {i+1}</h3>
219
+ <p><strong>Detection Confidence:</strong> {face_confidence:.1%}</p>
220
+ <p><strong>Matches Found:</strong> {len(performers)}</p>
221
+ </div>
222
+ </div>
223
+ """)
224
+
225
+ if performers:
226
+ html_parts.append('<div class="performer-grid">')
227
+ for performer in performers:
228
+ confidence_class = "high-confidence" if performer['confidence'] >= 80 else "medium-confidence" if performer['confidence'] >= 60 else "low-confidence"
229
+
230
+ # Create performer name with link if URL exists
231
+ performer_name = performer['name']
232
+ if performer.get('url'):
233
+ performer_name = f'<a href="{performer["url"]}" target="_blank">{performer["name"]}</a>'
234
+
235
+ html_parts.append(f"""
236
+ <div class="performer-item">
237
+ <img src="{performer['image']}" alt="{performer['name']}" class="performer-image" onerror="this.style.display='none'">
238
+ <h4>{performer_name}</h4>
239
+ <div class="confidence-bar">
240
+ <div class="confidence-fill {confidence_class}" style="width: {performer['confidence']}%">
241
+ {performer['confidence']}%
242
+ </div>
243
+ </div>
244
+ <p><small>Distance: {performer.get('distance', 'N/A')}</small></p>
245
+ </div>
246
+ """)
247
+ html_parts.append('</div>')
248
+ else:
249
+ html_parts.append('<p><em>No performer matches found for this face.</em></p>')
250
+
251
+ html_parts.append('</div>')
252
+
253
+ return gallery_images, ''.join(html_parts)
254
+
255
+ def multiple_image_search_with_visual(self, img):
256
+ """
257
+ Enhanced search function that returns both JSON and visual components
258
+
259
+ Returns:
260
+ tuple: (json_results, gallery_images, html_content)
261
+ """
262
+ try:
263
+ json_results = self.multiple_image_search(img)
264
+ gallery_images, html_content = self.format_results_for_visual_display(json_results)
265
+ return json_results, gallery_images, html_content
266
+ except Exception as e:
267
+ error_msg = f"<div class='performer-card'><h3>Error</h3><p>{str(e)}</p></div>"
268
+ return [], [], error_msg
269
+
270
+ def _create_visual_search_interface(self):
271
+ """Create the visual search interface"""
272
+ with gr.Blocks() as interface:
273
+ gr.Markdown("# Who is in the photo?")
274
+ gr.Markdown("Upload an image of a person(s) and we'll show you who it is with photos and details.")
275
+
276
+ with gr.Row():
277
+ with gr.Column():
278
+ img_input = gr.Image(type="pil")
279
+ search_btn = gr.Button("Search")
280
+
281
+ with gr.Column():
282
+ performer_info = gr.HTML(
283
+ label="Performer Information",
284
+ value="<p>Upload an image and click search to see results.</p>"
285
+ )
286
+
287
+ def visual_search_wrapper(img):
288
+ """Wrapper that returns only visual components"""
289
+ json_results, gallery_images, html_content = self.multiple_image_search_with_visual(img)
290
+ return html_content
291
+
292
+ search_btn.click(
293
+ fn=visual_search_wrapper,
294
+ inputs=[img_input],
295
+ outputs=[performer_info],
296
+ api_name="multiple_image_search_with_visual"
297
+ )
298
+
299
+ return interface
300
+
301
+
302
+ def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
303
+ """Launch the web interface"""
304
+ with gr.Blocks(
305
+ css="""
306
+ .gradio-container {
307
+ background-color: #1e1e1e !important;
308
+ color: #d4d4d4 !important;
309
+ }
310
+ .dark {
311
+ --background-fill-primary: #2d2d2d;
312
+ --background-fill-secondary: #3c3c3c;
313
+ --border-color-primary: #404040;
314
+ --block-title-text-color: #ffffff;
315
+ --body-text-color: #d4d4d4;
316
+ }
317
+ """
318
+ ) as demo:
319
+ with gr.Tabs():
320
+ with gr.TabItem("Visual Search"):
321
+ self._create_visual_search_interface()
322
+
323
+ demo.queue().launch(server_name=server_name, server_port=server_port, share=share, ssr_mode=False)