| import os |
| import tarfile |
| from pathlib import Path |
| from typing import Optional |
|
|
| import faiss |
| import numpy as np |
| import pyarrow as pa |
| import requests |
| import torch |
| from tqdm import tqdm |
| from transformers import CLIPModel, CLIPProcessor |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .configuration_cased import CaSEDConfig |
| from .transforms_cased import default_vocabulary_transforms |
|
|
| DATABASES = { |
| "cc12m": { |
| "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz", |
| "cache_subdir": "./cc12m/vit-l-14/", |
| }, |
| } |
|
|
|
|
| class MetadataProvider: |
| """Metadata provider. |
| |
| It uses arrow files to store metadata and retrieve it efficiently. |
| |
| Code reference: |
| - https://github.dev/rom1504/clip-retrieval |
| """ |
|
|
| def __init__(self, arrow_folder: Path): |
| arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()] |
| self.table = pa.concat_tables( |
| [ |
| pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all() |
| for arrow_file in arrow_files |
| ] |
| ) |
|
|
| def get(self, ids: np.ndarray, cols: Optional[list] = None): |
| """Get arrow metadata from ids. |
| |
| Args: |
| ids (np.ndarray): Ids to retrieve. |
| cols (Optional[list], optional): Columns to retrieve. Defaults to None. |
| """ |
| if cols is None: |
| cols = self.table.schema.names |
| else: |
| cols = list(set(self.table.schema.names) & set(cols)) |
| t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)]) |
| return t.select(cols).to_pandas().to_dict("records") |
|
|
|
|
| class CaSEDModel(PreTrainedModel): |
| """Transformers module for Category Search from External Databases (CaSED). |
| |
| Reference: |
| - Conti et al. Vocabulary-free Image Classification. arXiv 2023. |
| |
| Args: |
| config (CaSEDConfig): Configuration class for CaSED. |
| """ |
|
|
| config_class = CaSEDConfig |
|
|
| def __init__(self, config: CaSEDConfig): |
| super().__init__(config) |
|
|
| |
| model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
| self.vision_encoder = model.vision_model |
| self.vision_proj = model.visual_projection |
| self.language_encoder = model.text_model |
| self.language_proj = model.text_projection |
| self.logit_scale = model.logit_scale.exp() |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
| |
| self.vocabulary_transforms = default_vocabulary_transforms() |
|
|
| |
| self.hparams = {} |
| self.hparams["alpha"] = config.alpha |
| self.hparams["index_name"] = config.index_name |
| self.hparams["retrieval_num_results"] = config.retrieval_num_results |
|
|
| |
| self.hparams["cache_dir"] = Path(os.path.expanduser("~/.cache/cased")) |
| os.makedirs(self.hparams["cache_dir"], exist_ok=True) |
|
|
| |
| self.prepare_data() |
|
|
| |
| self.resources = {} |
| for name, items in DATABASES.items(): |
| database_path = self.hparams["cache_dir"] / "databases" / items["cache_subdir"] |
| text_index_fp = database_path / "text.index" |
| metadata_fp = database_path / "metadata/" |
|
|
| text_index = faiss.read_index( |
| str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY |
| ) |
| metadata_provider = MetadataProvider(metadata_fp) |
|
|
| self.resources[name] = { |
| "device": self.device, |
| "model": "ViT-L-14", |
| "text_index": text_index, |
| "metadata_provider": metadata_provider, |
| } |
|
|
| def prepare_data(self): |
| """Download data if needed.""" |
| databases_path = Path(self.hparams["cache_dir"]) / "databases" |
|
|
| for name, items in DATABASES.items(): |
| url = items["url"] |
| database_path = Path(databases_path, name) |
| if database_path.exists(): |
| continue |
|
|
| |
| target_path = Path(databases_path, name + ".tar.gz") |
| os.makedirs(target_path.parent, exist_ok=True) |
| with requests.get(url, stream=True) as r: |
| r.raise_for_status() |
| total_bytes_size = int(r.headers.get('content-length', 0)) |
| chunk_size = 8192 |
| p_bar = tqdm( |
| desc="Downloading cc12m index", |
| total=total_bytes_size, |
| unit='iB', |
| unit_scale=True, |
| ) |
| with open(target_path, 'wb') as f: |
| for chunk in r.iter_content(chunk_size=chunk_size): |
| f.write(chunk) |
| p_bar.update(len(chunk)) |
| p_bar.close() |
|
|
| |
| tar = tarfile.open(target_path, "r:gz") |
| tar.extractall(target_path.parent) |
| tar.close() |
| target_path.unlink() |
|
|
| @torch.no_grad() |
| def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: |
| """Query the external database index. |
| |
| Args: |
| sample_z (torch.Tensor): Sample to query the index. |
| """ |
| |
| resources = self.resources[self.hparams["index_name"]] |
| text_index = resources["text_index"] |
| metadata_provider = resources["metadata_provider"] |
|
|
| |
| sample_z = sample_z.squeeze(0) |
| sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) |
| query_input = sample_z.cpu().detach().numpy().tolist() |
| query = np.expand_dims(np.array(query_input).astype("float32"), 0) |
|
|
| distances, idxs, _ = text_index.search_and_reconstruct( |
| query, self.hparams["retrieval_num_results"] |
| ) |
| results = idxs[0] |
| nb_results = np.where(results == -1)[0] |
| nb_results = nb_results[0] if len(nb_results) > 0 else len(results) |
| indices = results[:nb_results] |
| distances = distances[0][:nb_results] |
|
|
| if len(distances) == 0: |
| return [] |
|
|
| |
| results = [] |
| metadata = metadata_provider.get(indices[:20], ["caption"]) |
| for key, (d, i) in enumerate(zip(distances, indices)): |
| output = {} |
| meta = None if key + 1 > len(metadata) else metadata[key] |
| if meta is not None: |
| output.update(meta) |
| output["id"] = i.item() |
| output["similarity"] = d.item() |
| results.append(output) |
|
|
| |
| vocabularies = [result["caption"] for result in results] |
|
|
| return vocabularies |
|
|
| @torch.no_grad() |
| def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor(): |
| """Forward pass. |
| |
| Args: |
| images (dict): Dictionary with the images. The expected keys are: |
| - pixel_values (torch.Tensor): Pixel values of the images. |
| alpha (Optional[float]): Alpha value for the interpolation. |
| """ |
| |
| images["pixel_values"] = images["pixel_values"].to(self.device) |
| images_z = self.vision_proj(self.vision_encoder(**images)[1]) |
|
|
| vocabularies, samples_p = [], [] |
| for image_z in images_z: |
| image_z = image_z.unsqueeze(0) |
|
|
| |
| vocabulary = self.query_index(image_z) |
| text = self.processor(text=vocabulary, return_tensors="pt", padding=True) |
| text["input_ids"] = text["input_ids"][:, :77].to(self.device) |
| text["attention_mask"] = text["attention_mask"][:, :77].to(self.device) |
| text_z = self.language_encoder(**text)[1] |
| text_z = self.language_proj(text_z) |
| text_z = text_z / text_z.norm(dim=-1, keepdim=True) |
| text_z = text_z.mean(dim=0).unsqueeze(0) |
| text_z = text_z / text_z.norm(dim=-1, keepdim=True) |
|
|
| |
| vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] |
| text = self.processor(text=vocabulary, return_tensors="pt", padding=True) |
| text = {k: v.to(self.device) for k, v in text.items()} |
| vocabulary_z = self.language_encoder(**text)[1] |
| vocabulary_z = self.language_proj(vocabulary_z) |
| vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True) |
|
|
| |
| image_z = image_z / image_z.norm(dim=-1, keepdim=True) |
| text_z = text_z / text_z.norm(dim=-1, keepdim=True) |
| image_p = (self.logit_scale * image_z @ vocabulary_z.T).softmax(dim=-1) |
| text_p = (self.logit_scale * text_z @ vocabulary_z.T).softmax(dim=-1) |
|
|
| |
| alpha = alpha or self.hparams["alpha"] |
| sample_p = alpha * image_p + (1 - alpha) * text_p |
|
|
| |
| samples_p.append(sample_p) |
| vocabularies.append(vocabulary) |
|
|
| |
| samples_p = torch.stack(samples_p, dim=0) |
| scores = sample_p.cpu() |
|
|
| |
| results = {"vocabularies": vocabularies, "scores": scores} |
|
|
| return results |
|
|