| from pathlib import Path |
| from pathlib import Path |
| from sqlite3 import Row |
| from typing import Any, Literal, Sequence |
| from urllib.parse import quote |
|
|
| import numpy as np |
| from PIL import Image |
| from pydantic import json |
| from sentence_transformers import SentenceTransformer |
|
|
| SRC_PATH = Path(__file__).resolve().parents[1] |
|
|
| def load_image_paths(data_dir: Path) -> list[Path]: |
| paths = sorted(data_dir.glob("*.jpg")) |
|
|
| if not paths: |
| raise FileNotFoundError(f"No JPG images found in {data_dir}") |
|
|
| return paths |
|
|
| def encode_images( |
| model: SentenceTransformer, |
| images: list[Image.Image], |
| batch_size: int = 16, |
| ) -> np.ndarray: |
| return model.encode( |
| images, |
| batch_size=batch_size, |
| normalize_embeddings=True, |
| convert_to_numpy=True, |
| show_progress_bar=True, |
| ) |
|
|
| def encode_texts( |
| model: SentenceTransformer, |
| texts: list[str], |
| ) -> np.ndarray: |
| return model.encode( |
| texts, |
| normalize_embeddings=True, |
| convert_to_numpy=True, |
| ) |
|
|
| def _public_file_url(path: Path) -> str: |
| absolute_path = str(path.resolve()) |
|
|
| return ( |
| "/gradio_api/file=" |
| + quote(absolute_path, safe="/:") |
| ) |
|
|
| def _row_to_item(row: Row) -> dict[str, Any]: |
| path = Path(row["path"]) |
|
|
| return { |
| "id": row["id"], |
| "name": row["name"], |
| "category": row["category"], |
| "available": bool(row["available"]), |
| "url": _public_file_url(path), |
| } |
|
|
| def normalize_rows(values: Any) -> np.ndarray: |
| values = np.asarray(values, dtype=np.float32) |
|
|
| if values.ndim == 1: |
| values = values[None, :] |
|
|
| norms = np.linalg.norm(values, axis=1, keepdims=True) |
|
|
| if np.any(norms == 0): |
| raise ValueError("Received a zero embedding") |
|
|
| return values / norms |
|
|
|
|
| def deserialize_embedding( |
| blob: bytes | bytearray | memoryview, |
| embedding_dim: int, |
| embedding_dtype: str, |
| ) -> np.ndarray: |
| if isinstance(blob, memoryview): |
| blob = blob.tobytes() |
|
|
| dtype = np.dtype(embedding_dtype) |
|
|
| embedding = np.frombuffer( |
| blob, |
| dtype=dtype, |
| ) |
|
|
| if embedding.size != embedding_dim: |
| raise ValueError( |
| "Invalid stored embedding: " |
| f"expected dimension {embedding_dim}, " |
| f"got {embedding.size}" |
| ) |
|
|
| return embedding.astype(np.float32, copy=False) |
|
|
|
|
| def load_outfit_embeddings( |
| conn, |
| outfit: Sequence[dict[str, Any]], |
| embedding_type: Literal["clip", "outfit"], |
| ) -> np.ndarray: |
| outfit_ids = [str(item["id"]) for item in outfit] |
|
|
| if not outfit_ids: |
| raise ValueError("Cannot load embeddings for an empty outfit") |
|
|
| if embedding_type not in {"clip", "outfit"}: |
| raise ValueError( |
| f"Unsupported embedding type: {embedding_type}" |
| ) |
|
|
| placeholders = ", ".join("?" for _ in outfit_ids) |
|
|
| rows = conn.execute( |
| f""" |
| SELECT |
| id, |
| {embedding_type}_image_embedding, |
| {embedding_type}_embedding_dim, |
| {embedding_type}_embedding_dtype |
| FROM metadata |
| WHERE id IN ({placeholders}) |
| """, |
| outfit_ids, |
| ).fetchall() |
|
|
| embeddings_by_id = { |
| str(row["id"]): deserialize_embedding( |
| blob=row[f"{embedding_type}_image_embedding"], |
| embedding_dim=row[f"{embedding_type}_embedding_dim"], |
| embedding_dtype=row[f"{embedding_type}_embedding_dtype"], |
| ) |
| for row in rows |
| } |
|
|
| missing_ids = [ |
| item_id |
| for item_id in outfit_ids |
| if item_id not in embeddings_by_id |
| ] |
|
|
| if missing_ids: |
| raise ValueError( |
| f"Missing {embedding_type} embeddings for items: " |
| f"{missing_ids}" |
| ) |
|
|
| return normalize_rows( |
| np.stack([ |
| embeddings_by_id[item_id] |
| for item_id in outfit_ids |
| ]) |
| ) |