M4CXR-TNNLS / visual_encoders.py
jonggwon-park's picture
debug import
795e71e
"""Modified from https://github.com/khanrc/honeybee
"""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers import CLIPVisionModel, Dinov2Model
from transformers.modeling_outputs import BaseModelOutput
from .utils import check_local_file, transformers_log_level
class VisualEncoderMixin:
"""VisualEncoderMixin is an abstract class for visual encoders."""
def get_dtype(self):
"""dtype of visual encoder"""
raise NotImplementedError()
def get_num_tokens(self) -> int:
"""The number of ouptut tokens. Mostly, num_patches (without cls token)"""
raise NotImplementedError()
def has_cls_token(self) -> bool:
"""Whether the encoder has cls token or not. Default: True"""
return True
def freeze_blocks(self, n: int):
"""Freeze the first n blocks of the encoder"""
raise NotImplementedError()
def postprocess_for_projector(self, visual_features):
"""Perform any post-processing (e.g., cls feature removal) for visual features
before submitting to projector.
"""
if self.has_cls_token():
# defaultly, we assume cls token is located at 0-index if exists.
visual_features = (
visual_features[:, 1:]
if visual_features.ndim == 3
else visual_features[:, :, 1:]
)
return visual_features
@dataclass
class VisionModelOutput(BaseModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class CustomCLIP(CLIPVisionModel, VisualEncoderMixin):
def get_dtype(self):
return self.vision_model.embeddings.class_embedding.data.dtype
def get_num_tokens(self):
return self.vision_model.embeddings.num_positions - 1 # -1: excluding cls token
class CustomDinov2Model(Dinov2Model, VisualEncoderMixin):
def get_dtype(self):
return self.embeddings.patch_embeddings.projection.weight.dtype
def get_num_tokens(self):
return self.embeddings.patch_embeddings.num_patches
def build_encoder(config):
with transformers_log_level(40): # 40 == logging.ERROR
if config.encoder_type == "openai.clip":
vm_local_files_only, vm_file_name = check_local_file(
config.pretrained_vision_name_or_path
)
model = CustomCLIP.from_pretrained(
vm_file_name,
local_files_only=vm_local_files_only,
)
elif config.encoder_type == "dinov2":
vm_local_files_only, vm_file_name = check_local_file(
config.pretrained_vision_name_or_path
)
model = CustomDinov2Model.from_pretrained(
vm_file_name,
local_files_only=vm_local_files_only,
)
else:
raise NotImplementedError()
return model