multimodal-hw / src /retrieval.py
AlekMan's picture
Upload 41 files
61d3625 verified
"""Runtime retrieval service utilities."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional
import numpy as np
import pandas as pd
from PIL import Image
from .config import CONFIG
from .embeddings import ImageEncoder, TextEncoder
from .indexers import AnnoyVectorIndex
class RetrievalService:
def __init__(self) -> None:
cfg = CONFIG
cfg.prepare()
paths = cfg.paths
def _parse_tags(value: object) -> list[str]:
if isinstance(value, str):
try:
parsed = json.loads(value)
if isinstance(parsed, list):
return [str(item) for item in parsed]
except json.JSONDecodeError:
return [part.strip() for part in value.split(",") if part.strip()]
elif isinstance(value, (list, tuple, set)):
return [str(item) for item in value]
return []
metadata_df = pd.read_parquet(paths.omni_metadata_path).set_index("id")
if "tags" in metadata_df.columns:
metadata_df["tags"] = metadata_df["tags"].apply(_parse_tags)
else:
metadata_df["tags"] = [[] for _ in range(len(metadata_df))]
metadata_df["style"] = metadata_df["style"].fillna("unknown")
metadata_df["genre"] = metadata_df["genre"].fillna("unknown")
self.metadata = metadata_df
self.id_to_idx = {sample_id: idx for idx, sample_id in enumerate(self.metadata.index)}
self.omni_tag_matrix = np.load(paths.indexes_dir / "labels" / "omni_tag_embeddings.npy")
with open(paths.indexes_dir / "labels" / "omni_tags.json", "r", encoding="utf-8") as fin:
self.omni_tags: list[str] = json.load(fin)
self.caption_embeddings = np.load(paths.embeddings_dir / "caption_embeddings.npy")
self.image_embeddings = np.load(paths.embeddings_dir / "image_embeddings.npy")
self.image_index = AnnoyVectorIndex(
dimension=self.image_embeddings.shape[1],
)
self.image_index.load(paths.indexes_dir / "image.ann")
self.caption_index = AnnoyVectorIndex(
dimension=self.caption_embeddings.shape[1],
)
self.caption_index.load(paths.indexes_dir / "caption.ann")
self.paths = paths
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
def _resolve_image(self, reference: str | Path | Image.Image) -> Image.Image:
if isinstance(reference, Image.Image):
return reference
path = Path(reference)
if not path.is_absolute():
path = self.paths.root / path
image = Image.open(path)
return image
def _metadata_for(self, sample_id: str) -> dict:
row = self.metadata.loc[sample_id].to_dict()
row["id"] = sample_id
image_path = row.get("image_path")
if image_path:
row["image_path"] = str((self.paths.root / image_path).resolve())
return row
def search_similar_images(self, image: str | Path | Image.Image, top_k: int | None = None) -> list[dict]:
candidate_image = self._resolve_image(image)
embedding = self.image_encoder.encode([candidate_image]).numpy()[0]
results = self.image_index.query(embedding, top_k=top_k)
return [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results]
def search_by_caption(self, caption: str, top_k: int | None = None) -> list[dict]:
embedding = self.text_encoder.encode([caption]).numpy()[0]
results = self.caption_index.query(embedding, top_k=top_k)
return [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results]
def search_omni(
self,
text_query: str | None = None,
styles: Optional[Iterable[int]] = None,
genres: Optional[Iterable[int]] = None,
extra_tags: Optional[Iterable[str]] = None,
top_k: int | None = None,
) -> list[dict]:
candidate_ids = self._filter_candidates(
styles=styles or [],
genres=genres or [],
extra_tags=extra_tags or [],
)
if not candidate_ids:
return []
candidate_vectors = self.caption_embeddings[[self.id_to_idx[sample_id] for sample_id in candidate_ids]]
if text_query:
temp_index = AnnoyVectorIndex(dimension=candidate_vectors.shape[1]) # Building runs quickly, so we can afford it in inference
temp_index.build(candidate_vectors, candidate_ids)
requested_top_k = top_k or CONFIG.index.top_k
requested_top_k = min(requested_top_k, len(candidate_ids))
text_embedding = self.text_encoder.encode([text_query]).numpy()[0]
results = temp_index.query(text_embedding, top_k=requested_top_k)
else:
results = [(sample_id, 0.0) for sample_id in candidate_ids[: top_k or len(candidate_ids)]]
formatted = [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results]
return formatted
def _filter_candidates(
self,
styles: Iterable[int],
genres: Iterable[int],
extra_tags: Iterable[str],
) -> list[str]:
df = self.metadata
mask = pd.Series(True, index=df.index)
if styles:
mask &= df["style"].isin(list(styles))
if genres:
mask &= df["genre"].isin(list(genres))
if extra_tags:
required_tags = set(extra_tags)
mask &= df["tags"].apply(lambda tags: required_tags.issubset(set(tags)))
return df.index[mask].tolist()