Spaces:
Sleeping
Sleeping
| """Runtime retrieval service utilities.""" | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from .config import CONFIG | |
| from .embeddings import ImageEncoder, TextEncoder | |
| from .indexers import AnnoyVectorIndex | |
| class RetrievalService: | |
| def __init__(self) -> None: | |
| cfg = CONFIG | |
| cfg.prepare() | |
| paths = cfg.paths | |
| def _parse_tags(value: object) -> list[str]: | |
| if isinstance(value, str): | |
| try: | |
| parsed = json.loads(value) | |
| if isinstance(parsed, list): | |
| return [str(item) for item in parsed] | |
| except json.JSONDecodeError: | |
| return [part.strip() for part in value.split(",") if part.strip()] | |
| elif isinstance(value, (list, tuple, set)): | |
| return [str(item) for item in value] | |
| return [] | |
| metadata_df = pd.read_parquet(paths.omni_metadata_path).set_index("id") | |
| if "tags" in metadata_df.columns: | |
| metadata_df["tags"] = metadata_df["tags"].apply(_parse_tags) | |
| else: | |
| metadata_df["tags"] = [[] for _ in range(len(metadata_df))] | |
| metadata_df["style"] = metadata_df["style"].fillna("unknown") | |
| metadata_df["genre"] = metadata_df["genre"].fillna("unknown") | |
| self.metadata = metadata_df | |
| self.id_to_idx = {sample_id: idx for idx, sample_id in enumerate(self.metadata.index)} | |
| self.omni_tag_matrix = np.load(paths.indexes_dir / "labels" / "omni_tag_embeddings.npy") | |
| with open(paths.indexes_dir / "labels" / "omni_tags.json", "r", encoding="utf-8") as fin: | |
| self.omni_tags: list[str] = json.load(fin) | |
| self.caption_embeddings = np.load(paths.embeddings_dir / "caption_embeddings.npy") | |
| self.image_embeddings = np.load(paths.embeddings_dir / "image_embeddings.npy") | |
| self.image_index = AnnoyVectorIndex( | |
| dimension=self.image_embeddings.shape[1], | |
| ) | |
| self.image_index.load(paths.indexes_dir / "image.ann") | |
| self.caption_index = AnnoyVectorIndex( | |
| dimension=self.caption_embeddings.shape[1], | |
| ) | |
| self.caption_index.load(paths.indexes_dir / "caption.ann") | |
| self.paths = paths | |
| self.image_encoder = ImageEncoder() | |
| self.text_encoder = TextEncoder() | |
| def _resolve_image(self, reference: str | Path | Image.Image) -> Image.Image: | |
| if isinstance(reference, Image.Image): | |
| return reference | |
| path = Path(reference) | |
| if not path.is_absolute(): | |
| path = self.paths.root / path | |
| image = Image.open(path) | |
| return image | |
| def _metadata_for(self, sample_id: str) -> dict: | |
| row = self.metadata.loc[sample_id].to_dict() | |
| row["id"] = sample_id | |
| image_path = row.get("image_path") | |
| if image_path: | |
| row["image_path"] = str((self.paths.root / image_path).resolve()) | |
| return row | |
| def search_similar_images(self, image: str | Path | Image.Image, top_k: int | None = None) -> list[dict]: | |
| candidate_image = self._resolve_image(image) | |
| embedding = self.image_encoder.encode([candidate_image]).numpy()[0] | |
| results = self.image_index.query(embedding, top_k=top_k) | |
| return [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results] | |
| def search_by_caption(self, caption: str, top_k: int | None = None) -> list[dict]: | |
| embedding = self.text_encoder.encode([caption]).numpy()[0] | |
| results = self.caption_index.query(embedding, top_k=top_k) | |
| return [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results] | |
| def search_omni( | |
| self, | |
| text_query: str | None = None, | |
| styles: Optional[Iterable[int]] = None, | |
| genres: Optional[Iterable[int]] = None, | |
| extra_tags: Optional[Iterable[str]] = None, | |
| top_k: int | None = None, | |
| ) -> list[dict]: | |
| candidate_ids = self._filter_candidates( | |
| styles=styles or [], | |
| genres=genres or [], | |
| extra_tags=extra_tags or [], | |
| ) | |
| if not candidate_ids: | |
| return [] | |
| candidate_vectors = self.caption_embeddings[[self.id_to_idx[sample_id] for sample_id in candidate_ids]] | |
| if text_query: | |
| temp_index = AnnoyVectorIndex(dimension=candidate_vectors.shape[1]) # Building runs quickly, so we can afford it in inference | |
| temp_index.build(candidate_vectors, candidate_ids) | |
| requested_top_k = top_k or CONFIG.index.top_k | |
| requested_top_k = min(requested_top_k, len(candidate_ids)) | |
| text_embedding = self.text_encoder.encode([text_query]).numpy()[0] | |
| results = temp_index.query(text_embedding, top_k=requested_top_k) | |
| else: | |
| results = [(sample_id, 0.0) for sample_id in candidate_ids[: top_k or len(candidate_ids)]] | |
| formatted = [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results] | |
| return formatted | |
| def _filter_candidates( | |
| self, | |
| styles: Iterable[int], | |
| genres: Iterable[int], | |
| extra_tags: Iterable[str], | |
| ) -> list[str]: | |
| df = self.metadata | |
| mask = pd.Series(True, index=df.index) | |
| if styles: | |
| mask &= df["style"].isin(list(styles)) | |
| if genres: | |
| mask &= df["genre"].isin(list(genres)) | |
| if extra_tags: | |
| required_tags = set(extra_tags) | |
| mask &= df["tags"].apply(lambda tags: required_tags.issubset(set(tags))) | |
| return df.index[mask].tolist() | |