wardrobe-ai / src /utils /utils.py
elalber2000's picture
first commit
59830d4 verified
from pathlib import Path
from pathlib import Path
from sqlite3 import Row
from typing import Any, Literal, Sequence
from urllib.parse import quote
import numpy as np
from PIL import Image
from pydantic import json
from sentence_transformers import SentenceTransformer
SRC_PATH = Path(__file__).resolve().parents[1]
def load_image_paths(data_dir: Path) -> list[Path]:
paths = sorted(data_dir.glob("*.jpg"))
if not paths:
raise FileNotFoundError(f"No JPG images found in {data_dir}")
return paths
def encode_images(
model: SentenceTransformer,
images: list[Image.Image],
batch_size: int = 16,
) -> np.ndarray:
return model.encode(
images,
batch_size=batch_size,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
)
def encode_texts(
model: SentenceTransformer,
texts: list[str],
) -> np.ndarray:
return model.encode(
texts,
normalize_embeddings=True,
convert_to_numpy=True,
)
def _public_file_url(path: Path) -> str:
absolute_path = str(path.resolve())
return (
"/gradio_api/file="
+ quote(absolute_path, safe="/:")
)
def _row_to_item(row: Row) -> dict[str, Any]:
path = Path(row["path"])
return {
"id": row["id"],
"name": row["name"],
"category": row["category"],
"available": bool(row["available"]),
"url": _public_file_url(path),
}
def normalize_rows(values: Any) -> np.ndarray:
values = np.asarray(values, dtype=np.float32)
if values.ndim == 1:
values = values[None, :]
norms = np.linalg.norm(values, axis=1, keepdims=True)
if np.any(norms == 0):
raise ValueError("Received a zero embedding")
return values / norms
def deserialize_embedding(
blob: bytes | bytearray | memoryview,
embedding_dim: int,
embedding_dtype: str,
) -> np.ndarray:
if isinstance(blob, memoryview):
blob = blob.tobytes()
dtype = np.dtype(embedding_dtype)
embedding = np.frombuffer(
blob,
dtype=dtype,
)
if embedding.size != embedding_dim:
raise ValueError(
"Invalid stored embedding: "
f"expected dimension {embedding_dim}, "
f"got {embedding.size}"
)
return embedding.astype(np.float32, copy=False)
def load_outfit_embeddings(
conn,
outfit: Sequence[dict[str, Any]],
embedding_type: Literal["clip", "outfit"],
) -> np.ndarray:
outfit_ids = [str(item["id"]) for item in outfit]
if not outfit_ids:
raise ValueError("Cannot load embeddings for an empty outfit")
if embedding_type not in {"clip", "outfit"}:
raise ValueError(
f"Unsupported embedding type: {embedding_type}"
)
placeholders = ", ".join("?" for _ in outfit_ids)
rows = conn.execute(
f"""
SELECT
id,
{embedding_type}_image_embedding,
{embedding_type}_embedding_dim,
{embedding_type}_embedding_dtype
FROM metadata
WHERE id IN ({placeholders})
""",
outfit_ids,
).fetchall()
embeddings_by_id = {
str(row["id"]): deserialize_embedding(
blob=row[f"{embedding_type}_image_embedding"],
embedding_dim=row[f"{embedding_type}_embedding_dim"],
embedding_dtype=row[f"{embedding_type}_embedding_dtype"],
)
for row in rows
}
missing_ids = [
item_id
for item_id in outfit_ids
if item_id not in embeddings_by_id
]
if missing_ids:
raise ValueError(
f"Missing {embedding_type} embeddings for items: "
f"{missing_ids}"
)
return normalize_rows(
np.stack([
embeddings_by_id[item_id]
for item_id in outfit_ids
])
)