File size: 2,918 Bytes
23680f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Image embedding computation via EmbedAnything."""

import os
import tempfile
from pathlib import Path

import numpy as np
from embed_anything import EmbeddingModel
from PIL import Image

from hyperview.core.sample import Sample


class EmbeddingComputer:
    """Compute embeddings for image samples using EmbedAnything."""

    def __init__(self, model: str):
        """Initialize the embedding computer.

        Args:
            model: HuggingFace model ID to load via EmbedAnything.
        """
        if not model or not model.strip():
            raise ValueError("model must be a non-empty HuggingFace model_id")

        self.model_id = model
        self._model: EmbeddingModel | None = None

    def _get_model(self) -> EmbeddingModel:
        """Lazily initialize the EmbedAnything model."""
        if self._model is None:
            self._model = EmbeddingModel.from_pretrained_hf(model_id=self.model_id)
        return self._model

    def _load_rgb_image(self, sample: Sample) -> Image.Image:
        """Load an image and normalize it to RGB.

        For file-backed samples, returns an in-memory copy and closes the file
        handle immediately to avoid leaking descriptors during batch processing.
        """
        with sample.load_image() as img:
            img.load()
            if img.mode != "RGB":
                return img.convert("RGB")
            return img.copy()

    def _embed_file(self, file_path: str) -> np.ndarray:
        model = self._get_model()
        result = model.embed_file(file_path)

        if not result:
            raise RuntimeError(f"EmbedAnything returned no embeddings for: {file_path}")
        if len(result) != 1:
            raise RuntimeError(
                f"Expected 1 embedding for an image file, got {len(result)}: {file_path}"
            )

        return np.asarray(result[0].embedding, dtype=np.float32)

    def _embed_pil_image(self, image: Image.Image) -> np.ndarray:
        temp_fd, temp_path = tempfile.mkstemp(suffix=".png")
        os.close(temp_fd)
        try:
            image.save(temp_path, format="PNG")
            return self._embed_file(temp_path)
        finally:
            Path(temp_path).unlink(missing_ok=True)

    def compute_single(self, sample: Sample) -> np.ndarray:
        """Compute embedding for a single sample."""
        image = self._load_rgb_image(sample)
        return self._embed_pil_image(image)

    def compute_batch(
        self,
        samples: list[Sample],
        batch_size: int = 32,
        show_progress: bool = True,
    ) -> list[np.ndarray]:
        """Compute embeddings for a list of samples."""
        if batch_size <= 0:
            raise ValueError("batch_size must be > 0")
        self._get_model()

        if show_progress:
            print(f"Computing embeddings for {len(samples)} samples...")

        return [self.compute_single(sample) for sample in samples]