Spaces:
Running
Running
| import gradio as gr | |
| from models.data_manager import DataManager | |
| from models.image_processor import image_search_performers | |
| _COUNTRY_FLAGS = { | |
| 'AD': '🇦🇩', 'AE': '🇦🇪', 'AF': '🇦🇫', 'AG': '🇦🇬', 'AI': '🇦🇮', 'AL': '🇦🇱', 'AM': '🇦🇲', 'AO': '🇦🇴', | |
| 'AQ': '🇦🇶', 'AR': '🇦🇷', 'AS': '🇦🇸', 'AT': '🇦🇹', 'AU': '🇦🇺', 'AW': '🇦🇼', 'AX': '🇦🇽', 'AZ': '🇦🇿', | |
| 'BA': '🇧🇦', 'BB': '🇧🇧', 'BD': '🇧🇩', 'BE': '🇧🇪', 'BF': '🇧🇫', 'BG': '🇧🇬', 'BH': '🇧🇭', 'BI': '🇧🇮', | |
| 'BJ': '🇧🇯', 'BL': '🇧🇱', 'BM': '🇧🇲', 'BN': '🇧🇳', 'BO': '🇧🇴', 'BQ': '🇧🇶', 'BR': '🇧🇷', 'BS': '🇧🇸', | |
| 'BT': '🇧🇹', 'BV': '🇧🇻', 'BW': '🇧🇼', 'BY': '🇧🇾', 'BZ': '🇧🇿', 'CA': '🇨🇦', 'CC': '🇨🇨', 'CD': '🇨🇩', | |
| 'CF': '🇨🇫', 'CG': '🇨🇬', 'CH': '🇨🇭', 'CI': '🇨🇮', 'CK': '🇨🇰', 'CL': '🇨🇱', 'CM': '🇨🇲', 'CN': '🇨🇳', | |
| 'CO': '🇨🇴', 'CR': '🇨🇷', 'CU': '🇨🇺', 'CV': '🇨🇻', 'CW': '🇨🇼', 'CX': '🇨🇽', 'CY': '🇨🇾', 'CZ': '🇨🇿', | |
| 'DE': '🇩🇪', 'DJ': '🇩🇯', 'DK': '🇩🇰', 'DM': '🇩🇲', 'DO': '🇩🇴', 'DZ': '🇩🇿', 'EC': '🇪🇨', 'EE': '🇪🇪', | |
| 'EG': '🇪🇬', 'EH': '🇪🇭', 'ER': '🇪🇷', 'ES': '🇪🇸', 'ET': '🇪🇹', 'FI': '🇫🇮', 'FJ': '🇫🇯', 'FK': '🇫🇰', | |
| 'FM': '🇫🇲', 'FO': '🇫🇴', 'FR': '🇫🇷', 'GA': '🇬🇦', 'GB': '🇬🇧', 'GD': '🇬🇩', 'GE': '🇬🇪', 'GF': '🇬🇫', | |
| 'GG': '🇬🇬', 'GH': '🇬🇭', 'GI': '🇬🇮', 'GL': '🇬🇱', 'GM': '🇬🇲', 'GN': '🇬🇳', 'GP': '🇬🇵', 'GQ': '🇬🇶', | |
| 'GR': '🇬🇷', 'GS': '🇬🇸', 'GT': '🇬🇹', 'GU': '🇬🇺', 'GW': '🇬🇼', 'GY': '🇬🇾', 'HK': '🇭🇰', 'HM': '🇭🇲', | |
| 'HN': '🇭🇳', 'HR': '🇭🇷', 'HT': '🇭🇹', 'HU': '🇭🇺', 'ID': '🇮🇩', 'IE': '🇮🇪', 'IL': '🇮🇱', 'IM': '🇮🇲', | |
| 'IN': '🇮🇳', 'IO': '🇮🇴', 'IQ': '🇮🇶', 'IR': '🇮🇷', 'IS': '🇮🇸', 'IT': '🇮🇹', 'JE': '🇯🇪', 'JM': '🇯🇲', | |
| 'JO': '🇯🇴', 'JP': '🇯🇵', 'KE': '🇰🇪', 'KG': '🇰🇬', 'KH': '🇰🇭', 'KI': '🇰🇮', 'KM': '🇰🇲', 'KN': '🇰🇳', | |
| 'KP': '🇰🇵', 'KR': '🇰🇷', 'KW': '🇰🇼', 'KY': '🇰🇾', 'KZ': '🇰🇿', 'LA': '🇱🇦', 'LB': '🇱🇧', 'LC': '🇱🇨', | |
| 'LI': '🇱🇮', 'LK': '🇱🇰', 'LR': '🇱🇷', 'LS': '🇱🇸', 'LT': '🇱🇹', 'LU': '🇱🇺', 'LV': '🇱🇻', 'LY': '🇱🇾', | |
| 'MA': '🇲🇦', 'MC': '🇲🇨', 'MD': '🇲🇩', 'ME': '🇲🇪', 'MF': '🇲🇫', 'MG': '🇲🇬', 'MH': '🇲🇭', 'MK': '🇲🇰', | |
| 'ML': '🇲🇱', 'MM': '🇲🇲', 'MN': '🇲🇳', 'MO': '🇲🇴', 'MP': '🇲🇵', 'MQ': '🇲🇶', 'MR': '🇲🇷', 'MS': '🇲🇸', | |
| 'MT': '🇲🇹', 'MU': '🇲🇺', 'MV': '🇲🇻', 'MW': '🇲🇼', 'MX': '🇲🇽', 'MY': '🇲🇾', 'MZ': '🇲🇿', 'NA': '🇳🇦', | |
| 'NC': '🇳🇨', 'NE': '🇳🇪', 'NF': '🇳🇫', 'NG': '🇳🇬', 'NI': '🇳🇮', 'NL': '🇳🇱', 'NO': '🇳🇴', 'NP': '🇳🇵', | |
| 'NR': '🇳🇷', 'NU': '🇳🇺', 'NZ': '🇳🇿', 'OM': '🇴🇲', 'PA': '🇵🇦', 'PE': '🇵🇪', 'PF': '🇵🇫', 'PG': '🇵🇬', | |
| 'PH': '🇵🇭', 'PK': '🇵🇰', 'PL': '🇵🇱', 'PM': '🇵🇲', 'PN': '🇵🇳', 'PR': '🇵🇷', 'PS': '🇵🇸', 'PT': '🇵🇹', | |
| 'PW': '🇵🇼', 'PY': '🇵🇾', 'QA': '🇶🇦', 'RE': '🇷🇪', 'RO': '🇷🇴', 'RS': '🇷🇸', 'RU': '🇷🇺', 'RW': '🇷🇼', | |
| 'SA': '🇸🇦', 'SB': '🇸🇧', 'SC': '🇸🇨', 'SD': '🇸🇩', 'SE': '🇸🇪', 'SG': '🇸🇬', 'SH': '🇸🇭', 'SI': '🇸🇮', | |
| 'SJ': '🇸🇯', 'SK': '🇸🇰', 'SL': '🇸🇱', 'SM': '🇸🇲', 'SN': '🇸🇳', 'SO': '🇸🇴', 'SR': '🇸🇷', 'SS': '🇸🇸', | |
| 'ST': '🇸🇹', 'SV': '🇸🇻', 'SX': '🇸🇽', 'SY': '🇸🇾', 'SZ': '🇸🇿', 'TC': '🇹🇨', 'TD': '🇹🇩', 'TF': '🇹🇫', | |
| 'TG': '🇹🇬', 'TH': '🇹🇭', 'TJ': '🇹🇯', 'TK': '🇹🇰', 'TL': '🇹🇱', 'TM': '🇹🇲', 'TN': '🇹🇳', 'TO': '🇹🇴', | |
| 'TR': '🇹🇷', 'TT': '🇹🇹', 'TV': '🇹🇻', 'TW': '🇹🇼', 'TZ': '🇹🇿', 'UA': '🇺🇦', 'UG': '🇺🇬', 'UM': '🇺🇲', | |
| 'US': '🇺🇸', 'UY': '🇺🇾', 'UZ': '🇺🇿', 'VA': '🇻🇦', 'VC': '🇻🇨', 'VE': '🇻🇪', 'VG': '🇻🇬', 'VI': '🇻🇮', | |
| 'VN': '🇻🇳', 'VU': '🇻🇺', 'WF': '🇼🇫', 'WS': '🇼🇸', 'YE': '🇾🇪', 'YT': '🇾🇹', 'ZA': '🇿🇦', 'ZM': '🇿🇲', | |
| 'ZW': '🇿🇼', | |
| } | |
| class WebInterface: | |
| def __init__(self, data_manager: DataManager, default_threshold: float = 50): | |
| """ | |
| Initialize the web interface. | |
| Parameters: | |
| data_manager: DataManager instance | |
| default_threshold: Default confidence threshold | |
| """ | |
| self.data_manager = data_manager | |
| self.default_threshold = default_threshold | |
| def get_country_flag(self, country_code): | |
| if not country_code or len(country_code) != 2: | |
| return "" | |
| return _COUNTRY_FLAGS.get(country_code.upper(), "") | |
| def multiple_image_search(self, img, threshold, results): | |
| """Wrapper for the multiple image search function""" | |
| try: | |
| return image_search_performers(img, self.data_manager, threshold / 100.0, results) | |
| except ValueError as e: | |
| if "No faces found" in str(e): | |
| return {"error": "No faces detected in the uploaded image. Please try uploading an image with visible faces."} | |
| else: | |
| raise e | |
| def format_results_for_visual_display(self, json_results): | |
| """ | |
| Convert JSON results to HTML for visual display. | |
| Parameters: | |
| json_results: List of face detection results from image_search_performers | |
| Returns: | |
| str: HTML content | |
| """ | |
| if not json_results: | |
| return "<p>No faces detected or no matches found.</p>" | |
| # Handle error case | |
| if isinstance(json_results, dict) and "error" in json_results: | |
| return f""" | |
| <div class="performer-card"> | |
| <div class="face-info"> | |
| <h3 style="color: #ff6b6b;">Error</h3> | |
| <p>{json_results['error']}</p> | |
| </div> | |
| </div> | |
| """ | |
| html_parts = [] | |
| html_parts.append(""" | |
| <style> | |
| body, .gradio-container { | |
| background-color: #1e1e1e !important; | |
| color: #d4d4d4 !important; | |
| } | |
| .performer-card { | |
| border: 1px solid #404040; | |
| border-radius: 12px; | |
| padding: 24px; | |
| margin: 16px 0; | |
| background: #2d2d2d; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.3); | |
| color: #d4d4d4; | |
| } | |
| .face-info { | |
| background: #3c3c3c; | |
| padding: 20px; | |
| border-radius: 8px; | |
| margin-bottom: 24px; | |
| border: 1px solid #4a4a4a; | |
| display: flex; | |
| align-items: flex-start; | |
| gap: 20px; | |
| } | |
| .face-info-content { | |
| flex: 1; | |
| } | |
| .face-info h3 { | |
| color: #ffffff; | |
| margin-top: 0; | |
| font-size: 1.4em; | |
| } | |
| .performer-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(350px, 1fr)); | |
| gap: 24px; | |
| margin-top: 16px; | |
| } | |
| .performer-item { | |
| border: 1px solid #4a4a4a; | |
| border-radius: 12px; | |
| padding: 24px; | |
| background: #333333; | |
| text-align: center; | |
| transition: all 0.3s ease; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.2); | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| } | |
| .performer-item:hover { | |
| border-color: #569cd6; | |
| box-shadow: 0 4px 16px rgba(0,0,0,0.4); | |
| transform: translateY(-2px); | |
| } | |
| .performer-image { | |
| width: 120px; | |
| height: 120px; | |
| border-radius: 12px; | |
| object-fit: cover; | |
| margin: 0 auto 16px auto; | |
| display: block; | |
| border: 2px solid #4a4a4a; | |
| transition: all 0.3s ease; | |
| text-align: center; | |
| } | |
| .performer-image:hover { | |
| border-color: #569cd6; | |
| transform: scale(1.05); | |
| } | |
| .performer-item h4 { | |
| color: #ffffff; | |
| margin: 16px 0 8px 0; | |
| font-size: 1.2em; | |
| } | |
| .performer-item h4 a { | |
| color: #569cd6; | |
| text-decoration: none; | |
| transition: color 0.3s ease; | |
| } | |
| .performer-item h4 a:hover { | |
| color: #9cdcfe; | |
| text-decoration: underline; | |
| } | |
| .performer-item p { | |
| color: #cccccc; | |
| margin: 8px 0; | |
| } | |
| .performer-item small { | |
| color: #999999; | |
| } | |
| .confidence-bar { | |
| background: #404040; | |
| border-radius: 12px; | |
| overflow: hidden; | |
| height: 28px; | |
| margin: 12px 0; | |
| border: 1px solid #4a4a4a; | |
| width: 100%; | |
| max-width: 200px; | |
| } | |
| .confidence-fill { | |
| height: 100%; | |
| transition: width 0.5s ease; | |
| text-align: center; | |
| line-height: 28px; | |
| color: white; | |
| font-size: 13px; | |
| font-weight: bold; | |
| text-shadow: 0 1px 2px rgba(0,0,0,0.5); | |
| } | |
| .high-confidence { | |
| background: linear-gradient(135deg, #4caf50, #66bb6a); | |
| } | |
| .medium-confidence { | |
| background: linear-gradient(135deg, #ff9800, #ffb74d); | |
| } | |
| .low-confidence { | |
| background: linear-gradient(135deg, #f44336, #ef5350); | |
| } | |
| .face-info p strong { | |
| color: #9cdcfe; | |
| } | |
| .country-flag { | |
| font-size: 1.2em; | |
| margin-right: 6px; | |
| vertical-align: middle; | |
| } | |
| .detected-face img { | |
| width: 120px; | |
| height: 120px; | |
| border-radius: 12px; | |
| object-fit: cover; | |
| border: 2px solid #569cd6; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.3); | |
| } | |
| </style> | |
| """) | |
| for i, face_result in enumerate(json_results): | |
| face_confidence = face_result['confidence'] | |
| performers = face_result['performers'] | |
| face_image_b64 = f"data:image/jpeg;base64,{face_result['image']}" | |
| html_parts.append(f""" | |
| <div class="performer-card"> | |
| <div class="face-info"> | |
| <div class="detected-face"> | |
| <img src="{face_image_b64}" alt="Detected Face {i+1}"> | |
| </div> | |
| <div class="face-info-content"> | |
| <h3>Face {i+1}</h3> | |
| <p><strong>Detection Confidence:</strong> {face_confidence:.1%}</p> | |
| <p><strong>Matches Found:</strong> {len(performers)}</p> | |
| </div> | |
| </div> | |
| """) | |
| if performers: | |
| html_parts.append('<div class="performer-grid">') | |
| for performer in performers: | |
| confidence_class = "high-confidence" if performer['confidence'] >= 80 else "medium-confidence" if performer['confidence'] >= 60 else "low-confidence" | |
| country_code = performer.get('country', '') | |
| country_flag = self.get_country_flag(country_code) | |
| country_display = f"{country_flag} {country_code}" if country_flag else (country_code if country_code else 'Unknown') | |
| html_parts.append(f""" | |
| <div class="performer-item"> | |
| <img src="{performer['image']}" alt="{performer['name']}" class="performer-image" onerror="this.style.display='none'"> | |
| <h4><a href="{performer['performer_url']}" target="_blank">{performer['name']}</a></h4> | |
| <p><strong>Country:</strong> {country_display}</p> | |
| <div class="confidence-bar"> | |
| <div class="confidence-fill {confidence_class}" style="width: {performer['confidence']}%"> | |
| {performer['confidence']}% | |
| </div> | |
| </div> | |
| <p><small>Distance: {performer.get('distance', 'N/A')}</small></p> | |
| </div> | |
| """) | |
| html_parts.append('</div>') | |
| else: | |
| html_parts.append('<p><em>No performer matches found for this face.</em></p>') | |
| html_parts.append('</div>') | |
| return ''.join(html_parts) | |
| def multiple_image_search_with_visual(self, img, threshold, results): | |
| """Run face search and return HTML for visual display.""" | |
| try: | |
| json_results = self.multiple_image_search(img, threshold, results) | |
| return self.format_results_for_visual_display(json_results) | |
| except Exception as e: | |
| return f"<div class='performer-card'><h3>Error</h3><p>{str(e)}</p></div>" | |
| def _create_json_search_interface(self): | |
| """Create the JSON API search interface""" | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Face Recognition API") | |
| gr.Markdown("Upload an image and get JSON results - perfect for API integration.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil") | |
| threshold = gr.Slider( | |
| label="threshold", | |
| minimum=1, | |
| maximum=100, | |
| value=self.default_threshold, | |
| step=1 | |
| ) | |
| results_count = gr.Slider( | |
| label="results", | |
| minimum=0, | |
| maximum=50, | |
| value=3, | |
| step=1 | |
| ) | |
| search_btn = gr.Button("Search") | |
| with gr.Column(): | |
| json_output = gr.JSON(label="JSON Results") | |
| search_btn.click( | |
| fn=self.multiple_image_search, | |
| inputs=[img_input, threshold, results_count], | |
| outputs=json_output, | |
| api_name="multiple_image_search" | |
| ) | |
| return interface | |
| def _create_visual_search_interface(self): | |
| """Create the visual search interface""" | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Who is in the photo?") | |
| gr.Markdown("Upload an image of a person(s) and we'll show you who it is with photos and details.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil") | |
| threshold = gr.Slider( | |
| label="threshold", | |
| minimum=1, | |
| maximum=100, | |
| value=self.default_threshold, | |
| step=1 | |
| ) | |
| results_count = gr.Slider( | |
| label="results", | |
| minimum=0, | |
| maximum=50, | |
| value=3, | |
| step=1 | |
| ) | |
| search_btn = gr.Button("Search") | |
| with gr.Column(): | |
| performer_info = gr.HTML( | |
| label="Performer Information", | |
| value="<p>Upload an image and click search to see results.</p>" | |
| ) | |
| search_btn.click( | |
| fn=self.multiple_image_search_with_visual, | |
| inputs=[img_input, threshold, results_count], | |
| outputs=[performer_info], | |
| api_name="multiple_image_search_with_visual" | |
| ) | |
| return interface | |
| def launch(self, server_name="0.0.0.0", server_port=7860, share=True): | |
| """Launch the web interface""" | |
| with gr.Blocks( | |
| css=""" | |
| .gradio-container { | |
| background-color: #1e1e1e !important; | |
| color: #d4d4d4 !important; | |
| } | |
| .dark { | |
| --background-fill-primary: #2d2d2d; | |
| --background-fill-secondary: #3c3c3c; | |
| --border-color-primary: #404040; | |
| --block-title-text-color: #ffffff; | |
| --body-text-color: #d4d4d4; | |
| } | |
| """ | |
| ) as demo: | |
| with gr.Tabs(): | |
| with gr.TabItem("Visual Search"): | |
| self._create_visual_search_interface() | |
| with gr.TabItem("JSON API"): | |
| self._create_json_search_interface() | |
| demo.queue().launch(server_name=server_name, server_port=server_port, share=share, ssr_mode=False) | |