Image-Retrieval-System / src /ui /callbacks.py
s1ngledoge's picture
upd
1c8d1ba
Raw
History Blame Contribute Delete
7.95 kB
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import gradio as gr
from PIL import Image
from src.config import ALL_CATEGORIES_LABEL, ConfigurationError, METADATA_CSV
from src.dataset import DatasetError, get_categories, load_metadata
from src.favorites import add_result_to_favorites, prepare_favorites_export
from src.runtime import AppRuntime
from src.ui.constants import IMAGE_QUERY_EMPTY_PREVIEW, TEXT_QUERY_EMPTY_PREVIEW
from src.ui.formatters import (
error_overview,
favorite_gallery,
image_query_preview,
initial_overview,
result_gallery,
result_to_state,
search_overview,
text_query_preview,
)
from src.ui.types import FavoriteState, GalleryItem, ResultState
logger = logging.getLogger(__name__)
def category_to_backend(category: str | None) -> str:
return category or ALL_CATEGORIES_LABEL
def current_categories() -> list[str]:
try:
categories = get_categories(METADATA_CSV)
except Exception as exc:
logger.warning("Could not load categories: %s", exc)
return [ALL_CATEGORIES_LABEL]
return categories or [ALL_CATEGORIES_LABEL]
def metadata_status() -> tuple[bool, str]:
if not METADATA_CSV.exists():
return (
False,
"metadata.csv was not found. Put images under data/images/<category>/ and run `python scripts/build_metadata.py`.",
)
try:
metadata = load_metadata(METADATA_CSV)
except Exception as exc:
return False, f"metadata.csv could not be loaded: {exc}"
if metadata.empty:
return (
False,
"metadata.csv does not contain any images. Add images and run `python scripts/build_metadata.py`.",
)
return True, ""
def configuration_message(exc: Exception) -> str:
text = str(exc)
if "UPSTASH_VECTOR" in text or "Upstash Vector" in text:
return "Upstash Vector configuration was not detected. Check UPSTASH_VECTOR_REST_URL and UPSTASH_VECTOR_REST_TOKEN in .env."
return f"Configuration could not be loaded: {text}"
def dataset_message(exc: Exception) -> str:
text = str(exc)
if "metadata.csv" in text:
return f"Dataset configuration error: {text}"
return f"Dataset could not be read: {text}"
class AppCallbacks:
def __init__(self, runtime: AppRuntime) -> None:
self.runtime = runtime
@staticmethod
def preview_text_query(text: str) -> str:
return text_query_preview(text)
@staticmethod
def preview_image_query(image: Image.Image | None) -> str:
return image_query_preview(image is not None)
@staticmethod
def select_text_mode() -> str:
return "text"
@staticmethod
def select_image_mode() -> str:
return "image"
def run_search(
self,
query_mode: str,
text_query: str,
image_query: Image.Image | None,
top_k: int,
category_filter: str,
min_similarity: float,
) -> tuple[str, list[GalleryItem], str | None, list[ResultState]]:
mode = query_mode if query_mode in {"text", "image"} else "text"
text = (text_query or "").strip()
category = category_to_backend(category_filter)
threshold = float(min_similarity or 0)
if mode == "text":
if not text:
return (
error_overview(
"Search cannot be completed",
"Enter an image description before running text search.",
),
[],
None,
[],
)
current_query = text
elif image_query is not None:
current_query = "Uploaded image"
else:
return (
error_overview(
"Search cannot be completed",
"Upload a query image before running image search.",
),
[],
None,
[],
)
ready, message = metadata_status()
if not ready:
return error_overview("Dataset error", message), [], None, []
try:
service = self.runtime.get_search_service()
if mode == "text":
results = service.search_by_text(text, int(top_k), category)
else:
assert image_query is not None
results = service.search_by_image(image_query, int(top_k), category)
except ConfigurationError as exc:
self.runtime.reset_search_service()
return error_overview("Configuration error", configuration_message(exc)), [], None, []
except (DatasetError, FileNotFoundError, ValueError) as exc:
self.runtime.reset_search_service()
return error_overview("Search cannot be completed", dataset_message(exc)), [], None, []
except Exception:
self.runtime.reset_search_service()
logger.exception("Search failed")
preload_note = (
" The CLIP model also failed to preload earlier."
if self.runtime.preload_error
else ""
)
return (
error_overview(
"Search cannot be completed",
f"An unexpected error occurred during search. Check the application logs for details.{preload_note}",
),
[],
None,
[],
)
filtered_results = [
result_to_state(result)
for result in results
if float(result.score) >= threshold
]
empty_message = None
if not filtered_results:
empty_message = (
"No images matched the current filters. Try changing the query, category, or minimum similarity."
)
return (
search_overview(
current_query=current_query,
returned_results=len(filtered_results),
message=empty_message,
),
result_gallery(filtered_results),
None,
filtered_results,
)
@staticmethod
def select_result_from_gallery(
current_results: list[ResultState] | None,
evt: gr.SelectData,
) -> str | None:
results = list(current_results or [])
index: Any = evt.index
if isinstance(index, (list, tuple)):
index = index[0] if index else None
try:
selected_index = int(index)
except (TypeError, ValueError):
return None
if 0 <= selected_index < len(results):
return str(results[selected_index].get("id") or "")
return None
@staticmethod
def add_selected_to_favorites(
selected_result_id: str | None,
current_results: list[ResultState] | None,
favorites: list[FavoriteState] | None,
) -> tuple[list[FavoriteState], list[GalleryItem], str, Path | None]:
updated_favorites, message = add_result_to_favorites(
selected_result_id,
current_results,
favorites,
)
return (
updated_favorites,
favorite_gallery(updated_favorites),
message,
prepare_favorites_export(updated_favorites),
)
@staticmethod
def clear_search(
top_k: int,
category_filter: str,
min_similarity: float,
) -> tuple[
str,
None,
str,
str,
str,
list[GalleryItem],
None,
list[ResultState],
]:
return (
"",
None,
TEXT_QUERY_EMPTY_PREVIEW,
IMAGE_QUERY_EMPTY_PREVIEW,
initial_overview(top_k, category_filter, min_similarity),
[],
None,
[],
)