File size: 3,387 Bytes
2bbb2d2
2449b1f
 
2bbb2d2
2449b1f
2bbb2d2
2449b1f
 
 
2bbb2d2
 
 
 
3153924
6292093
 
 
 
 
 
3153924
 
 
2bbb2d2
 
3153924
2bbb2d2
 
2449b1f
 
 
 
 
 
 
 
3153924
 
2bbb2d2
 
6292093
 
 
 
3153924
6292093
2bbb2d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3153924
 
 
2bbb2d2
 
 
 
 
3153924
 
 
2bbb2d2
 
 
6292093
3153924
6292093
2bbb2d2
3153924
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""CLIP model for zero-shot classification; running on CPU machine"""
from typing import Dict, List

import open_clip
import torch
from open_clip import tokenizer
from PIL import Image

from src.core.logger import logger

# modules
from src.core.singleton import SingletonMeta


class ClipModel(metaclass=SingletonMeta):
    """CLIP Model Class

    Args:
        metaclass (_type_, optional): _description_. Defaults to SingletonMeta.
    """

    def __init__(
        self,
        model_name: str = "ViT-B/32",
        pretrained: str = "laion2b_s34b_b79k",
        jit: bool = False,
    ):
        logger.debug("creating CLIP Model Object")
        self.config = {
            "model_name": model_name,
            "pretrained": pretrained,
            "precision": "bf16",
            "device": "cpu",
            "jit": jit,
            "cache_dir": "model_dir/",
        }
        self.model, self.preprocess = open_clip.create_model_from_pretrained(
            **self.config
        )
        self.model.eval()
        # Use lazy % formatting in logging functions
        logger.info(
            "%s %s initialized",
            self.config.get("model_name"),
            self.config.get("pretrained"),
        )

    def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]:
        """inference pipeline for CLIP model"""
        with torch.inference_mode(), torch.cpu.amp.autocast():
            # compute image features
            image_input = self.preprocess_image(image)
            image_features = self.get_image_features(image_input)
            logger.info("image features computed")

            # compute text features
            text_input = self.preprocess_text(text)
            text_features = self.get_text_features(text_input)
            logger.info("text features computed")

            # zero-shot classification
            text_probs = self.matmul_and_softmax(image_features, text_features)
            logger.debug("text_probs: %s", text_probs)
            return dict(zip(text, text_probs))

    def preprocess_image(self, image: Image.Image) -> torch.Tensor:
        """function to preprocess the input image"""
        return self.preprocess(image).unsqueeze(0)

    @staticmethod
    def preprocess_text(text: List[str]) -> torch.Tensor:
        """function to preprocess the input text"""
        return tokenizer.tokenize(text)

    def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor:
        """function to get the image features"""
        image_features = self.model.encode_image(image_input)
        image_features /= image_features.norm(
            dim=-1, keepdim=True
        )  # normalize vector prior
        return image_features

    def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor:
        """function to get the text features"""
        text_features = self.model.encode_text(text_input)
        text_features /= text_features.norm(
            dim=-1, keepdim=True
        )  # normalize vector prior
        return text_features

    @staticmethod
    def matmul_and_softmax(
        image_features: torch.Tensor, text_features: torch.Tensor
    ) -> List[float]:
        """compute matmul and softmax"""
        return (
            (100.0 * image_features @ text_features.T)
            .softmax(dim=-1)
            .squeeze(0)
            .tolist()
        )