| from time import perf_counter |
| from typing import Any, Dict, List, Tuple, Union |
|
|
| import clip |
| import numpy as np |
| import onnxruntime |
| from PIL import Image |
|
|
| from inference.core.entities.requests.clip import ( |
| ClipCompareRequest, |
| ClipImageEmbeddingRequest, |
| ClipInferenceRequest, |
| ClipTextEmbeddingRequest, |
| ) |
| from inference.core.entities.requests.inference import InferenceRequestImage |
| from inference.core.entities.responses.clip import ( |
| ClipCompareResponse, |
| ClipEmbeddingResponse, |
| ) |
| from inference.core.entities.responses.inference import InferenceResponse |
| from inference.core.env import ( |
| CLIP_MAX_BATCH_SIZE, |
| CLIP_MODEL_ID, |
| ONNXRUNTIME_EXECUTION_PROVIDERS, |
| REQUIRED_ONNX_PROVIDERS, |
| TENSORRT_CACHE_PATH, |
| ) |
| from inference.core.exceptions import OnnxProviderNotAvailable |
| from inference.core.models.roboflow import OnnxRoboflowCoreModel |
| from inference.core.models.types import PreprocessReturnMetadata |
| from inference.core.utils.image_utils import load_image_rgb |
| from inference.core.utils.onnx import get_onnxruntime_execution_providers |
| from inference.core.utils.postprocess import cosine_similarity |
|
|
|
|
| class Clip(OnnxRoboflowCoreModel): |
| """Roboflow ONNX ClipModel model. |
| |
| This class is responsible for handling the ONNX ClipModel model, including |
| loading the model, preprocessing the input, and performing inference. |
| |
| Attributes: |
| visual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for visual inference. |
| textual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for textual inference. |
| resolution (int): The resolution of the input image. |
| clip_preprocess (function): Function to preprocess the image. |
| """ |
|
|
| def __init__( |
| self, |
| *args, |
| model_id: str = CLIP_MODEL_ID, |
| onnxruntime_execution_providers: List[ |
| str |
| ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS), |
| **kwargs, |
| ): |
| """Initializes the Clip with the given arguments and keyword arguments.""" |
| self.onnxruntime_execution_providers = onnxruntime_execution_providers |
| t1 = perf_counter() |
| super().__init__(*args, model_id=model_id, **kwargs) |
| |
| self.log("Creating inference sessions") |
| self.visual_onnx_session = onnxruntime.InferenceSession( |
| self.cache_file("visual.onnx"), |
| providers=self.onnxruntime_execution_providers, |
| ) |
|
|
| self.textual_onnx_session = onnxruntime.InferenceSession( |
| self.cache_file("textual.onnx"), |
| providers=self.onnxruntime_execution_providers, |
| ) |
|
|
| if REQUIRED_ONNX_PROVIDERS: |
| available_providers = onnxruntime.get_available_providers() |
| for provider in REQUIRED_ONNX_PROVIDERS: |
| if provider not in available_providers: |
| raise OnnxProviderNotAvailable( |
| f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device." |
| ) |
|
|
| self.resolution = self.visual_onnx_session.get_inputs()[0].shape[2] |
|
|
| self.clip_preprocess = clip.clip._transform(self.resolution) |
| self.log(f"CLIP model loaded in {perf_counter() - t1:.2f} seconds") |
|
|
| def compare( |
| self, |
| subject: Any, |
| prompt: Any, |
| subject_type: str = "image", |
| prompt_type: Union[str, List[str], Dict[str, Any]] = "text", |
| **kwargs, |
| ) -> Union[List[float], Dict[str, float]]: |
| """ |
| Compares the subject with the prompt to calculate similarity scores. |
| |
| Args: |
| subject (Any): The subject data to be compared. Can be either an image or text. |
| prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values. |
| subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image". |
| prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text". |
| **kwargs: Additional keyword arguments. |
| |
| Returns: |
| Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys. |
| |
| Raises: |
| ValueError: If subject_type or prompt_type is neither "image" nor "text". |
| ValueError: If the number of prompts exceeds the maximum batch size. |
| """ |
|
|
| if subject_type == "image": |
| subject_embeddings = self.embed_image(subject) |
| elif subject_type == "text": |
| subject_embeddings = self.embed_text(subject) |
| else: |
| raise ValueError( |
| "subject_type must be either 'image' or 'text', but got {request.subject_type}" |
| ) |
|
|
| if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt): |
| prompt_keys = prompt.keys() |
| prompt = [prompt[k] for k in prompt_keys] |
| prompt_obj = "dict" |
| else: |
| prompt = prompt |
| if not isinstance(prompt, list): |
| prompt = [prompt] |
| prompt_obj = "list" |
|
|
| if len(prompt) > CLIP_MAX_BATCH_SIZE: |
| raise ValueError( |
| f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}" |
| ) |
|
|
| if prompt_type == "image": |
| prompt_embeddings = self.embed_image(prompt) |
| elif prompt_type == "text": |
| prompt_embeddings = self.embed_text(prompt) |
| else: |
| raise ValueError( |
| "prompt_type must be either 'image' or 'text', but got {request.prompt_type}" |
| ) |
|
|
| similarities = [ |
| cosine_similarity(subject_embeddings, p) for p in prompt_embeddings |
| ] |
|
|
| if prompt_obj == "dict": |
| similarities = dict(zip(prompt_keys, similarities)) |
|
|
| return similarities |
|
|
| def make_compare_response( |
| self, similarities: Union[List[float], Dict[str, float]] |
| ) -> ClipCompareResponse: |
| """ |
| Creates a ClipCompareResponse object from the provided similarity data. |
| |
| Args: |
| similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores. |
| |
| Returns: |
| ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores. |
| |
| Example: |
| Assuming `ClipCompareResponse` expects a dictionary of string-float pairs: |
| |
| >>> make_compare_response({"image1": 0.98, "image2": 0.76}) |
| ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76}) |
| """ |
| response = ClipCompareResponse(similarity=similarities) |
| return response |
|
|
| def embed_image( |
| self, |
| image: Any, |
| **kwargs, |
| ) -> np.ndarray: |
| """ |
| Embeds an image or a list of images using the Clip model. |
| |
| Args: |
| image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method. |
| **kwargs: Additional keyword arguments. |
| |
| Returns: |
| np.ndarray: The embeddings of the image(s) as a numpy array. |
| |
| Raises: |
| ValueError: If the number of images in the list exceeds the maximum batch size. |
| |
| Notes: |
| The function measures performance using perf_counter and also has support for ONNX session to get embeddings. |
| """ |
| t1 = perf_counter() |
|
|
| if isinstance(image, list): |
| if len(image) > CLIP_MAX_BATCH_SIZE: |
| raise ValueError( |
| f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}" |
| ) |
| imgs = [self.preproc_image(i) for i in image] |
| img_in = np.concatenate(imgs, axis=0) |
| else: |
| img_in = self.preproc_image(image) |
|
|
| onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in} |
| embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0] |
|
|
| return embeddings |
|
|
| def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: |
| onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in} |
| embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0] |
| return (embeddings,) |
|
|
| def make_embed_image_response( |
| self, embeddings: np.ndarray |
| ) -> ClipEmbeddingResponse: |
| """ |
| Converts the given embeddings into a ClipEmbeddingResponse object. |
| |
| Args: |
| embeddings (np.ndarray): A numpy array containing the embeddings for an image or images. |
| |
| Returns: |
| ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list. |
| |
| Example: |
| >>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) |
| >>> make_embed_image_response(embeddings_array) |
| ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) |
| """ |
| response = ClipEmbeddingResponse(embeddings=embeddings.tolist()) |
|
|
| return response |
|
|
| def embed_text( |
| self, |
| text: Union[str, List[str]], |
| **kwargs, |
| ) -> np.ndarray: |
| """ |
| Embeds a text or a list of texts using the Clip model. |
| |
| Args: |
| text (Union[str, List[str]]): The text string or list of text strings to be embedded. |
| **kwargs: Additional keyword arguments. |
| |
| Returns: |
| np.ndarray: The embeddings of the text or texts as a numpy array. |
| |
| Raises: |
| ValueError: If the number of text strings in the list exceeds the maximum batch size. |
| |
| Notes: |
| The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter. |
| """ |
| t1 = perf_counter() |
|
|
| if isinstance(text, list): |
| if len(text) > CLIP_MAX_BATCH_SIZE: |
| raise ValueError( |
| f"The maximum number of text strings that can be embedded at once is {CLIP_MAX_BATCH_SIZE}" |
| ) |
|
|
| texts = text |
| else: |
| texts = [text] |
|
|
| texts = clip.tokenize(texts).numpy().astype(np.int32) |
|
|
| onnx_input_text = {self.textual_onnx_session.get_inputs()[0].name: texts} |
| embeddings = self.textual_onnx_session.run(None, onnx_input_text)[0] |
|
|
| return embeddings |
|
|
| def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse: |
| """ |
| Converts the given text embeddings into a ClipEmbeddingResponse object. |
| |
| Args: |
| embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts. |
| |
| Returns: |
| ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list. |
| |
| Example: |
| >>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) |
| >>> make_embed_text_response(embeddings_array) |
| ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) |
| """ |
| response = ClipEmbeddingResponse(embeddings=embeddings.tolist()) |
| return response |
|
|
| def get_infer_bucket_file_list(self) -> List[str]: |
| """Gets the list of files required for inference. |
| |
| Returns: |
| List[str]: The list of file names. |
| """ |
| return ["textual.onnx", "visual.onnx"] |
|
|
| def infer_from_request( |
| self, request: ClipInferenceRequest |
| ) -> ClipEmbeddingResponse: |
| """Routes the request to the appropriate inference function. |
| |
| Args: |
| request (ClipInferenceRequest): The request object containing the inference details. |
| |
| Returns: |
| ClipEmbeddingResponse: The response object containing the embeddings. |
| """ |
| t1 = perf_counter() |
| if isinstance(request, ClipImageEmbeddingRequest): |
| infer_func = self.embed_image |
| make_response_func = self.make_embed_image_response |
| elif isinstance(request, ClipTextEmbeddingRequest): |
| infer_func = self.embed_text |
| make_response_func = self.make_embed_text_response |
| elif isinstance(request, ClipCompareRequest): |
| infer_func = self.compare |
| make_response_func = self.make_compare_response |
| else: |
| raise ValueError( |
| f"Request type {type(request)} is not a valid ClipInferenceRequest" |
| ) |
| data = infer_func(**request.dict()) |
| response = make_response_func(data) |
| response.time = perf_counter() - t1 |
| return response |
|
|
| def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse: |
| return [self.make_embed_image_response(embeddings)] |
|
|
| def postprocess( |
| self, |
| predictions: Tuple[np.ndarray], |
| preprocess_return_metadata: PreprocessReturnMetadata, |
| **kwargs, |
| ) -> Any: |
| return predictions[0] |
|
|
| def infer(self, image: Any, **kwargs) -> Any: |
| """Embeds an image""" |
| return super().infer(image, **kwargs) |
|
|
| def preproc_image(self, image: InferenceRequestImage) -> np.ndarray: |
| """Preprocesses an inference request image. |
| |
| Args: |
| image (InferenceRequestImage): The object containing information necessary to load the image for inference. |
| |
| Returns: |
| np.ndarray: A numpy array of the preprocessed image pixel data. |
| """ |
| pil_image = Image.fromarray(load_image_rgb(image)) |
| preprocessed_image = self.clip_preprocess(pil_image) |
|
|
| img_in = np.expand_dims(preprocessed_image, axis=0) |
|
|
| return img_in.astype(np.float32) |
|
|
| def preprocess( |
| self, image: Any, **kwargs |
| ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: |
| return self.preproc_image(image), PreprocessReturnMetadata({}) |
|
|