File size: 2,128 Bytes
263e7a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import CLIPModel, PreTrainedModel, CLIPProcessor, AutoConfig
import torch
import pickle
from torch.nn.functional import cosine_similarity

CLIP_MODEL = "openai/clip-vit-large-patch14"


class Q16Model(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.clip_model = CLIPModel.from_pretrained(CLIP_MODEL)
        self.soft_prompts = None

    def load_soft_prompts(self, path):
        self.soft_prompts = torch.HalfTensor(pickle.load(
            open(path, 'rb'))).to('cpu').to(torch.float32)

    def forward(self, pixel_values):
        # Get image encodings from CLIP model
        image_features = self.clip_model.get_image_features(
            pixel_values=pixel_values)

        # Compare image features with soft prompts
        similarities = cosine_similarity(image_features.unsqueeze(
            1), self.soft_prompts.unsqueeze(0), dim=-1)
        logits = similarities

        return logits

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = kwargs.pop("config", None)
        model = super(Q16Model, cls).from_pretrained(
            pretrained_model_name_or_path, config=config, *model_args, **kwargs)
        # Load the soft prompts
        model.load_soft_prompts(f"{pretrained_model_name_or_path}/prompts.p")
        return model

    def save_pretrained(self, save_directory):
        super().save_pretrained(save_directory)
        # Save the soft prompts separately
        with open(f"{save_directory}/prompts.p", 'wb') as f:
            pickle.dump(self.soft_prompts.cpu().numpy(), f)


if __name__ == "__main__":
    # Define the configuration
    config = AutoConfig.from_pretrained(CLIP_MODEL)
    config.soft_prompt_dim = 768

    # Initialize the custom model
    model = Q16Model(config)

    # Load the soft prompts
    model.load_soft_prompts("./prompts.p")

    # Save the model and processor
    save_directory = "."
    model.save_pretrained(save_directory)
    processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
    processor.save_pretrained(save_directory)