File size: 8,585 Bytes
9e1a324
 
 
 
 
5d71831
9e1a324
5d71831
9e1a324
 
 
dcb7fcc
9e1a324
 
 
 
 
2277fc2
 
5d71831
9e1a324
 
 
 
dcb7fcc
 
 
 
 
 
 
9e1a324
2277fc2
9e1a324
5d71831
 
9e1a324
 
5d71831
9e1a324
5d71831
 
 
 
72e2859
 
 
2277fc2
72e2859
 
 
2277fc2
 
 
 
 
 
 
 
5d71831
72e2859
9e1a324
 
 
9f52a10
9e1a324
9f52a10
9e1a324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f52a10
 
 
9e1a324
9f52a10
9e1a324
 
 
 
9f52a10
 
9e1a324
5d71831
 
 
 
 
2277fc2
5d71831
 
 
 
 
 
 
 
 
9f52a10
 
9e1a324
 
9f52a10
9e1a324
ded317b
 
 
 
1458622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ded317b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1458622
ded317b
 
1458622
ded317b
 
 
 
 
 
 
1458622
 
 
 
 
 
ded317b
1458622
ded317b
1458622
ded317b
 
 
 
 
 
 
9e1a324
 
 
 
 
 
 
 
 
648852a
 
 
 
 
 
 
 
 
9e1a324
648852a
 
9e1a324
 
9f52a10
 
 
9e1a324
 
 
 
9f52a10
9e1a324
 
fecdfaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Cell 6: Gradio App (Parquet-Based Search Engine)
import gradio as gr
import pandas as pd
import numpy as np
import os
import io
from deepface import DeepFace
from PIL import Image, ImageDraw, ImageFont
from sklearn.metrics.pairwise import cosine_similarity

# --- 1. Load the Knowledge Base (Specific Target) ---
TARGET_DB = "famous_faces_ArcFace_standalone.parquet"

if os.path.exists(TARGET_DB):
    DB_PATH = TARGET_DB
    print(f"πŸ“‚ Loaded Knowledge Base: {DB_PATH}")
    df_db = pd.read_parquet(DB_PATH)
    print(f"πŸ“Š Database columns: {df_db.columns.tolist()}")
    print(f"πŸ“Š Database shape: {df_db.shape}")
    
    # Convert embedding column to a clean numpy matrix for fast math
    DB_VECTORS = np.stack(df_db['embedding'].values)
    
    # Identify Model Name from filename
    filename_no_ext = os.path.basename(DB_PATH).replace(".parquet", "")
    parts = filename_no_ext.split("_")
    
    if parts[-1] == "standalone":
        MODEL_NAME = parts[-2]
    else:
        MODEL_NAME = parts[-1]
    print(f"βš™οΈ Model configured: {MODEL_NAME}")
    
else:
    print("❌ CRITICAL: Parquet file not found!")
    DB_PATH = None
    DB_VECTORS = None
    MODEL_NAME = "Unknown"
    df_db = None


