"""CLIP model for zero-shot classification; running on CPU machine""" from typing import Dict, List import open_clip import torch from open_clip import tokenizer from PIL import Image from src.core.logger import logger # modules from src.core.singleton import SingletonMeta class ClipModel(metaclass=SingletonMeta): """CLIP Model Class Args: metaclass (_type_, optional): _description_. Defaults to SingletonMeta. """ def __init__( self, model_name: str = "ViT-B/32", pretrained: str = "laion2b_s34b_b79k", jit: bool = False, ): logger.debug("creating CLIP Model Object") self.config = { "model_name": model_name, "pretrained": pretrained, "precision": "bf16", "device": "cpu", "jit": jit, "cache_dir": "model_dir/", } self.model, self.preprocess = open_clip.create_model_from_pretrained( **self.config ) self.model.eval() # Use lazy % formatting in logging functions logger.info( "%s %s initialized", self.config.get("model_name"), self.config.get("pretrained"), ) def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]: """inference pipeline for CLIP model""" with torch.inference_mode(), torch.cpu.amp.autocast(): # compute image features image_input = self.preprocess_image(image) image_features = self.get_image_features(image_input) logger.info("image features computed") # compute text features text_input = self.preprocess_text(text) text_features = self.get_text_features(text_input) logger.info("text features computed") # zero-shot classification text_probs = self.matmul_and_softmax(image_features, text_features) logger.debug("text_probs: %s", text_probs) return dict(zip(text, text_probs)) def preprocess_image(self, image: Image.Image) -> torch.Tensor: """function to preprocess the input image""" return self.preprocess(image).unsqueeze(0) @staticmethod def preprocess_text(text: List[str]) -> torch.Tensor: """function to preprocess the input text""" return tokenizer.tokenize(text) def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor: """function to get the image features""" image_features = self.model.encode_image(image_input) image_features /= image_features.norm( dim=-1, keepdim=True ) # normalize vector prior return image_features def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor: """function to get the text features""" text_features = self.model.encode_text(text_input) text_features /= text_features.norm( dim=-1, keepdim=True ) # normalize vector prior return text_features @staticmethod def matmul_and_softmax( image_features: torch.Tensor, text_features: torch.Tensor ) -> List[float]: """compute matmul and softmax""" return ( (100.0 * image_features @ text_features.T) .softmax(dim=-1) .squeeze(0) .tolist() )