File size: 1,819 Bytes
4ee5289
 
c26362f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from .configuration_talk2dino import Talk2DINOConfig
from .dinotext import DINOText
from transformers import PreTrainedModel
import clip
import torch

class Talk2DINO(DINOText, PreTrainedModel):
    config_class = Talk2DINOConfig

    def __init__(self, config: Talk2DINOConfig):
        # Store the config
        self.config = config

        # Convert config to a dict (works for PretrainedConfig subclasses)
        cfg_dict = config.to_dict()

        # Initialize parent (DINOText) with unpacked kwargs
        super().__init__(**cfg_dict)

    def encode_text(self, texts):
        """ texts: string or list of strings
         returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension
        """
        text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device)
        txt_embed = self.clip_model.encode_text(text_tokens)
        txt_embed = self.proj.project_clip_txt(txt_embed)
        return txt_embed
    
    def encode_image(self, images):
        """ images: PIL image or list of PIL images
         returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension
        """
        if type(images) is not list:
            images = [images]
        img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images]
        img_preprocessed = torch.stack(img_preprocessed)
        if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
            img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
        elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
            img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :]
              
        return img_embed