File size: 2,128 Bytes
263e7a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from transformers import CLIPModel, PreTrainedModel, CLIPProcessor, AutoConfig
import torch
import pickle
from torch.nn.functional import cosine_similarity
CLIP_MODEL = "openai/clip-vit-large-patch14"
class Q16Model(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.clip_model = CLIPModel.from_pretrained(CLIP_MODEL)
self.soft_prompts = None
def load_soft_prompts(self, path):
self.soft_prompts = torch.HalfTensor(pickle.load(
open(path, 'rb'))).to('cpu').to(torch.float32)
def forward(self, pixel_values):
# Get image encodings from CLIP model
image_features = self.clip_model.get_image_features(
pixel_values=pixel_values)
# Compare image features with soft prompts
similarities = cosine_similarity(image_features.unsqueeze(
1), self.soft_prompts.unsqueeze(0), dim=-1)
logits = similarities
return logits
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
model = super(Q16Model, cls).from_pretrained(
pretrained_model_name_or_path, config=config, *model_args, **kwargs)
# Load the soft prompts
model.load_soft_prompts(f"{pretrained_model_name_or_path}/prompts.p")
return model
def save_pretrained(self, save_directory):
super().save_pretrained(save_directory)
# Save the soft prompts separately
with open(f"{save_directory}/prompts.p", 'wb') as f:
pickle.dump(self.soft_prompts.cpu().numpy(), f)
if __name__ == "__main__":
# Define the configuration
config = AutoConfig.from_pretrained(CLIP_MODEL)
config.soft_prompt_dim = 768
# Initialize the custom model
model = Q16Model(config)
# Load the soft prompts
model.load_soft_prompts("./prompts.p")
# Save the model and processor
save_directory = "."
model.save_pretrained(save_directory)
processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
processor.save_pretrained(save_directory)
|