Spaces:
Runtime error
Runtime error
File size: 3,387 Bytes
2bbb2d2 2449b1f 2bbb2d2 2449b1f 2bbb2d2 2449b1f 2bbb2d2 3153924 6292093 3153924 2bbb2d2 3153924 2bbb2d2 2449b1f 3153924 2bbb2d2 6292093 3153924 6292093 2bbb2d2 3153924 2bbb2d2 3153924 2bbb2d2 6292093 3153924 6292093 2bbb2d2 3153924 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | """CLIP model for zero-shot classification; running on CPU machine"""
from typing import Dict, List
import open_clip
import torch
from open_clip import tokenizer
from PIL import Image
from src.core.logger import logger
# modules
from src.core.singleton import SingletonMeta
class ClipModel(metaclass=SingletonMeta):
"""CLIP Model Class
Args:
metaclass (_type_, optional): _description_. Defaults to SingletonMeta.
"""
def __init__(
self,
model_name: str = "ViT-B/32",
pretrained: str = "laion2b_s34b_b79k",
jit: bool = False,
):
logger.debug("creating CLIP Model Object")
self.config = {
"model_name": model_name,
"pretrained": pretrained,
"precision": "bf16",
"device": "cpu",
"jit": jit,
"cache_dir": "model_dir/",
}
self.model, self.preprocess = open_clip.create_model_from_pretrained(
**self.config
)
self.model.eval()
# Use lazy % formatting in logging functions
logger.info(
"%s %s initialized",
self.config.get("model_name"),
self.config.get("pretrained"),
)
def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]:
"""inference pipeline for CLIP model"""
with torch.inference_mode(), torch.cpu.amp.autocast():
# compute image features
image_input = self.preprocess_image(image)
image_features = self.get_image_features(image_input)
logger.info("image features computed")
# compute text features
text_input = self.preprocess_text(text)
text_features = self.get_text_features(text_input)
logger.info("text features computed")
# zero-shot classification
text_probs = self.matmul_and_softmax(image_features, text_features)
logger.debug("text_probs: %s", text_probs)
return dict(zip(text, text_probs))
def preprocess_image(self, image: Image.Image) -> torch.Tensor:
"""function to preprocess the input image"""
return self.preprocess(image).unsqueeze(0)
@staticmethod
def preprocess_text(text: List[str]) -> torch.Tensor:
"""function to preprocess the input text"""
return tokenizer.tokenize(text)
def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor:
"""function to get the image features"""
image_features = self.model.encode_image(image_input)
image_features /= image_features.norm(
dim=-1, keepdim=True
) # normalize vector prior
return image_features
def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor:
"""function to get the text features"""
text_features = self.model.encode_text(text_input)
text_features /= text_features.norm(
dim=-1, keepdim=True
) # normalize vector prior
return text_features
@staticmethod
def matmul_and_softmax(
image_features: torch.Tensor, text_features: torch.Tensor
) -> List[float]:
"""compute matmul and softmax"""
return (
(100.0 * image_features @ text_features.T)
.softmax(dim=-1)
.squeeze(0)
.tolist()
)
|