TYH71
qol: linting
2449b1f
"""CLIP model for zero-shot classification; running on CPU machine"""
from typing import Dict, List
import open_clip
import torch
from open_clip import tokenizer
from PIL import Image
from src.core.logger import logger
# modules
from src.core.singleton import SingletonMeta
class ClipModel(metaclass=SingletonMeta):
"""CLIP Model Class
Args:
metaclass (_type_, optional): _description_. Defaults to SingletonMeta.
"""
def __init__(
self,
model_name: str = "ViT-B/32",
pretrained: str = "laion2b_s34b_b79k",
jit: bool = False,
):
logger.debug("creating CLIP Model Object")
self.config = {
"model_name": model_name,
"pretrained": pretrained,
"precision": "bf16",
"device": "cpu",
"jit": jit,
"cache_dir": "model_dir/",
}
self.model, self.preprocess = open_clip.create_model_from_pretrained(
**self.config
)
self.model.eval()
# Use lazy % formatting in logging functions
logger.info(
"%s %s initialized",
self.config.get("model_name"),
self.config.get("pretrained"),
)
def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]:
"""inference pipeline for CLIP model"""
with torch.inference_mode(), torch.cpu.amp.autocast():
# compute image features
image_input = self.preprocess_image(image)
image_features = self.get_image_features(image_input)
logger.info("image features computed")
# compute text features
text_input = self.preprocess_text(text)
text_features = self.get_text_features(text_input)
logger.info("text features computed")
# zero-shot classification
text_probs = self.matmul_and_softmax(image_features, text_features)
logger.debug("text_probs: %s", text_probs)
return dict(zip(text, text_probs))
def preprocess_image(self, image: Image.Image) -> torch.Tensor:
"""function to preprocess the input image"""
return self.preprocess(image).unsqueeze(0)
@staticmethod
def preprocess_text(text: List[str]) -> torch.Tensor:
"""function to preprocess the input text"""
return tokenizer.tokenize(text)
def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor:
"""function to get the image features"""
image_features = self.model.encode_image(image_input)
image_features /= image_features.norm(
dim=-1, keepdim=True
) # normalize vector prior
return image_features
def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor:
"""function to get the text features"""
text_features = self.model.encode_text(text_input)
text_features /= text_features.norm(
dim=-1, keepdim=True
) # normalize vector prior
return text_features
@staticmethod
def matmul_and_softmax(
image_features: torch.Tensor, text_features: torch.Tensor
) -> List[float]:
"""compute matmul and softmax"""
return (
(100.0 * image_features @ text_features.T)
.softmax(dim=-1)
.squeeze(0)
.tolist()
)