stashface_onnx / web /interface.py
cc1234
Stashface: face recognition with ensemble models
9282c17
import gradio as gr
from models.data_manager import DataManager
from models.image_processor import image_search_performers
_COUNTRY_FLAGS = {
'AD': '🇦🇩', 'AE': '🇦🇪', 'AF': '🇦🇫', 'AG': '🇦🇬', 'AI': '🇦🇮', 'AL': '🇦🇱', 'AM': '🇦🇲', 'AO': '🇦🇴',
'AQ': '🇦🇶', 'AR': '🇦🇷', 'AS': '🇦🇸', 'AT': '🇦🇹', 'AU': '🇦🇺', 'AW': '🇦🇼', 'AX': '🇦🇽', 'AZ': '🇦🇿',
'BA': '🇧🇦', 'BB': '🇧🇧', 'BD': '🇧🇩', 'BE': '🇧🇪', 'BF': '🇧🇫', 'BG': '🇧🇬', 'BH': '🇧🇭', 'BI': '🇧🇮',
'BJ': '🇧🇯', 'BL': '🇧🇱', 'BM': '🇧🇲', 'BN': '🇧🇳', 'BO': '🇧🇴', 'BQ': '🇧🇶', 'BR': '🇧🇷', 'BS': '🇧🇸',
'BT': '🇧🇹', 'BV': '🇧🇻', 'BW': '🇧🇼', 'BY': '🇧🇾', 'BZ': '🇧🇿', 'CA': '🇨🇦', 'CC': '🇨🇨', 'CD': '🇨🇩',
'CF': '🇨🇫', 'CG': '🇨🇬', 'CH': '🇨🇭', 'CI': '🇨🇮', 'CK': '🇨🇰', 'CL': '🇨🇱', 'CM': '🇨🇲', 'CN': '🇨🇳',
'CO': '🇨🇴', 'CR': '🇨🇷', 'CU': '🇨🇺', 'CV': '🇨🇻', 'CW': '🇨🇼', 'CX': '🇨🇽', 'CY': '🇨🇾', 'CZ': '🇨🇿',
'DE': '🇩🇪', 'DJ': '🇩🇯', 'DK': '🇩🇰', 'DM': '🇩🇲', 'DO': '🇩🇴', 'DZ': '🇩🇿', 'EC': '🇪🇨', 'EE': '🇪🇪',
'EG': '🇪🇬', 'EH': '🇪🇭', 'ER': '🇪🇷', 'ES': '🇪🇸', 'ET': '🇪🇹', 'FI': '🇫🇮', 'FJ': '🇫🇯', 'FK': '🇫🇰',
'FM': '🇫🇲', 'FO': '🇫🇴', 'FR': '🇫🇷', 'GA': '🇬🇦', 'GB': '🇬🇧', 'GD': '🇬🇩', 'GE': '🇬🇪', 'GF': '🇬🇫',
'GG': '🇬🇬', 'GH': '🇬🇭', 'GI': '🇬🇮', 'GL': '🇬🇱', 'GM': '🇬🇲', 'GN': '🇬🇳', 'GP': '🇬🇵', 'GQ': '🇬🇶',
'GR': '🇬🇷', 'GS': '🇬🇸', 'GT': '🇬🇹', 'GU': '🇬🇺', 'GW': '🇬🇼', 'GY': '🇬🇾', 'HK': '🇭🇰', 'HM': '🇭🇲',
'HN': '🇭🇳', 'HR': '🇭🇷', 'HT': '🇭🇹', 'HU': '🇭🇺', 'ID': '🇮🇩', 'IE': '🇮🇪', 'IL': '🇮🇱', 'IM': '🇮🇲',
'IN': '🇮🇳', 'IO': '🇮🇴', 'IQ': '🇮🇶', 'IR': '🇮🇷', 'IS': '🇮🇸', 'IT': '🇮🇹', 'JE': '🇯🇪', 'JM': '🇯🇲',
'JO': '🇯🇴', 'JP': '🇯🇵', 'KE': '🇰🇪', 'KG': '🇰🇬', 'KH': '🇰🇭', 'KI': '🇰🇮', 'KM': '🇰🇲', 'KN': '🇰🇳',
'KP': '🇰🇵', 'KR': '🇰🇷', 'KW': '🇰🇼', 'KY': '🇰🇾', 'KZ': '🇰🇿', 'LA': '🇱🇦', 'LB': '🇱🇧', 'LC': '🇱🇨',
'LI': '🇱🇮', 'LK': '🇱🇰', 'LR': '🇱🇷', 'LS': '🇱🇸', 'LT': '🇱🇹', 'LU': '🇱🇺', 'LV': '🇱🇻', 'LY': '🇱🇾',
'MA': '🇲🇦', 'MC': '🇲🇨', 'MD': '🇲🇩', 'ME': '🇲🇪', 'MF': '🇲🇫', 'MG': '🇲🇬', 'MH': '🇲🇭', 'MK': '🇲🇰',
'ML': '🇲🇱', 'MM': '🇲🇲', 'MN': '🇲🇳', 'MO': '🇲🇴', 'MP': '🇲🇵', 'MQ': '🇲🇶', 'MR': '🇲🇷', 'MS': '🇲🇸',
'MT': '🇲🇹', 'MU': '🇲🇺', 'MV': '🇲🇻', 'MW': '🇲🇼', 'MX': '🇲🇽', 'MY': '🇲🇾', 'MZ': '🇲🇿', 'NA': '🇳🇦',
'NC': '🇳🇨', 'NE': '🇳🇪', 'NF': '🇳🇫', 'NG': '🇳🇬', 'NI': '🇳🇮', 'NL': '🇳🇱', 'NO': '🇳🇴', 'NP': '🇳🇵',
'NR': '🇳🇷', 'NU': '🇳🇺', 'NZ': '🇳🇿', 'OM': '🇴🇲', 'PA': '🇵🇦', 'PE': '🇵🇪', 'PF': '🇵🇫', 'PG': '🇵🇬',
'PH': '🇵🇭', 'PK': '🇵🇰', 'PL': '🇵🇱', 'PM': '🇵🇲', 'PN': '🇵🇳', 'PR': '🇵🇷', 'PS': '🇵🇸', 'PT': '🇵🇹',
'PW': '🇵🇼', 'PY': '🇵🇾', 'QA': '🇶🇦', 'RE': '🇷🇪', 'RO': '🇷🇴', 'RS': '🇷🇸', 'RU': '🇷🇺', 'RW': '🇷🇼',
'SA': '🇸🇦', 'SB': '🇸🇧', 'SC': '🇸🇨', 'SD': '🇸🇩', 'SE': '🇸🇪', 'SG': '🇸🇬', 'SH': '🇸🇭', 'SI': '🇸🇮',
'SJ': '🇸🇯', 'SK': '🇸🇰', 'SL': '🇸🇱', 'SM': '🇸🇲', 'SN': '🇸🇳', 'SO': '🇸🇴', 'SR': '🇸🇷', 'SS': '🇸🇸',
'ST': '🇸🇹', 'SV': '🇸🇻', 'SX': '🇸🇽', 'SY': '🇸🇾', 'SZ': '🇸🇿', 'TC': '🇹🇨', 'TD': '🇹🇩', 'TF': '🇹🇫',
'TG': '🇹🇬', 'TH': '🇹🇭', 'TJ': '🇹🇯', 'TK': '🇹🇰', 'TL': '🇹🇱', 'TM': '🇹🇲', 'TN': '🇹🇳', 'TO': '🇹🇴',
'TR': '🇹🇷', 'TT': '🇹🇹', 'TV': '🇹🇻', 'TW': '🇹🇼', 'TZ': '🇹🇿', 'UA': '🇺🇦', 'UG': '🇺🇬', 'UM': '🇺🇲',
'US': '🇺🇸', 'UY': '🇺🇾', 'UZ': '🇺🇿', 'VA': '🇻🇦', 'VC': '🇻🇨', 'VE': '🇻🇪', 'VG': '🇻🇬', 'VI': '🇻🇮',
'VN': '🇻🇳', 'VU': '🇻🇺', 'WF': '🇼🇫', 'WS': '🇼🇸', 'YE': '🇾🇪', 'YT': '🇾🇹', 'ZA': '🇿🇦', 'ZM': '🇿🇲',
'ZW': '🇿🇼',
}
class WebInterface:
def __init__(self, data_manager: DataManager, default_threshold: float = 50):
"""
Initialize the web interface.
Parameters:
data_manager: DataManager instance
default_threshold: Default confidence threshold
"""
self.data_manager = data_manager
self.default_threshold = default_threshold
def get_country_flag(self, country_code):
if not country_code or len(country_code) != 2:
return ""
return _COUNTRY_FLAGS.get(country_code.upper(), "")
def multiple_image_search(self, img, threshold, results):
"""Wrapper for the multiple image search function"""
try:
return image_search_performers(img, self.data_manager, threshold / 100.0, results)
except ValueError as e:
if "No faces found" in str(e):
return {"error": "No faces detected in the uploaded image. Please try uploading an image with visible faces."}
else:
raise e
def format_results_for_visual_display(self, json_results):
"""
Convert JSON results to HTML for visual display.
Parameters:
json_results: List of face detection results from image_search_performers
Returns:
str: HTML content
"""
if not json_results:
return "<p>No faces detected or no matches found.</p>"
# Handle error case
if isinstance(json_results, dict) and "error" in json_results:
return f"""
<div class="performer-card">
<div class="face-info">
<h3 style="color: #ff6b6b;">Error</h3>
<p>{json_results['error']}</p>
</div>
</div>
"""
html_parts = []
html_parts.append("""
<style>
body, .gradio-container {
background-color: #1e1e1e !important;
color: #d4d4d4 !important;
}
.performer-card {
border: 1px solid #404040;
border-radius: 12px;
padding: 24px;
margin: 16px 0;
background: #2d2d2d;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
color: #d4d4d4;
}
.face-info {
background: #3c3c3c;
padding: 20px;
border-radius: 8px;
margin-bottom: 24px;
border: 1px solid #4a4a4a;
display: flex;
align-items: flex-start;
gap: 20px;
}
.face-info-content {
flex: 1;
}
.face-info h3 {
color: #ffffff;
margin-top: 0;
font-size: 1.4em;
}
.performer-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(350px, 1fr));
gap: 24px;
margin-top: 16px;
}
.performer-item {
border: 1px solid #4a4a4a;
border-radius: 12px;
padding: 24px;
background: #333333;
text-align: center;
transition: all 0.3s ease;
box-shadow: 0 2px 8px rgba(0,0,0,0.2);
display: flex;
flex-direction: column;
align-items: center;
}
.performer-item:hover {
border-color: #569cd6;
box-shadow: 0 4px 16px rgba(0,0,0,0.4);
transform: translateY(-2px);
}
.performer-image {
width: 120px;
height: 120px;
border-radius: 12px;
object-fit: cover;
margin: 0 auto 16px auto;
display: block;
border: 2px solid #4a4a4a;
transition: all 0.3s ease;
text-align: center;
}
.performer-image:hover {
border-color: #569cd6;
transform: scale(1.05);
}
.performer-item h4 {
color: #ffffff;
margin: 16px 0 8px 0;
font-size: 1.2em;
}
.performer-item h4 a {
color: #569cd6;
text-decoration: none;
transition: color 0.3s ease;
}
.performer-item h4 a:hover {
color: #9cdcfe;
text-decoration: underline;
}
.performer-item p {
color: #cccccc;
margin: 8px 0;
}
.performer-item small {
color: #999999;
}
.confidence-bar {
background: #404040;
border-radius: 12px;
overflow: hidden;
height: 28px;
margin: 12px 0;
border: 1px solid #4a4a4a;
width: 100%;
max-width: 200px;
}
.confidence-fill {
height: 100%;
transition: width 0.5s ease;
text-align: center;
line-height: 28px;
color: white;
font-size: 13px;
font-weight: bold;
text-shadow: 0 1px 2px rgba(0,0,0,0.5);
}
.high-confidence {
background: linear-gradient(135deg, #4caf50, #66bb6a);
}
.medium-confidence {
background: linear-gradient(135deg, #ff9800, #ffb74d);
}
.low-confidence {
background: linear-gradient(135deg, #f44336, #ef5350);
}
.face-info p strong {
color: #9cdcfe;
}
.country-flag {
font-size: 1.2em;
margin-right: 6px;
vertical-align: middle;
}
.detected-face img {
width: 120px;
height: 120px;
border-radius: 12px;
object-fit: cover;
border: 2px solid #569cd6;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
}
</style>
""")
for i, face_result in enumerate(json_results):
face_confidence = face_result['confidence']
performers = face_result['performers']
face_image_b64 = f"data:image/jpeg;base64,{face_result['image']}"
html_parts.append(f"""
<div class="performer-card">
<div class="face-info">
<div class="detected-face">
<img src="{face_image_b64}" alt="Detected Face {i+1}">
</div>
<div class="face-info-content">
<h3>Face {i+1}</h3>
<p><strong>Detection Confidence:</strong> {face_confidence:.1%}</p>
<p><strong>Matches Found:</strong> {len(performers)}</p>
</div>
</div>
""")
if performers:
html_parts.append('<div class="performer-grid">')
for performer in performers:
confidence_class = "high-confidence" if performer['confidence'] >= 80 else "medium-confidence" if performer['confidence'] >= 60 else "low-confidence"
country_code = performer.get('country', '')
country_flag = self.get_country_flag(country_code)
country_display = f"{country_flag} {country_code}" if country_flag else (country_code if country_code else 'Unknown')
html_parts.append(f"""
<div class="performer-item">
<img src="{performer['image']}" alt="{performer['name']}" class="performer-image" onerror="this.style.display='none'">
<h4><a href="{performer['performer_url']}" target="_blank">{performer['name']}</a></h4>
<p><strong>Country:</strong> {country_display}</p>
<div class="confidence-bar">
<div class="confidence-fill {confidence_class}" style="width: {performer['confidence']}%">
{performer['confidence']}%
</div>
</div>
<p><small>Distance: {performer.get('distance', 'N/A')}</small></p>
</div>
""")
html_parts.append('</div>')
else:
html_parts.append('<p><em>No performer matches found for this face.</em></p>')
html_parts.append('</div>')
return ''.join(html_parts)
def multiple_image_search_with_visual(self, img, threshold, results):
"""Run face search and return HTML for visual display."""
try:
json_results = self.multiple_image_search(img, threshold, results)
return self.format_results_for_visual_display(json_results)
except Exception as e:
return f"<div class='performer-card'><h3>Error</h3><p>{str(e)}</p></div>"
def _create_json_search_interface(self):
"""Create the JSON API search interface"""
with gr.Blocks() as interface:
gr.Markdown("# Face Recognition API")
gr.Markdown("Upload an image and get JSON results - perfect for API integration.")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil")
threshold = gr.Slider(
label="threshold",
minimum=1,
maximum=100,
value=self.default_threshold,
step=1
)
results_count = gr.Slider(
label="results",
minimum=0,
maximum=50,
value=3,
step=1
)
search_btn = gr.Button("Search")
with gr.Column():
json_output = gr.JSON(label="JSON Results")
search_btn.click(
fn=self.multiple_image_search,
inputs=[img_input, threshold, results_count],
outputs=json_output,
api_name="multiple_image_search"
)
return interface
def _create_visual_search_interface(self):
"""Create the visual search interface"""
with gr.Blocks() as interface:
gr.Markdown("# Who is in the photo?")
gr.Markdown("Upload an image of a person(s) and we'll show you who it is with photos and details.")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil")
threshold = gr.Slider(
label="threshold",
minimum=1,
maximum=100,
value=self.default_threshold,
step=1
)
results_count = gr.Slider(
label="results",
minimum=0,
maximum=50,
value=3,
step=1
)
search_btn = gr.Button("Search")
with gr.Column():
performer_info = gr.HTML(
label="Performer Information",
value="<p>Upload an image and click search to see results.</p>"
)
search_btn.click(
fn=self.multiple_image_search_with_visual,
inputs=[img_input, threshold, results_count],
outputs=[performer_info],
api_name="multiple_image_search_with_visual"
)
return interface
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
"""Launch the web interface"""
with gr.Blocks(
css="""
.gradio-container {
background-color: #1e1e1e !important;
color: #d4d4d4 !important;
}
.dark {
--background-fill-primary: #2d2d2d;
--background-fill-secondary: #3c3c3c;
--border-color-primary: #404040;
--block-title-text-color: #ffffff;
--body-text-color: #d4d4d4;
}
"""
) as demo:
with gr.Tabs():
with gr.TabItem("Visual Search"):
self._create_visual_search_interface()
with gr.TabItem("JSON API"):
self._create_json_search_interface()
demo.queue().launch(server_name=server_name, server_port=server_port, share=share, ssr_mode=False)