Spaces:
Sleeping
Sleeping
| """ | |
| Image embedding: local CLIP (with Mac MPS / CUDA) or Azure AI Vision multimodal. | |
| Same model must be used at index time and query time for retrieval. | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Union | |
| import numpy as np | |
| from photo_editor.config import get_settings | |
| logger = logging.getLogger(__name__) | |
| class AzureVisionEmbedder: | |
| """Encode images via Azure AI Vision retrieval:vectorizeImage API.""" | |
| def __init__( | |
| self, | |
| endpoint: str, | |
| key: str, | |
| model_version: str = "2023-04-15", | |
| ): | |
| self.endpoint = endpoint.rstrip("/") | |
| self.key = key | |
| self.model_version = model_version | |
| self._dim: Union[int, None] = None | |
| def dimension(self) -> int: | |
| if self._dim is not None: | |
| return self._dim | |
| # Get dimension from one dummy call (or set from known model: 1024 for 2023-04-15) | |
| import io | |
| from PIL import Image | |
| dummy = np.zeros((224, 224, 3), dtype=np.uint8) | |
| pil = Image.fromarray(dummy) | |
| buf = io.BytesIO() | |
| pil.save(buf, format="JPEG") | |
| v = self._vectorize_image_bytes(buf.getvalue()) | |
| self._dim = len(v) | |
| return self._dim | |
| def _vectorize_image_bytes(self, image_bytes: bytes) -> List[float]: | |
| import json | |
| import urllib.error | |
| import urllib.request | |
| # Production API only. 2023-02-01-preview returns 410 Gone (deprecated). | |
| # Docs: https://learn.microsoft.com/en-us/rest/api/computervision/vectorize/image-stream | |
| # Path: POST <endpoint>/computervision/retrieval:vectorizeImage?overload=stream&model-version=...&api-version=2024-02-01 | |
| url = ( | |
| f"{self.endpoint}/computervision/retrieval:vectorizeImage" | |
| f"?overload=stream&model-version={self.model_version}&api-version=2024-02-01" | |
| ) | |
| req = urllib.request.Request(url, data=image_bytes, method="POST") | |
| req.add_header("Ocp-Apim-Subscription-Key", self.key) | |
| req.add_header("Content-Type", "image/jpeg") | |
| try: | |
| with urllib.request.urlopen(req) as resp: | |
| data = json.loads(resp.read().decode()) | |
| return data["vector"] | |
| except urllib.error.HTTPError as e: | |
| try: | |
| body = e.fp.read().decode() if e.fp else "(no body)" | |
| except Exception: | |
| body = "(could not read body)" | |
| logger.error( | |
| "Azure Vision vectorizeImage failed: HTTP %s %s. %s", | |
| e.code, | |
| e.reason, | |
| body, | |
| exc_info=False, | |
| ) | |
| raise RuntimeError( | |
| f"Azure Vision vectorizeImage failed: HTTP {e.code} {e.reason}. {body}" | |
| ) from e | |
| def encode_images(self, images: List[np.ndarray]) -> np.ndarray: | |
| import io | |
| from PIL import Image | |
| out = [] | |
| for im in images: | |
| pil = Image.fromarray((np.clip(im, 0, 1) * 255).astype(np.uint8)) | |
| buf = io.BytesIO() | |
| pil.save(buf, format="JPEG") | |
| vec = self._vectorize_image_bytes(buf.getvalue()) | |
| out.append(vec) | |
| return np.array(out, dtype=np.float32) | |
| def encode_image(self, image: np.ndarray) -> np.ndarray: | |
| vecs = self.encode_images([image]) | |
| return vecs[0] | |
| class ImageEmbedder: | |
| """Encode images to fixed-size vectors for vector search.""" | |
| def __init__( | |
| self, | |
| model_name: str = "openai/clip-vit-base-patch32", | |
| device: str = "cpu", | |
| ): | |
| self.model_name = model_name | |
| self.device = device | |
| self._model = None | |
| self._processor = None | |
| def _load(self) -> None: | |
| if self._model is not None: | |
| return | |
| try: | |
| from transformers import CLIPModel, CLIPProcessor | |
| except ImportError as e: | |
| raise ImportError( | |
| "transformers and torch required for CLIP. " | |
| "Install with: pip install transformers torch" | |
| ) from e | |
| self._processor = CLIPProcessor.from_pretrained(self.model_name) | |
| self._model = CLIPModel.from_pretrained(self.model_name) | |
| self._model.to(self.device) | |
| self._model.eval() | |
| def dimension(self) -> int: | |
| self._load() | |
| return self._model.config.projection_dim | |
| def encode_images(self, images: List[np.ndarray]) -> np.ndarray: | |
| """ | |
| images: list of HWC float32 [0,1] RGB arrays (e.g. from dng_to_rgb). | |
| Returns (N, dim) float32 numpy. | |
| """ | |
| import torch | |
| from PIL import Image | |
| self._load() | |
| # CLIPProcessor expects PIL Images | |
| pil_list = [ | |
| Image.fromarray((np.clip(im, 0, 1) * 255).astype(np.uint8)) | |
| for im in images | |
| ] | |
| inputs = self._processor(images=pil_list, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = self._model.get_image_features(**inputs) | |
| # Newer transformers return BaseModelOutputWithPooling; use pooled tensor | |
| t = getattr(out, "pooler_output", None) if hasattr(out, "pooler_output") else None | |
| if t is None and hasattr(out, "last_hidden_state"): | |
| t = out.last_hidden_state[:, 0] | |
| elif t is None: | |
| t = out | |
| return t.detach().cpu().float().numpy() | |
| def encode_image(self, image: np.ndarray) -> np.ndarray: | |
| """Single image -> (dim,) vector.""" | |
| vecs = self.encode_images([image]) | |
| return vecs[0] | |
| def _default_device() -> str: | |
| import torch | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| return "mps" # Mac GPU (Apple Silicon) | |
| return "cpu" | |
| def get_embedder(): | |
| """Return Azure Vision embedder if configured and available in region; else local CLIP.""" | |
| s = get_settings() | |
| if s.azure_vision_configured(): | |
| try: | |
| emb = AzureVisionEmbedder( | |
| endpoint=s.azure_vision_endpoint, | |
| key=s.azure_vision_key, | |
| model_version=s.azure_vision_model_version or "2023-04-15", | |
| ) | |
| _ = emb.dimension # one call to verify region supports the API | |
| logger.info( | |
| "Using Azure Vision embedder (endpoint=%s, model_version=%s)", | |
| s.azure_vision_endpoint, | |
| s.azure_vision_model_version or "2023-04-15", | |
| ) | |
| return emb | |
| except RuntimeError as e: | |
| err = str(e) | |
| if "not enabled in this region" in err or "InvalidRequest" in err: | |
| logger.warning( | |
| "Azure Vision retrieval/vectorize not available (region/InvalidRequest). " | |
| "Falling back to local CLIP (Mac MPS/CUDA/CPU). Error: %s", | |
| err, | |
| ) | |
| return ImageEmbedder( | |
| model_name=s.embedding_model, | |
| device=_default_device(), | |
| ) | |
| logger.error("Azure Vision embedder failed: %s", err) | |
| raise | |
| logger.info( | |
| "Azure Vision not configured; using local CLIP (model=%s)", | |
| s.embedding_model, | |
| ) | |
| return ImageEmbedder( | |
| model_name=s.embedding_model, | |
| device=_default_device(), | |
| ) | |