# --- 2. Define the Search Logic ---
def create_placeholder(name):
    """Creates a placeholder image with the name if the actual image is missing."""
    placeholder = Image.new('RGB', (200, 200), color=(220, 220, 220))
    draw = ImageDraw.Draw(placeholder)
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
    except:
        font = ImageFont.load_default()
    
    text_lines = [name[:15], "Image", "Not Found"]
    y_offset = 50
    for line in text_lines:
        bbox = draw.textbbox((0, 0), line, font=font)
        text_width = bbox[2] - bbox[0]
        position = ((200 - text_width) // 2, y_offset)
        draw.text(position, line, fill=(100, 100, 100), font=font)
        y_offset += 30
    return placeholder

def find_best_matches(user_image):
    # Error handling for empty inputs
    if user_image is None: 
        return [], "No Image"
    if DB_VECTORS is None:
        return [], "System Error: No DB"
    
    try:
        # A. Get User Embedding
        user_embedding_obj = DeepFace.represent(
            img_path=user_image,
            model_name=MODEL_NAME,
            enforce_detection=False
        )
        user_vector = user_embedding_obj[0]["embedding"]
        
        # B. Calculate Cosine Similarity
        user_vector = np.array(user_vector).reshape(1, -1)
        similarities = cosine_similarity(user_vector, DB_VECTORS)[0]
        
        # C. Get Top 3 Indices
        top_indices = np.argsort(similarities)[::-1][:3]
        
        # Prepare Output Lists
        gallery_images = []
        result_text = "## Top 3 Matches:\n\n"
        
        for i, idx in enumerate(top_indices, 1):
            score = similarities[idx]
            row = df_db.iloc[idx]
            
            # Format Name
            display_name = f"{row['name']} (Match: {int(score*100)}%)"
            result_text += f"### #{i}: {display_name}\n\n"
            
            # Load Image from Bytes (Parquet)
            img = None
            if 'image_bytes' in df_db.columns and row.get('image_bytes') is not None:
                try:
                    img = Image.open(io.BytesIO(row['image_bytes']))
                    result_text += f"βœ“ Image loaded from parquet\n\n"
                except Exception as e:
                    print(f"⚠️ Could not load image bytes for {row['name']}: {e}")
            
            # Fallback to placeholder
            if img is None:
                img = create_placeholder(row['name'])
                result_text += f"⚠️ Image not found (Bytes missing)\n\n"
                
            gallery_images.append((img, display_name))
        
        return gallery_images, result_text

    except Exception as e:
        return [], f"Error: {str(e)}"

# --- 3. Monkey Patch to Fix Gradio API Bug ---
# The error occurs in gradio_client/utils.py line 882: if "const" in schema:
# but schema is a boolean instead of a dict. We'll patch multiple places.

# Patch 1: Fix the _json_schema_to_python_type function to handle boolean schemas
try:
    from gradio_client import utils as client_utils
    
    _original_json_schema_to_python_type = client_utils._json_schema_to_python_type
    
    def safe_json_schema_to_python_type(schema, defs=None):
        """Safe version that handles boolean schemas"""
        if isinstance(schema, bool):
            # Boolean schemas are not valid JSON schemas, return "Any"
            return "Any"
        if not isinstance(schema, dict):
            # If schema is not a dict or bool, return "Any"
            return "Any"
        try:
            return _original_json_schema_to_python_type(schema, defs)
        except Exception as e:
            # Catch any parsing errors and return "Any"
            print(f"Warning: Schema parsing failed: {e}, returning 'Any'")
            return "Any"
    
    client_utils._json_schema_to_python_type = safe_json_schema_to_python_type
    print("βœ“ Patched gradio_client.utils._json_schema_to_python_type")
except Exception as e:
    print(f"Could not patch _json_schema_to_python_type: {e}")

# Patch 2: Fix the get_type function in gradio_client.utils
try:
    from gradio_client import utils as client_utils
    
    _original_get_type = client_utils.get_type
    
    def safe_get_type(schema):
        """Safe version that checks if schema is a dict before using 'in'"""
        if not isinstance(schema, dict):
            # If schema is not a dict (e.g., it's a bool), return a default type
            return "Any"
        return _original_get_type(schema)
    
    client_utils.get_type = safe_get_type
    print("βœ“ Patched gradio_client.utils.get_type")
except Exception as e:
    print(f"Could not patch gradio_client.utils.get_type: {e}")

# Patch 3: Fix the get_api_info method in gradio.blocks
try:
    from gradio import blocks
    from gradio_client.utils import APIInfoParseError
    
    _original_get_api_info = blocks.Blocks.get_api_info
    
    def safe_get_api_info(self):
        """Safe version that catches schema introspection errors"""
        try:
            return _original_get_api_info(self)
        except (TypeError, APIInfoParseError) as e:
            error_str = str(e)
            if ("argument of type 'bool' is not iterable" in error_str or 
                "'bool' is not iterable" in error_str or
                "Cannot parse schema True" in error_str or
                "Cannot parse schema False" in error_str):
                # Known Gradio bug - return empty API info
                print(f"Warning: Caught Gradio API schema bug ({type(e).__name__}), returning empty API info")
                return {}
            raise  # Re-raise if it's a different error
    
    blocks.Blocks.get_api_info = safe_get_api_info
    print("βœ“ Patched gradio.blocks.Blocks.get_api_info")
except Exception as e:
    print(f"Could not patch gradio.blocks.get_api_info: {e}")

# --- 4. Build Interface ---
with gr.Blocks(title="Famous Face Matcher") as demo:
    gr.Markdown("# 🎭 Who is your Celebrity Twin?")
    gr.Markdown(f"Searching **{len(df_db) if df_db is not None else 0} faces** using **{MODEL_NAME}**.")
    
    with gr.Row():
        with gr.Column():
            user_input = gr.Image(sources=["upload", "webcam"], type="numpy", label="Your Photo")
            btn = gr.Button("Find Match", variant="primary")
            
            # Demo button to load test image
            demo_btn = gr.Button("Load Demo Image", variant="secondary")
            def load_demo_image():
                if os.path.exists("Sahar_Millis.png"):
                    img = Image.open("Sahar_Millis.png")
                    return np.array(img)
                return None
            demo_btn.click(fn=load_demo_image, outputs=user_input)
            
            # Safe Example loading
            if os.path.exists("Sahar_Millis.png"):
                gr.Examples(examples=[["Sahar_Millis.png"]], inputs=user_input)

        with gr.Column():
            gallery = gr.Gallery(label="Top 3 Matches", show_label=True, elem_id="gallery", 
                                columns=3, rows=1, height="auto")
            results_md = gr.Markdown(label="Match Details")

    btn.click(
        fn=find_best_matches, 
        inputs=user_input, 
        outputs=[gallery, results_md]
    )

demo.launch(show_api=False)