Spaces:
Runtime error
Runtime error
| """CLIP model for zero-shot classification; running on CPU machine""" | |
| from typing import Dict, List | |
| import open_clip | |
| import torch | |
| from open_clip import tokenizer | |
| from PIL import Image | |
| from src.core.logger import logger | |
| # modules | |
| from src.core.singleton import SingletonMeta | |
| class ClipModel(metaclass=SingletonMeta): | |
| """CLIP Model Class | |
| Args: | |
| metaclass (_type_, optional): _description_. Defaults to SingletonMeta. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "ViT-B/32", | |
| pretrained: str = "laion2b_s34b_b79k", | |
| jit: bool = False, | |
| ): | |
| logger.debug("creating CLIP Model Object") | |
| self.config = { | |
| "model_name": model_name, | |
| "pretrained": pretrained, | |
| "precision": "bf16", | |
| "device": "cpu", | |
| "jit": jit, | |
| "cache_dir": "model_dir/", | |
| } | |
| self.model, self.preprocess = open_clip.create_model_from_pretrained( | |
| **self.config | |
| ) | |
| self.model.eval() | |
| # Use lazy % formatting in logging functions | |
| logger.info( | |
| "%s %s initialized", | |
| self.config.get("model_name"), | |
| self.config.get("pretrained"), | |
| ) | |
| def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]: | |
| """inference pipeline for CLIP model""" | |
| with torch.inference_mode(), torch.cpu.amp.autocast(): | |
| # compute image features | |
| image_input = self.preprocess_image(image) | |
| image_features = self.get_image_features(image_input) | |
| logger.info("image features computed") | |
| # compute text features | |
| text_input = self.preprocess_text(text) | |
| text_features = self.get_text_features(text_input) | |
| logger.info("text features computed") | |
| # zero-shot classification | |
| text_probs = self.matmul_and_softmax(image_features, text_features) | |
| logger.debug("text_probs: %s", text_probs) | |
| return dict(zip(text, text_probs)) | |
| def preprocess_image(self, image: Image.Image) -> torch.Tensor: | |
| """function to preprocess the input image""" | |
| return self.preprocess(image).unsqueeze(0) | |
| def preprocess_text(text: List[str]) -> torch.Tensor: | |
| """function to preprocess the input text""" | |
| return tokenizer.tokenize(text) | |
| def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor: | |
| """function to get the image features""" | |
| image_features = self.model.encode_image(image_input) | |
| image_features /= image_features.norm( | |
| dim=-1, keepdim=True | |
| ) # normalize vector prior | |
| return image_features | |
| def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor: | |
| """function to get the text features""" | |
| text_features = self.model.encode_text(text_input) | |
| text_features /= text_features.norm( | |
| dim=-1, keepdim=True | |
| ) # normalize vector prior | |
| return text_features | |
| def matmul_and_softmax( | |
| image_features: torch.Tensor, text_features: torch.Tensor | |
| ) -> List[float]: | |
| """compute matmul and softmax""" | |
| return ( | |
| (100.0 * image_features @ text_features.T) | |
| .softmax(dim=-1) | |
| .squeeze(0) | |
| .tolist() | |
| ) | |