| from .configuration_talk2dino import Talk2DINOConfig |
| from .dinotext import DINOText |
| from transformers import PreTrainedModel |
| import clip |
| import torch |
|
|
| class Talk2DINO(DINOText, PreTrainedModel): |
| config_class = Talk2DINOConfig |
|
|
| def __init__(self, config: Talk2DINOConfig): |
| |
| self.config = config |
|
|
| |
| cfg_dict = config.to_dict() |
|
|
| |
| super().__init__(**cfg_dict) |
|
|
| def encode_text(self, texts): |
| """ texts: string or list of strings |
| returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension |
| """ |
| text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device) |
| txt_embed = self.clip_model.encode_text(text_tokens) |
| txt_embed = self.proj.project_clip_txt(txt_embed) |
| return txt_embed |
| |
| def encode_image(self, images): |
| """ images: PIL image or list of PIL images |
| returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension |
| """ |
| if type(images) is not list: |
| images = [images] |
| img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images] |
| img_preprocessed = torch.stack(img_preprocessed) |
| |
| img_embed = self.model.forward_features(img_preprocessed)[:, 5:] |
| |
| return img_embed |