multimodal-hw / app.py
AlekMan's picture
Update app.py
e16d30f verified
"""Interactive Gradio demo for multimodal art retrieval."""
from __future__ import annotations
from functools import lru_cache
import gradio as gr
from src.retrieval import RetrievalService
import pandas as pd
from pathlib import Path
@lru_cache(maxsize=1)
def get_service() -> RetrievalService:
return RetrievalService()
def _load_mapping(mapping_dir: Path) -> dict[str, dict[str, str]]:
mapping = {}
for name in ["artist", "genre", "style"]:
path = mapping_dir / f"{name}.csv"
try:
df = pd.read_csv(path)["name"].to_list()
mapping[name] = {i: v for i, v in enumerate(df)}
except Exception as e:
print(f"Error loading {path}: {e}")
mapping[name] = {}
return mapping
MAPPING = _load_mapping(Path("artifacts/type_mappings"))
def set_new_mapping_location(new_path: Path) -> None:
global MAPPING
MAPPING = _load_mapping(new_path)
def _format_results(results: list[dict]) -> pd.DataFrame:
gallery_items: list[tuple[str, str, str, str, str, str]] = []
for item in results:
artist_disp = MAPPING["artist"].get(item.get("artist", 0), "unknown")
style_disp = MAPPING["style"].get(item.get("style", 0), "unknown")
genre_disp = MAPPING["genre"].get(item.get("genre", 0), "unknown")
image_path = item.get("image_path", "")
caption = item.get("caption", "")
tags_display = ", ".join(item.get("tags", []) or [])
gallery_items.append((image_path, artist_disp, style_disp, genre_disp, caption, tags_display))
return pd.DataFrame(
gallery_items,
columns=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"],
)
def image_to_image(query_image: str | None, top_k: int) -> pd.DataFrame:
if not query_image:
return pd.DataFrame()
service = get_service()
results = service.search_similar_images(query_image, top_k=top_k)
return _format_results(results)
def caption_to_image(query_text: str, top_k: int) -> pd.DataFrame:
if not query_text.strip():
return pd.DataFrame()
service = get_service()
results = service.search_by_caption(query_text, top_k=top_k)
return _format_results(results)
def omni_to_image(
query_text: str,
styles: list[str],
genres: list[str],
tags: list[str],
top_k: int,
) -> pd.DataFrame:
if not any([query_text.strip(), styles, genres, tags]):
return pd.DataFrame()
service = get_service()
results = service.search_omni(
text_query=query_text.strip() or None,
styles=styles,
genres=genres,
extra_tags=tags,
top_k=top_k,
)
return _format_results(results)
def build_demo() -> gr.Blocks:
service = get_service()
metadata = service.metadata.reset_index()
styles = sorted({style for style in metadata["style"].dropna().unique() if style})
genres = sorted({genre for genre in metadata["genre"].dropna().unique() if genre})
tags = service.omni_tags
style_choices = [(MAPPING["style"].get(style, "unknown"), style) for style in styles]
genre_choices = [(MAPPING["genre"].get(genre, "unknown"), genre) for genre in genres]
with gr.Blocks(title="WikiArt Multimodal Retrieval") as demo:
gr.Markdown(
"""
# WikiArt Multimodal Retrieval
"""
)
with gr.Tab("Image -> Image"):
with gr.Row():
image_input = gr.Image(type="filepath", label="Запросное изображение")
top_k_slider = gr.Slider(3, 20, value=10, step=1, label="Top-K")
image_df = gr.DataFrame(headers=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], datatype=["str", "str", "str", "str", "str", "str"], label="Результаты (номер сэмпла соответствует номеру в датасете)")
run_btn = gr.Button("Найти похожие")
run_btn.click(
image_to_image,
inputs=[image_input, top_k_slider],
outputs=image_df,
)
with gr.Tab("Caption -> Image"):
with gr.Row():
caption_input = gr.Textbox(label="Текстовый запрос", lines=3)
caption_topk = gr.Slider(3, 20, value=10, step=1, label="Top-K")
caption_df = gr.DataFrame(headers=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], datatype=["str", "str", "str", "str", "str", "str"], label="Результаты (номер сэмпла соответствует номеру в датасете)")
caption_run = gr.Button("Найти по описанию")
caption_run.click(
caption_to_image,
inputs=[caption_input, caption_topk],
outputs=caption_df,
)
with gr.Tab("Omni Search"):
query_box = gr.Textbox(label="Свободный запрос", lines=2)
with gr.Row():
style_select = gr.CheckboxGroup(choices=style_choices, label="Стиль")
genre_select = gr.CheckboxGroup(choices=genre_choices, label="Жанр")
tag_select = gr.CheckboxGroup(choices=tags, label="Zero-Shot теги")
omni_topk = gr.Slider(3, 20, value=10, step=1, label="Top-K")
omni_df = gr.DataFrame(headers=["Image Path", "Artist", "Style", "Genre", "Caption", "Tags"], datatype=["str", "str", "str", "str", "str", "str"], label="Результаты (номер сэмпла соответствует номеру в датасете)")
omni_run = gr.Button("Объединить запрос")
omni_run.click(
omni_to_image,
inputs=[query_box, style_select, genre_select, tag_select, omni_topk],
outputs=omni_df,
)
return demo
def main() -> None:
demo = build_demo()
demo.launch(ssr_mode=False)
if __name__ == "__main__":
main()