megaface / web /interface.py
cc1234's picture
Update confidence scoring thresholds and remove distance display in performer results
4785fbb
import gradio as gr
import base64
import io
from PIL import Image as PILImage
from models.data_manager import DataManager
from models.image_processor import (
image_search_performers,
)
class WebInterface:
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
"""
Initialize the web interface.
Parameters:
data_manager: DataManager instance
default_threshold: Default confidence threshold
"""
self.data_manager = data_manager
self.default_threshold = default_threshold
def multiple_image_search(self, img):
"""Wrapper for the multiple image search function"""
try:
# Use default values: threshold=0.5, results=4
return image_search_performers(img, self.data_manager, 0.5, 4)
except ValueError as e:
if "No faces found" in str(e):
return {"error": "No faces detected in the uploaded image. Please try uploading an image with visible faces."}
else:
raise e
def format_results_for_visual_display(self, json_results):
"""
Convert JSON results to visual components for better UX
Parameters:
json_results: List of face detection results from image_search_performers
Returns:
tuple: (gallery_images, html_content)
"""
if not json_results:
return [], "<p>No faces detected or no matches found.</p>"
# Handle error case
if isinstance(json_results, dict) and "error" in json_results:
error_html = f"""
<div class="performer-card">
<div class="face-info">
<h3 style="color: #ff6b6b;">Error</h3>
<p>{json_results['error']}</p>
</div>
</div>
"""
return [], error_html
gallery_images = []
html_parts = []
html_parts.append("""
<style>
body, .gradio-container {
background-color: #1e1e1e !important;
color: #d4d4d4 !important;
}
.performer-card {
border: 1px solid #404040;
border-radius: 12px;
padding: 24px;
margin: 16px 0;
background: #2d2d2d;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
color: #d4d4d4;
}
.face-info {
background: #3c3c3c;
padding: 20px;
border-radius: 8px;
margin-bottom: 24px;
border: 1px solid #4a4a4a;
display: flex;
align-items: flex-start;
gap: 20px;
}
.face-info-content {
flex: 1;
}
.face-info h3 {
color: #ffffff;
margin-top: 0;
font-size: 1.4em;
}
.performer-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(350px, 1fr));
gap: 24px;
margin-top: 16px;
}
.performer-item {
border: 1px solid #4a4a4a;
border-radius: 12px;
padding: 24px;
background: #333333;
text-align: center;
transition: all 0.3s ease;
box-shadow: 0 2px 8px rgba(0,0,0,0.2);
display: flex;
flex-direction: column;
align-items: center;
}
.performer-item:hover {
border-color: #569cd6;
box-shadow: 0 4px 16px rgba(0,0,0,0.4);
transform: translateY(-2px);
}
.performer-image {
width: 120px;
height: 120px;
border-radius: 12px;
object-fit: cover;
margin: 0 auto 16px auto;
display: block;
border: 2px solid #4a4a4a;
transition: all 0.3s ease;
text-align: center;
}
.performer-image:hover {
border-color: #569cd6;
transform: scale(1.05);
}
.performer-item h4 {
color: #ffffff;
margin: 16px 0 8px 0;
font-size: 1.2em;
}
.performer-item h4 a {
color: #569cd6;
text-decoration: none;
transition: color 0.3s ease;
}
.performer-item h4 a:hover {
color: #9cdcfe;
text-decoration: underline;
}
.performer-item p {
color: #cccccc;
margin: 8px 0;
}
.performer-item small {
color: #999999;
}
.confidence-bar {
background: #404040;
border-radius: 12px;
overflow: hidden;
height: 28px;
margin: 12px 0;
border: 1px solid #4a4a4a;
width: 100%;
max-width: 200px;
}
.confidence-fill {
height: 100%;
transition: width 0.5s ease;
text-align: center;
line-height: 28px;
color: white;
font-size: 13px;
font-weight: bold;
text-shadow: 0 1px 2px rgba(0,0,0,0.5);
}
.high-confidence {
background: linear-gradient(135deg, #4caf50, #66bb6a);
}
.medium-confidence {
background: linear-gradient(135deg, #ff9800, #ffb74d);
}
.low-confidence {
background: linear-gradient(135deg, #f44336, #ef5350);
}
.face-info p strong {
color: #9cdcfe;
}
.country-flag {
font-size: 1.2em;
margin-right: 6px;
vertical-align: middle;
}
</style>
""")
for i, face_result in enumerate(json_results):
# Convert base64 face image to PIL for gallery
try:
face_image_data = base64.b64decode(face_result['image'])
face_pil = PILImage.open(io.BytesIO(face_image_data))
gallery_images.append(face_pil)
except Exception as e:
print(f"Error decoding face image: {e}")
continue
# Create HTML for this face
face_confidence = face_result['confidence']
performers = face_result['performers']
# Create base64 data URL for the detected face image
face_image_b64 = f"data:image/jpeg;base64,{face_result['image']}"
html_parts.append(f"""
<div class="performer-card">
<div class="face-info">
<div class="detected-face">
<img src="{face_image_b64}" alt="Detected Face {i+1}" style="width: 120px; height: 120px; border-radius: 12px; object-fit: cover; border: 2px solid #569cd6; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">
</div>
<div class="face-info-content">
<h3>Face {i+1}</h3>
<p><strong>Detection Confidence:</strong> {face_confidence:.1%}</p>
<p><strong>Matches Found:</strong> {len(performers)}</p>
</div>
</div>
""")
if performers:
html_parts.append('<div class="performer-grid">')
for performer in performers:
confidence_class = "high-confidence" if performer['confidence'] >= 70 else "medium-confidence" if performer['confidence'] >= 50 else "low-confidence"
# Create performer name with link if URL exists
performer_name = performer['name']
if performer.get('url'):
performer_name = f'<a href="{performer["url"]}" target="_blank">{performer["name"]}</a>'
html_parts.append(f"""
<div class="performer-item">
<img src="{performer['image']}" alt="{performer['name']}" class="performer-image" onerror="this.style.display='none'">
<h4>{performer_name}</h4>
<div class="confidence-bar">
<div class="confidence-fill {confidence_class}" style="width: {performer['confidence']}%">
{performer['confidence']}%
</div>
</div>
</div>
""")
html_parts.append('</div>')
else:
html_parts.append('<p><em>No performer matches found for this face.</em></p>')
html_parts.append('</div>')
return gallery_images, ''.join(html_parts)
def multiple_image_search_with_visual(self, img):
"""
Enhanced search function that returns both JSON and visual components
Returns:
tuple: (json_results, gallery_images, html_content)
"""
try:
json_results = self.multiple_image_search(img)
gallery_images, html_content = self.format_results_for_visual_display(json_results)
return json_results, gallery_images, html_content
except Exception as e:
error_msg = f"<div class='performer-card'><h3>Error</h3><p>{str(e)}</p></div>"
return [], [], error_msg
def _create_visual_search_interface(self):
"""Create the visual search interface"""
with gr.Blocks() as interface:
gr.Markdown("# Who is in the photo?")
gr.Markdown("Upload an image of a person(s) and we'll show you who it is with photos and details.")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil")
search_btn = gr.Button("Search")
with gr.Column():
performer_info = gr.HTML(
label="Performer Information",
value="<p>Upload an image and click search to see results.</p>"
)
def visual_search_wrapper(img):
"""Wrapper that returns only visual components"""
json_results, gallery_images, html_content = self.multiple_image_search_with_visual(img)
return html_content
search_btn.click(
fn=visual_search_wrapper,
inputs=[img_input],
outputs=[performer_info],
api_name="multiple_image_search_with_visual"
)
return interface
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
"""Launch the web interface"""
with gr.Blocks(
css="""
.gradio-container {
background-color: #1e1e1e !important;
color: #d4d4d4 !important;
}
.dark {
--background-fill-primary: #2d2d2d;
--background-fill-secondary: #3c3c3c;
--border-color-primary: #404040;
--block-title-text-color: #ffffff;
--body-text-color: #d4d4d4;
}
"""
) as demo:
with gr.Tabs():
with gr.TabItem("Visual Search"):
self._create_visual_search_interface()
demo.queue().launch(server_name=server_name, server_port=server_port, share=share, ssr_mode=False)