from pathlib import Path from pathlib import Path from sqlite3 import Row from typing import Any, Literal, Sequence from urllib.parse import quote import numpy as np from PIL import Image from pydantic import json from sentence_transformers import SentenceTransformer SRC_PATH = Path(__file__).resolve().parents[1] def load_image_paths(data_dir: Path) -> list[Path]: paths = sorted(data_dir.glob("*.jpg")) if not paths: raise FileNotFoundError(f"No JPG images found in {data_dir}") return paths def encode_images( model: SentenceTransformer, images: list[Image.Image], batch_size: int = 16, ) -> np.ndarray: return model.encode( images, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True, ) def encode_texts( model: SentenceTransformer, texts: list[str], ) -> np.ndarray: return model.encode( texts, normalize_embeddings=True, convert_to_numpy=True, ) def _public_file_url(path: Path) -> str: absolute_path = str(path.resolve()) return ( "/gradio_api/file=" + quote(absolute_path, safe="/:") ) def _row_to_item(row: Row) -> dict[str, Any]: path = Path(row["path"]) return { "id": row["id"], "name": row["name"], "category": row["category"], "available": bool(row["available"]), "url": _public_file_url(path), } def normalize_rows(values: Any) -> np.ndarray: values = np.asarray(values, dtype=np.float32) if values.ndim == 1: values = values[None, :] norms = np.linalg.norm(values, axis=1, keepdims=True) if np.any(norms == 0): raise ValueError("Received a zero embedding") return values / norms def deserialize_embedding( blob: bytes | bytearray | memoryview, embedding_dim: int, embedding_dtype: str, ) -> np.ndarray: if isinstance(blob, memoryview): blob = blob.tobytes() dtype = np.dtype(embedding_dtype) embedding = np.frombuffer( blob, dtype=dtype, ) if embedding.size != embedding_dim: raise ValueError( "Invalid stored embedding: " f"expected dimension {embedding_dim}, " f"got {embedding.size}" ) return embedding.astype(np.float32, copy=False) def load_outfit_embeddings( conn, outfit: Sequence[dict[str, Any]], embedding_type: Literal["clip", "outfit"], ) -> np.ndarray: outfit_ids = [str(item["id"]) for item in outfit] if not outfit_ids: raise ValueError("Cannot load embeddings for an empty outfit") if embedding_type not in {"clip", "outfit"}: raise ValueError( f"Unsupported embedding type: {embedding_type}" ) placeholders = ", ".join("?" for _ in outfit_ids) rows = conn.execute( f""" SELECT id, {embedding_type}_image_embedding, {embedding_type}_embedding_dim, {embedding_type}_embedding_dtype FROM metadata WHERE id IN ({placeholders}) """, outfit_ids, ).fetchall() embeddings_by_id = { str(row["id"]): deserialize_embedding( blob=row[f"{embedding_type}_image_embedding"], embedding_dim=row[f"{embedding_type}_embedding_dim"], embedding_dtype=row[f"{embedding_type}_embedding_dtype"], ) for row in rows } missing_ids = [ item_id for item_id in outfit_ids if item_id not in embeddings_by_id ] if missing_ids: raise ValueError( f"Missing {embedding_type} embeddings for items: " f"{missing_ids}" ) return normalize_rows( np.stack([ embeddings_by_id[item_id] for item_id in outfit_ids ]) )