"""Modified from https://github.com/khanrc/honeybee """ from dataclasses import dataclass from typing import Optional, Tuple import torch from transformers import CLIPVisionModel, Dinov2Model from transformers.modeling_outputs import BaseModelOutput from .utils import check_local_file, transformers_log_level class VisualEncoderMixin: """VisualEncoderMixin is an abstract class for visual encoders.""" def get_dtype(self): """dtype of visual encoder""" raise NotImplementedError() def get_num_tokens(self) -> int: """The number of ouptut tokens. Mostly, num_patches (without cls token)""" raise NotImplementedError() def has_cls_token(self) -> bool: """Whether the encoder has cls token or not. Default: True""" return True def freeze_blocks(self, n: int): """Freeze the first n blocks of the encoder""" raise NotImplementedError() def postprocess_for_projector(self, visual_features): """Perform any post-processing (e.g., cls feature removal) for visual features before submitting to projector. """ if self.has_cls_token(): # defaultly, we assume cls token is located at 0-index if exists. visual_features = ( visual_features[:, 1:] if visual_features.ndim == 3 else visual_features[:, :, 1:] ) return visual_features @dataclass class VisionModelOutput(BaseModelOutput): """ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. """ last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None class CustomCLIP(CLIPVisionModel, VisualEncoderMixin): def get_dtype(self): return self.vision_model.embeddings.class_embedding.data.dtype def get_num_tokens(self): return self.vision_model.embeddings.num_positions - 1 # -1: excluding cls token class CustomDinov2Model(Dinov2Model, VisualEncoderMixin): def get_dtype(self): return self.embeddings.patch_embeddings.projection.weight.dtype def get_num_tokens(self): return self.embeddings.patch_embeddings.num_patches def build_encoder(config): with transformers_log_level(40): # 40 == logging.ERROR if config.encoder_type == "openai.clip": vm_local_files_only, vm_file_name = check_local_file( config.pretrained_vision_name_or_path ) model = CustomCLIP.from_pretrained( vm_file_name, local_files_only=vm_local_files_only, ) elif config.encoder_type == "dinov2": vm_local_files_only, vm_file_name = check_local_file( config.pretrained_vision_name_or_path ) model = CustomDinov2Model.from_pretrained( vm_file_name, local_files_only=vm_local_files_only, ) else: raise NotImplementedError() return model