Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from PIL import Image | |
| from src.config import ALL_CATEGORIES_LABEL, ConfigurationError, METADATA_CSV | |
| from src.dataset import DatasetError, get_categories, load_metadata | |
| from src.favorites import add_result_to_favorites, prepare_favorites_export | |
| from src.runtime import AppRuntime | |
| from src.ui.constants import IMAGE_QUERY_EMPTY_PREVIEW, TEXT_QUERY_EMPTY_PREVIEW | |
| from src.ui.formatters import ( | |
| error_overview, | |
| favorite_gallery, | |
| image_query_preview, | |
| initial_overview, | |
| result_gallery, | |
| result_to_state, | |
| search_overview, | |
| text_query_preview, | |
| ) | |
| from src.ui.types import FavoriteState, GalleryItem, ResultState | |
| logger = logging.getLogger(__name__) | |
| def category_to_backend(category: str | None) -> str: | |
| return category or ALL_CATEGORIES_LABEL | |
| def current_categories() -> list[str]: | |
| try: | |
| categories = get_categories(METADATA_CSV) | |
| except Exception as exc: | |
| logger.warning("Could not load categories: %s", exc) | |
| return [ALL_CATEGORIES_LABEL] | |
| return categories or [ALL_CATEGORIES_LABEL] | |
| def metadata_status() -> tuple[bool, str]: | |
| if not METADATA_CSV.exists(): | |
| return ( | |
| False, | |
| "metadata.csv was not found. Put images under data/images/<category>/ and run `python scripts/build_metadata.py`.", | |
| ) | |
| try: | |
| metadata = load_metadata(METADATA_CSV) | |
| except Exception as exc: | |
| return False, f"metadata.csv could not be loaded: {exc}" | |
| if metadata.empty: | |
| return ( | |
| False, | |
| "metadata.csv does not contain any images. Add images and run `python scripts/build_metadata.py`.", | |
| ) | |
| return True, "" | |
| def configuration_message(exc: Exception) -> str: | |
| text = str(exc) | |
| if "UPSTASH_VECTOR" in text or "Upstash Vector" in text: | |
| return "Upstash Vector configuration was not detected. Check UPSTASH_VECTOR_REST_URL and UPSTASH_VECTOR_REST_TOKEN in .env." | |
| return f"Configuration could not be loaded: {text}" | |
| def dataset_message(exc: Exception) -> str: | |
| text = str(exc) | |
| if "metadata.csv" in text: | |
| return f"Dataset configuration error: {text}" | |
| return f"Dataset could not be read: {text}" | |
| class AppCallbacks: | |
| def __init__(self, runtime: AppRuntime) -> None: | |
| self.runtime = runtime | |
| def preview_text_query(text: str) -> str: | |
| return text_query_preview(text) | |
| def preview_image_query(image: Image.Image | None) -> str: | |
| return image_query_preview(image is not None) | |
| def select_text_mode() -> str: | |
| return "text" | |
| def select_image_mode() -> str: | |
| return "image" | |
| def run_search( | |
| self, | |
| query_mode: str, | |
| text_query: str, | |
| image_query: Image.Image | None, | |
| top_k: int, | |
| category_filter: str, | |
| min_similarity: float, | |
| ) -> tuple[str, list[GalleryItem], str | None, list[ResultState]]: | |
| mode = query_mode if query_mode in {"text", "image"} else "text" | |
| text = (text_query or "").strip() | |
| category = category_to_backend(category_filter) | |
| threshold = float(min_similarity or 0) | |
| if mode == "text": | |
| if not text: | |
| return ( | |
| error_overview( | |
| "Search cannot be completed", | |
| "Enter an image description before running text search.", | |
| ), | |
| [], | |
| None, | |
| [], | |
| ) | |
| current_query = text | |
| elif image_query is not None: | |
| current_query = "Uploaded image" | |
| else: | |
| return ( | |
| error_overview( | |
| "Search cannot be completed", | |
| "Upload a query image before running image search.", | |
| ), | |
| [], | |
| None, | |
| [], | |
| ) | |
| ready, message = metadata_status() | |
| if not ready: | |
| return error_overview("Dataset error", message), [], None, [] | |
| try: | |
| service = self.runtime.get_search_service() | |
| if mode == "text": | |
| results = service.search_by_text(text, int(top_k), category) | |
| else: | |
| assert image_query is not None | |
| results = service.search_by_image(image_query, int(top_k), category) | |
| except ConfigurationError as exc: | |
| self.runtime.reset_search_service() | |
| return error_overview("Configuration error", configuration_message(exc)), [], None, [] | |
| except (DatasetError, FileNotFoundError, ValueError) as exc: | |
| self.runtime.reset_search_service() | |
| return error_overview("Search cannot be completed", dataset_message(exc)), [], None, [] | |
| except Exception: | |
| self.runtime.reset_search_service() | |
| logger.exception("Search failed") | |
| preload_note = ( | |
| " The CLIP model also failed to preload earlier." | |
| if self.runtime.preload_error | |
| else "" | |
| ) | |
| return ( | |
| error_overview( | |
| "Search cannot be completed", | |
| f"An unexpected error occurred during search. Check the application logs for details.{preload_note}", | |
| ), | |
| [], | |
| None, | |
| [], | |
| ) | |
| filtered_results = [ | |
| result_to_state(result) | |
| for result in results | |
| if float(result.score) >= threshold | |
| ] | |
| empty_message = None | |
| if not filtered_results: | |
| empty_message = ( | |
| "No images matched the current filters. Try changing the query, category, or minimum similarity." | |
| ) | |
| return ( | |
| search_overview( | |
| current_query=current_query, | |
| returned_results=len(filtered_results), | |
| message=empty_message, | |
| ), | |
| result_gallery(filtered_results), | |
| None, | |
| filtered_results, | |
| ) | |
| def select_result_from_gallery( | |
| current_results: list[ResultState] | None, | |
| evt: gr.SelectData, | |
| ) -> str | None: | |
| results = list(current_results or []) | |
| index: Any = evt.index | |
| if isinstance(index, (list, tuple)): | |
| index = index[0] if index else None | |
| try: | |
| selected_index = int(index) | |
| except (TypeError, ValueError): | |
| return None | |
| if 0 <= selected_index < len(results): | |
| return str(results[selected_index].get("id") or "") | |
| return None | |
| def add_selected_to_favorites( | |
| selected_result_id: str | None, | |
| current_results: list[ResultState] | None, | |
| favorites: list[FavoriteState] | None, | |
| ) -> tuple[list[FavoriteState], list[GalleryItem], str, Path | None]: | |
| updated_favorites, message = add_result_to_favorites( | |
| selected_result_id, | |
| current_results, | |
| favorites, | |
| ) | |
| return ( | |
| updated_favorites, | |
| favorite_gallery(updated_favorites), | |
| message, | |
| prepare_favorites_export(updated_favorites), | |
| ) | |
| def clear_search( | |
| top_k: int, | |
| category_filter: str, | |
| min_similarity: float, | |
| ) -> tuple[ | |
| str, | |
| None, | |
| str, | |
| str, | |
| str, | |
| list[GalleryItem], | |
| None, | |
| list[ResultState], | |
| ]: | |
| return ( | |
| "", | |
| None, | |
| TEXT_QUERY_EMPTY_PREVIEW, | |
| IMAGE_QUERY_EMPTY_PREVIEW, | |
| initial_overview(top_k, category_filter, min_similarity), | |
| [], | |
| None, | |
| [], | |
| ) | |