Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from dataclasses import dataclass | |
| from torchvision.transforms import Normalize | |
| from transformers import CLIPModel, CLIPTokenizer | |
| from transformers.utils import ModelOutput | |
| from typing import Iterable, Optional, Union, List | |
| ImageType = Union[np.ndarray, torch.Tensor, Image.Image] | |
| class CLIPEmbedOutput(ModelOutput): | |
| last_hidden_state: torch.FloatTensor = None | |
| pooler_output: torch.FloatTensor = None | |
| embeds: torch.FloatTensor = None | |
| class CLIPEncoder(torch.nn.Module): | |
| def __init__(self, model_path="openai/clip-vit-base-patch32"): | |
| super().__init__() | |
| # Load the CLIP model and processor | |
| self.model: CLIPModel = CLIPModel.from_pretrained(model_path) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_path) | |
| self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| self.model.training = False | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| def encode_image(self, images: Iterable[Optional[ImageType]]): | |
| pixel_values = self.image_preprocess(images) | |
| vision_outputs = self.model.vision_model(pixel_values=pixel_values) | |
| pooler_output = vision_outputs[1] # pooled_output | |
| image_features = self.model.visual_projection(pooler_output) | |
| visual_embeds = CLIPEmbedOutput( | |
| last_hidden_state=vision_outputs.last_hidden_state, | |
| pooler_output=pooler_output, | |
| embeds=image_features | |
| ) | |
| return visual_embeds | |
| def encode_text(self, texts: List[str]): | |
| text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") | |
| text_outputs = self.model.text_model(input_ids=text_inputs) | |
| pooler_output = text_outputs[1] # pooled_output | |
| text_features = self.model.text_projection(pooler_output) | |
| text_embeds = CLIPEmbedOutput( | |
| last_hidden_state=text_outputs.last_hidden_state, | |
| pooler_output=pooler_output, | |
| embeds=text_features | |
| ) | |
| return text_embeds | |
| def forward(self, | |
| images: Iterable[Optional[ImageType]], | |
| texts: List[str]): | |
| visual_embeds = self.encode_image(images) | |
| text_embeds = self.encode_text(texts) | |
| return visual_embeds, text_embeds | |