Spaces:
Sleeping
Sleeping
| """Interactive Gradio demo for multimodal art retrieval.""" | |
| from __future__ import annotations | |
| from functools import lru_cache | |
| import gradio as gr | |
| from src.retrieval import RetrievalService | |
| import pandas as pd | |
| from pathlib import Path | |
| def get_service() -> RetrievalService: | |
| return RetrievalService() | |
| def _load_mapping(mapping_dir: Path) -> dict[str, dict[str, str]]: | |
| mapping = {} | |
| for name in ["artist", "genre", "style"]: | |
| path = mapping_dir / f"{name}.csv" | |
| try: | |
| df = pd.read_csv(path)["name"].to_list() | |
| mapping[name] = {i: v for i, v in enumerate(df)} | |
| except Exception as e: | |
| print(f"Error loading {path}: {e}") | |
| mapping[name] = {} | |
| return mapping | |
| MAPPING = _load_mapping(Path("artifacts/type_mappings")) | |
| def set_new_mapping_location(new_path: Path) -> None: | |
| global MAPPING | |
| MAPPING = _load_mapping(new_path) | |
| def _format_results(results: list[dict]) -> pd.DataFrame: | |
| gallery_items: list[tuple[str, str, str, str, str, str]] = [] | |
| for item in results: | |
| artist_disp = MAPPING["artist"].get(item.get("artist", 0), "unknown") | |
| style_disp = MAPPING["style"].get(item.get("style", 0), "unknown") | |
| genre_disp = MAPPING["genre"].get(item.get("genre", 0), "unknown") | |
| image_path = item.get("image_path", "") | |
| caption = item.get("caption", "") | |
| tags_display = ", ".join(item.get("tags", []) or []) | |
| gallery_items.append((image_path, artist_disp, style_disp, genre_disp, caption, tags_display)) | |
| return pd.DataFrame( | |
| gallery_items, | |
| columns=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], | |
| ) | |
| def image_to_image(query_image: str | None, top_k: int) -> pd.DataFrame: | |
| if not query_image: | |
| return pd.DataFrame() | |
| service = get_service() | |
| results = service.search_similar_images(query_image, top_k=top_k) | |
| return _format_results(results) | |
| def caption_to_image(query_text: str, top_k: int) -> pd.DataFrame: | |
| if not query_text.strip(): | |
| return pd.DataFrame() | |
| service = get_service() | |
| results = service.search_by_caption(query_text, top_k=top_k) | |
| return _format_results(results) | |
| def omni_to_image( | |
| query_text: str, | |
| styles: list[str], | |
| genres: list[str], | |
| tags: list[str], | |
| top_k: int, | |
| ) -> pd.DataFrame: | |
| if not any([query_text.strip(), styles, genres, tags]): | |
| return pd.DataFrame() | |
| service = get_service() | |
| results = service.search_omni( | |
| text_query=query_text.strip() or None, | |
| styles=styles, | |
| genres=genres, | |
| extra_tags=tags, | |
| top_k=top_k, | |
| ) | |
| return _format_results(results) | |
| def build_demo() -> gr.Blocks: | |
| service = get_service() | |
| metadata = service.metadata.reset_index() | |
| styles = sorted({style for style in metadata["style"].dropna().unique() if style}) | |
| genres = sorted({genre for genre in metadata["genre"].dropna().unique() if genre}) | |
| tags = service.omni_tags | |
| style_choices = [(MAPPING["style"].get(style, "unknown"), style) for style in styles] | |
| genre_choices = [(MAPPING["genre"].get(genre, "unknown"), genre) for genre in genres] | |
| with gr.Blocks(title="WikiArt Multimodal Retrieval") as demo: | |
| gr.Markdown( | |
| """ | |
| # WikiArt Multimodal Retrieval | |
| """ | |
| ) | |
| with gr.Tab("Image -> Image"): | |
| with gr.Row(): | |
| image_input = gr.Image(type="filepath", label="Запросное изображение") | |
| top_k_slider = gr.Slider(3, 20, value=10, step=1, label="Top-K") | |
| image_df = gr.DataFrame(headers=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], datatype=["str", "str", "str", "str", "str", "str"], label="Результаты (номер сэмпла соответствует номеру в датасете)") | |
| run_btn = gr.Button("Найти похожие") | |
| run_btn.click( | |
| image_to_image, | |
| inputs=[image_input, top_k_slider], | |
| outputs=image_df, | |
| ) | |
| with gr.Tab("Caption -> Image"): | |
| with gr.Row(): | |
| caption_input = gr.Textbox(label="Текстовый запрос", lines=3) | |
| caption_topk = gr.Slider(3, 20, value=10, step=1, label="Top-K") | |
| caption_df = gr.DataFrame(headers=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], datatype=["str", "str", "str", "str", "str", "str"], label="Результаты (номер сэмпла соответствует номеру в датасете)") | |
| caption_run = gr.Button("Найти по описанию") | |
| caption_run.click( | |
| caption_to_image, | |
| inputs=[caption_input, caption_topk], | |
| outputs=caption_df, | |
| ) | |
| with gr.Tab("Omni Search"): | |
| query_box = gr.Textbox(label="Свободный запрос", lines=2) | |
| with gr.Row(): | |
| style_select = gr.CheckboxGroup(choices=style_choices, label="Стиль") | |
| genre_select = gr.CheckboxGroup(choices=genre_choices, label="Жанр") | |
| tag_select = gr.CheckboxGroup(choices=tags, label="Zero-Shot теги") | |
| omni_topk = gr.Slider(3, 20, value=10, step=1, label="Top-K") | |
| omni_df = gr.DataFrame(headers=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], datatype=["str", "str", "str", "str", "str", "str"], label="Результаты (номер сэмпла соответствует номеру в датасете)") | |
| omni_run = gr.Button("Объединить запрос") | |
| omni_run.click( | |
| omni_to_image, | |
| inputs=[query_box, style_select, genre_select, tag_select, omni_topk], | |
| outputs=omni_df, | |
| ) | |
| return demo | |
| def main() -> None: | |
| demo = build_demo() | |
| demo.launch(ssr_mode=False) | |
| if __name__ == "__main__": | |
| main() | |