File size: 7,487 Bytes
5d36f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Image embedding: local CLIP (with Mac MPS / CUDA) or Azure AI Vision multimodal.
Same model must be used at index time and query time for retrieval.
"""
import logging
from pathlib import Path
from typing import List, Union

import numpy as np

from photo_editor.config import get_settings

logger = logging.getLogger(__name__)


class AzureVisionEmbedder:
    """Encode images via Azure AI Vision retrieval:vectorizeImage API."""

    def __init__(
        self,
        endpoint: str,
        key: str,
        model_version: str = "2023-04-15",
    ):
        self.endpoint = endpoint.rstrip("/")
        self.key = key
        self.model_version = model_version
        self._dim: Union[int, None] = None

    @property
    def dimension(self) -> int:
        if self._dim is not None:
            return self._dim
        # Get dimension from one dummy call (or set from known model: 1024 for 2023-04-15)
        import io
        from PIL import Image
        dummy = np.zeros((224, 224, 3), dtype=np.uint8)
        pil = Image.fromarray(dummy)
        buf = io.BytesIO()
        pil.save(buf, format="JPEG")
        v = self._vectorize_image_bytes(buf.getvalue())
        self._dim = len(v)
        return self._dim

    def _vectorize_image_bytes(self, image_bytes: bytes) -> List[float]:
        import json
        import urllib.error
        import urllib.request

        # Production API only. 2023-02-01-preview returns 410 Gone (deprecated).
        # Docs: https://learn.microsoft.com/en-us/rest/api/computervision/vectorize/image-stream
        # Path: POST <endpoint>/computervision/retrieval:vectorizeImage?overload=stream&model-version=...&api-version=2024-02-01
        url = (
            f"{self.endpoint}/computervision/retrieval:vectorizeImage"
            f"?overload=stream&model-version={self.model_version}&api-version=2024-02-01"
        )
        req = urllib.request.Request(url, data=image_bytes, method="POST")
        req.add_header("Ocp-Apim-Subscription-Key", self.key)
        req.add_header("Content-Type", "image/jpeg")
        try:
            with urllib.request.urlopen(req) as resp:
                data = json.loads(resp.read().decode())
            return data["vector"]
        except urllib.error.HTTPError as e:
            try:
                body = e.fp.read().decode() if e.fp else "(no body)"
            except Exception:
                body = "(could not read body)"
            logger.error(
                "Azure Vision vectorizeImage failed: HTTP %s %s. %s",
                e.code,
                e.reason,
                body,
                exc_info=False,
            )
            raise RuntimeError(
                f"Azure Vision vectorizeImage failed: HTTP {e.code} {e.reason}. {body}"
            ) from e

    def encode_images(self, images: List[np.ndarray]) -> np.ndarray:
        import io
        from PIL import Image
        out = []
        for im in images:
            pil = Image.fromarray((np.clip(im, 0, 1) * 255).astype(np.uint8))
            buf = io.BytesIO()
            pil.save(buf, format="JPEG")
            vec = self._vectorize_image_bytes(buf.getvalue())
            out.append(vec)
        return np.array(out, dtype=np.float32)

    def encode_image(self, image: np.ndarray) -> np.ndarray:
        vecs = self.encode_images([image])
        return vecs[0]


class ImageEmbedder:
    """Encode images to fixed-size vectors for vector search."""

    def __init__(
        self,
        model_name: str = "openai/clip-vit-base-patch32",
        device: str = "cpu",
    ):
        self.model_name = model_name
        self.device = device
        self._model = None
        self._processor = None

    def _load(self) -> None:
        if self._model is not None:
            return
        try:
            from transformers import CLIPModel, CLIPProcessor
        except ImportError as e:
            raise ImportError(
                "transformers and torch required for CLIP. "
                "Install with: pip install transformers torch"
            ) from e
        self._processor = CLIPProcessor.from_pretrained(self.model_name)
        self._model = CLIPModel.from_pretrained(self.model_name)
        self._model.to(self.device)
        self._model.eval()

    @property
    def dimension(self) -> int:
        self._load()
        return self._model.config.projection_dim

    def encode_images(self, images: List[np.ndarray]) -> np.ndarray:
        """
        images: list of HWC float32 [0,1] RGB arrays (e.g. from dng_to_rgb).
        Returns (N, dim) float32 numpy.
        """
        import torch
        from PIL import Image
        self._load()
        # CLIPProcessor expects PIL Images
        pil_list = [
            Image.fromarray((np.clip(im, 0, 1) * 255).astype(np.uint8))
            for im in images
        ]
        inputs = self._processor(images=pil_list, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            out = self._model.get_image_features(**inputs)
        # Newer transformers return BaseModelOutputWithPooling; use pooled tensor
        t = getattr(out, "pooler_output", None) if hasattr(out, "pooler_output") else None
        if t is None and hasattr(out, "last_hidden_state"):
            t = out.last_hidden_state[:, 0]
        elif t is None:
            t = out
        return t.detach().cpu().float().numpy()

    def encode_image(self, image: np.ndarray) -> np.ndarray:
        """Single image -> (dim,) vector."""
        vecs = self.encode_images([image])
        return vecs[0]


def _default_device() -> str:
    import torch
    if torch.cuda.is_available():
        return "cuda"
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return "mps"  # Mac GPU (Apple Silicon)
    return "cpu"


def get_embedder():
    """Return Azure Vision embedder if configured and available in region; else local CLIP."""
    s = get_settings()
    if s.azure_vision_configured():
        try:
            emb = AzureVisionEmbedder(
                endpoint=s.azure_vision_endpoint,
                key=s.azure_vision_key,
                model_version=s.azure_vision_model_version or "2023-04-15",
            )
            _ = emb.dimension  # one call to verify region supports the API
            logger.info(
                "Using Azure Vision embedder (endpoint=%s, model_version=%s)",
                s.azure_vision_endpoint,
                s.azure_vision_model_version or "2023-04-15",
            )
            return emb
        except RuntimeError as e:
            err = str(e)
            if "not enabled in this region" in err or "InvalidRequest" in err:
                logger.warning(
                    "Azure Vision retrieval/vectorize not available (region/InvalidRequest). "
                    "Falling back to local CLIP (Mac MPS/CUDA/CPU). Error: %s",
                    err,
                )
                return ImageEmbedder(
                    model_name=s.embedding_model,
                    device=_default_device(),
                )
            logger.error("Azure Vision embedder failed: %s", err)
            raise
    logger.info(
        "Azure Vision not configured; using local CLIP (model=%s)",
        s.embedding_model,
    )
    return ImageEmbedder(
        model_name=s.embedding_model,
        device=_default_device(),
    )