File size: 6,019 Bytes
3e0528a
 
3b2121f
 
 
3e0528a
3b2121f
3e0528a
5eb4524
 
 
 
3b2121f
 
 
 
 
 
 
 
 
3e0528a
 
 
3b2121f
 
 
 
 
 
5eb4524
3b2121f
 
 
 
 
3e0528a
5eb4524
3e0528a
3b2121f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5eb4524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b2121f
 
 
 
 
 
5eb4524
 
 
 
 
 
 
3b2121f
5eb4524
 
 
 
3b2121f
 
 
 
 
3e0528a
 
 
 
3b2121f
 
 
3e0528a
 
 
5eb4524
3b2121f
 
bcfaacf
 
3b2121f
 
 
 
 
 
 
3e0528a
5eb4524
bcfaacf
3e0528a
 
3b2121f
bcfaacf
3b2121f
 
3e0528a
 
3b2121f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import base64
import io
from pathlib import Path
from typing import Any, Dict, List
from urllib.request import urlopen

import open_clip
import torch
from PIL import Image, ImageOps
from torchvision.transforms import Compose, Normalize, ToTensor

INPUT_SIZE = 224


def _is_git_lfs_pointer(path: Path) -> bool:
    if not path.is_file() or path.stat().st_size > 1024:
        return False

    with path.open("rb") as handle:
        return handle.read(64).startswith(b"version https://git-lfs.github.com/spec/v1")


class EndpointHandler:
    def __init__(self, model_dir: str = "", **kwargs: Any):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_dir = Path(model_dir or "/repository")

        self._validate_model_files()

        model_id = f"local-dir:{self.model_dir}"
        self.model, preprocess = open_clip.create_model_from_pretrained(
            model_id,
            device=self.device,
            return_transform=True,
        )
        self.tokenizer = open_clip.get_tokenizer(model_id)
        self.model.eval()
        self.tensor_preprocess = self._build_tensor_preprocess(preprocess)

    def _validate_model_files(self) -> None:
        config_path = self.model_dir / "open_clip_config.json"
        checkpoint_paths = [
            self.model_dir / "open_clip_model.safetensors",
            self.model_dir / "open_clip_pytorch_model.bin",
        ]

        if not config_path.is_file():
            raise FileNotFoundError(
                f"Missing {config_path.name} in {self.model_dir}. "
                "This repository must contain the OpenCLIP config file."
            )

        existing_checkpoints = [path for path in checkpoint_paths if path.is_file()]
        if not existing_checkpoints:
            raise FileNotFoundError(
                f"No OpenCLIP checkpoint found in {self.model_dir}. "
                "Expected open_clip_model.safetensors or open_clip_pytorch_model.bin."
            )

        pointer_paths = [path.name for path in existing_checkpoints if _is_git_lfs_pointer(path)]
        if pointer_paths:
            raise RuntimeError(
                "The repository contains Git LFS pointer files instead of real model weights: "
                f"{', '.join(pointer_paths)}. "
                "Upload the actual LFS blobs to the Hugging Face model repo before starting the endpoint."
            )

    @staticmethod
    def _build_tensor_preprocess(original_preprocess) -> Compose:
        """Extract Normalize from the model's preprocess and build ToTensor + Normalize only.

        The default model preprocess includes Resize + CenterCrop + ToTensor + Normalize.
        Since we manually squash images to INPUT_SIZE x INPUT_SIZE, we only need
        ToTensor + Normalize to match the existing embedding pipeline.
        """
        normalize = None
        for t in original_preprocess.transforms:
            if isinstance(t, Normalize):
                normalize = t
                break
        if normalize is None:
            normalize = Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        return Compose([ToTensor(), normalize])

    @staticmethod
    def _prepare_image(img: Image.Image) -> Image.Image:
        """Squash image to INPUT_SIZE x INPUT_SIZE."""
        return img.resize((INPUT_SIZE, INPUT_SIZE), Image.BICUBIC)

    def _load_image(self, image_input: Any) -> Image.Image | None:
        if not isinstance(image_input, str):
            return None

        if image_input.startswith(("http://", "https://")):
            with urlopen(image_input, timeout=10) as response:
                img = Image.open(io.BytesIO(response.read()))
        else:
            image_bytes = base64.b64decode(image_input.split(",")[-1])
            img = Image.open(io.BytesIO(image_bytes))

        img = ImageOps.exif_transpose(img)
        return img.convert("RGB")

    def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
        """Squash to INPUT_SIZE and apply tensor normalization."""
        image = self._prepare_image(image)
        return self.tensor_preprocess(image).unsqueeze(0).to(self.device)

    def _tokenize_text(self, text: str | List[str]) -> torch.Tensor:
        texts = text if isinstance(text, list) else [text]
        return self.tokenizer(texts).to(self.device)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        payload = data.get("inputs", data)

        text = payload.get("text")
        image_input = payload.get("image")

        image = self._load_image(image_input)

        with torch.no_grad():
            if image is not None and text is not None:
                image_tensor = self._preprocess_image(image)
                text_tensor = self._tokenize_text(text)

                image_features = self.model.encode_image(image_tensor, normalize=True)
                text_features = self.model.encode_text(text_tensor, normalize=True)

                response = {"image_embedding": image_features[0].cpu().tolist()}
                if isinstance(text, list):
                    response["text_embeddings"] = text_features.cpu().tolist()
                else:
                    response["text_embedding"] = text_features[0].cpu().tolist()
                return response
            elif image is not None:
                image_tensor = self._preprocess_image(image)
                image_features = self.model.encode_image(image_tensor, normalize=True)
                return {"image_embedding": image_features[0].cpu().tolist()}
            elif text is not None:
                text_tensor = self._tokenize_text(text)
                text_features = self.model.encode_text(text_tensor, normalize=True)
                if isinstance(text, list):
                    return {"text_embeddings": text_features.cpu().tolist()}
                return {"text_embedding": text_features[0].cpu().tolist()}
            else:
                return {"error": "Provide 'text' or 'image' (base64 or URL)."}