| import os |
| from typing import Callable, Optional |
|
|
| import numpy as np |
| import torch |
| from transformers import CLIPModel, CLIPProcessor |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .configuration_cased import CaSEDConfig |
| from .retrieval_cased import RetrievalDatabase, download_retrieval_databases |
| from .transforms_cased import default_vocabulary_transforms |
|
|
|
|
| class CaSEDModel(PreTrainedModel): |
| """Transformers module for Category Search from External Databases (CaSED). |
| |
| Reference: |
| - Conti et al. Vocabulary-free Image Classification. NeurIPS 2023. |
| |
| Args: |
| config (CaSEDConfig): Configuration class for CaSED. |
| """ |
|
|
| config_class = CaSEDConfig |
|
|
| def __init__(self, config: CaSEDConfig): |
| super().__init__(config) |
|
|
| |
| model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
| self.vision_encoder = model.vision_model |
| self.vision_proj = model.visual_projection |
| self.language_encoder = model.text_model |
| self.language_proj = model.text_projection |
| self.logit_scale = model.logit_scale.exp() |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
| |
| self.hparams = {} |
| self.hparams["alpha"] = config.alpha |
| self.hparams["index_name"] = config.index_name |
| self.hparams["retrieval_num_results"] = config.retrieval_num_results |
| self.hparams["cache_dir"] = config.cache_dir |
|
|
| |
| os.makedirs(self.hparams["cache_dir"], exist_ok=True) |
|
|
| |
| download_retrieval_databases(cache_dir=self.hparams["cache_dir"]) |
|
|
| |
| self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"]) |
| self._vocab_transform = default_vocabulary_transforms() |
|
|
| @property |
| def vocab_transform(self) -> Callable: |
| """Get image preprocess transform. |
| |
| The getter wraps the transform in a map_reduce function and applies it to a list of images. |
| If interested in the transform itself, use `self._vocab_transform`. |
| """ |
| vocab_transform = self._vocab_transform |
|
|
| def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]: |
| return [vocab_transform(text) for text in texts] |
|
|
| return vocabs_transforms |
|
|
| def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]: |
| """Get the vocabulary for a batch of images. |
| |
| Args: |
| images_z (torch.Tensor): Batch of image embeddings. |
| """ |
| num_samples = self.hparams["retrieval_num_results"] |
|
|
| assert images_z is not None |
|
|
| images_z = images_z / images_z.norm(dim=-1, keepdim=True) |
| images_z = images_z.cpu().detach().numpy().tolist() |
|
|
| if isinstance(images_z[0], float): |
| images_z = [images_z] |
|
|
| query = np.matrix(images_z).astype("float32") |
| results = self.vocabulary.query(query, modality="text", num_samples=num_samples) |
|
|
| vocabularies = [[r["caption"] for r in result] for result in results] |
| return vocabularies |
|
|
| def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor: |
| """Forward pass. |
| |
| Args: |
| images (dict): Dictionary with the images. The expected keys are: |
| - pixel_values (torch.Tensor): Pixel values of the images. |
| alpha (Optional[float]): Alpha value for the interpolation. |
| """ |
| alpha = alpha or self.hparams["alpha"] |
|
|
| |
| images["pixel_values"] = images["pixel_values"].to(self.device) |
| images_z = self.vision_proj(self.vision_encoder(**images)[1]) |
| images_z = images_z / images_z.norm(dim=-1, keepdim=True) |
| vocabularies = self.get_vocabulary(images_z=images_z) |
|
|
| |
| unfiltered_words = sum(vocabularies, []) |
| texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True) |
| texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device) |
| texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device) |
| texts_z = self.language_encoder(**texts_z)[1] |
| texts_z = self.language_proj(texts_z) |
| texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) |
|
|
| |
| unfiltered_words_per_image = [len(vocab) for vocab in vocabularies] |
| texts_z = torch.split(texts_z, unfiltered_words_per_image) |
| texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z]) |
| texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) |
|
|
| |
| vocabularies = self.vocab_transform(vocabularies) |
| vocabularies = [vocab or ["object"] for vocab in vocabularies] |
| words = sum(vocabularies, []) |
| words_z = self.processor(words, return_tensors="pt", padding=True) |
| words_z = {k: v.to(self.device) for k, v in words_z.items()} |
| words_z = self.language_encoder(**words_z)[1] |
| words_z = self.language_proj(words_z) |
| words_z = words_z / words_z.norm(dim=-1, keepdim=True) |
|
|
| |
| words_per_image = [len(vocab) for vocab in vocabularies] |
| col_indices = torch.arange(sum(words_per_image)) |
| row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image)) |
| mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device) |
| mask[row_indices, col_indices] = 1 |
|
|
| |
| images_z = images_z / images_z.norm(dim=-1, keepdim=True) |
| texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) |
| words_z = words_z / words_z.norm(dim=-1, keepdim=True) |
| images_sim = self.logit_scale * images_z @ words_z.T |
| texts_sim = self.logit_scale * texts_z @ words_z.T |
|
|
| |
| images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf")) |
| texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf")) |
|
|
| |
| images_p = images_sim.softmax(dim=-1) |
| texts_p = texts_sim.softmax(dim=-1) |
|
|
| |
| samples_p = alpha * images_p + (1 - alpha) * texts_p |
|
|
| return {"scores": samples_p, "words": words, "vocabularies": vocabularies} |
|
|