| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch.nn as nn |
| |
|
| | from transformers.modeling_outputs import BaseModelOutputWithPooling |
| | from typing import Optional, Tuple, Union |
| |
|
| | from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower |
| | from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig |
| |
|
| |
|
| | class MultiBackboneChannelConcatenationVisionModel(nn.Module): |
| |
|
| | """ |
| | A vision model wrapper that concatenates channels from multiple backbones. |
| | |
| | Args: |
| | config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model. |
| | |
| | Attributes: |
| | vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation. |
| | |
| | Notes: |
| | **The class is not inherited from the PreTrainedModel in transformers** |
| | |
| | """ |
| |
|
| | config_class = MultiBackboneChannelConcatenationVisionModelConfig |
| | main_input_name = "pixel_values" |
| |
|
| | def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config): |
| | super().__init__() |
| |
|
| | self.vision_model = MultiBackboneChannelConcatenationVisionTower( |
| | vision_tower=config.vision_tower, |
| | args=config, |
| | grid_size=config.grid_size, |
| | convnext_img_size=config.convnext_img_size, |
| | normalize_type=config.normalize_type, |
| | raw_config=raw_config |
| | ) |
| |
|
| |
|
| | def get_input_embeddings(self): |
| | |
| | return self.vision_model.vision_towers[0].get_input_embeddings() |
| |
|
| | def forward( |
| | self, |
| | pixel_values, |
| | return_dict: Optional[bool] = True, |
| | output_hidden_states: Optional[bool] = False, |
| | ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | assert return_dict is True, "We only support return_dict" |
| | assert output_hidden_states is False, "We do not support output_hidden_states" |
| |
|
| | features = self.vision_model(pixel_values) |
| |
|
| | |
| | return BaseModelOutputWithPooling( |
| | last_hidden_state=features, |
| | pooler_output=None, |
| | hidden_states=None, |
| | attentions=None, |
| | ) |
| |
|
| | @property |
| | def dummy_feature(self): |
| | return self.vision_model.dummy_feature |
| |
|
| | @property |
| | def dtype(self): |
| | return self.vision_model.dtype |
| |
|
| | @property |
| | def device(self): |
| | return self.vision_model.device |
| |
|
| | @property |
| | def config(self): |
| | return self.vision_model.config |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.vision_model.hidden_size |
| |
|
| | @property |
| | def num_patches(self): |
| | return self.vision_model.num_patches |
| |
|