| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch CCT model.""" |
|
|
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput |
| from transformers import PreTrainedModel |
| from .configuration_cct import CctConfig |
|
|
| |
| _CONFIG_FOR_DOC = "CctConfig" |
|
|
| |
| _CHECKPOINT_FOR_DOC = "rishabbala/cct_14_7x2_384" |
| _EXPECTED_OUTPUT_SHAPE = [1, 384] |
|
|
| |
| _IMAGE_CLASS_CHECKPOINT = "rishabbala/cct_14_7x2_384" |
| _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" |
|
|
|
|
| CCT_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "rishabbala/cct_14_7x2_384", |
| "rishabbala/cct_14_7x2_224" |
| |
| ] |
|
|
|
|
| @dataclass |
| class BaseModelOutputWithSeqPool(ModelOutput): |
| """ |
| Base class for model's outputs, with potential hidden states and attentions. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model prior to sequential pooling. |
| hidden_state_post_pool (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model post sequential pooling. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
| shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer |
| plus the initial embedding outputs. |
| """ |
|
|
| last_hidden_state: torch.FloatTensor = None |
| hidden_state_post_pool: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| |
| def drop_path(input, drop_prob: float = 0.0, training: bool = False): |
| """ |
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| |
| Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, |
| however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the |
| layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the |
| argument. |
| """ |
| if drop_prob == 0.0 or not training: |
| return input |
| keep_prob = 1 - drop_prob |
| shape = (input.shape[0],) + (1,) * (input.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) |
| random_tensor.floor_() |
| output = input.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| |
| class CctDropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
|
| def __init__(self, drop_prob: Optional[float] = None) -> None: |
| super().__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return drop_path(hidden_states, self.drop_prob, self.training) |
|
|
| def extra_repr(self) -> str: |
| return "p={}".format(self.drop_prob) |
|
|
|
|
| class CctConvEmbeddings(nn.Module): |
| """ |
| Performs convolutional tokenization of the input image. |
| """ |
|
|
| def __init__(self, config: CctConfig): |
| super().__init__() |
| self.in_channels = config.in_channels |
| self.img_size = config.img_size |
|
|
| channels_size = [config.in_channels] + config.out_channels |
| assert ( |
| len(channels_size) == config.num_conv_layers + 1 |
| ), "Ensure that the number output channels matches the number of conv layers" |
|
|
| self.embedding_layers = nn.ModuleList([]) |
| for i in range(config.num_conv_layers): |
| self.embedding_layers.extend( |
| [ |
| nn.Conv2d( |
| channels_size[i], |
| channels_size[i + 1], |
| kernel_size=config.conv_kernel_size, |
| stride=config.conv_stride, |
| padding=config.conv_padding, |
| bias=config.conv_bias, |
| ), |
| nn.ReLU(), |
| nn.MaxPool2d(config.pool_kernel_size, stride=config.pool_stride, padding=config.pool_padding), |
| ] |
| ) |
|
|
| def forward(self, pixel_values): |
| for layer in self.embedding_layers: |
| pixel_values = layer(pixel_values) |
| batch_size, num_channels, height, width = pixel_values.shape |
| hidden_size = height * width |
| |
| pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1) |
| return pixel_values |
|
|
| def get_sequence_length(self) -> int: |
| return self.forward(torch.zeros((1, self.in_channels, self.img_size, self.img_size))).shape[1] |
|
|
|
|
| class CctSelfAttention(nn.Module): |
| """ |
| Attention Module that computes self-attention, given an input hidden_state. Q, K, V are computed implicitly from |
| hidden_state |
| """ |
|
|
| def __init__(self, embed_dim, num_heads=6, attention_drop_rate=0.1, drop_rate=0.0): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = embed_dim // self.num_heads |
| self.scale = head_dim**-0.5 |
|
|
| self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) |
| self.attn_drop = nn.Dropout(attention_drop_rate) |
| self.proj = nn.Linear(embed_dim, embed_dim) |
| self.proj_drop = nn.Dropout(drop_rate) |
|
|
| def forward(self, hidden_state): |
| B, N, C = hidden_state.shape |
| qkv = self.qkv(hidden_state).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| hidden_state = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| hidden_state = self.proj(hidden_state) |
| hidden_state = self.proj_drop(hidden_state) |
| return hidden_state |
|
|
|
|
| class CctStage(nn.Module): |
| """ |
| CCT stage composed of stacked transformer layers |
| """ |
|
|
| def __init__( |
| self, embed_dim=384, num_heads=6, mlp_ratio=3, drop_rate=0.0, attention_drop_rate=0.1, drop_path_rate=0.0 |
| ): |
| super().__init__() |
| dim_feedforward = mlp_ratio * embed_dim |
| self.pre_norm = nn.LayerNorm(embed_dim) |
|
|
| self.linear1 = nn.Linear(embed_dim, dim_feedforward) |
| self.norm1 = nn.LayerNorm(embed_dim) |
| self.linear2 = nn.Linear(dim_feedforward, embed_dim) |
| self.self_attn = CctSelfAttention( |
| embed_dim=embed_dim, num_heads=num_heads, attention_drop_rate=attention_drop_rate, drop_rate=drop_rate |
| ) |
| self.dropout1 = nn.Dropout(drop_rate) |
| self.dropout2 = nn.Dropout(drop_rate) |
| self.drop_path = CctDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() |
| self.activation = F.gelu |
|
|
| def forward(self, hidden_state): |
| hidden_state = hidden_state + self.drop_path(self.self_attn(self.pre_norm(hidden_state))) |
| hidden_state = self.norm1(hidden_state) |
| hidden_state = hidden_state + self.drop_path( |
| self.dropout2(self.linear2(self.dropout1(self.activation(self.linear1(hidden_state))))) |
| ) |
|
|
| return hidden_state |
|
|
|
|
| class CctEncoder(nn.Module): |
| """ |
| Class that combines CctConvEmbeddings and CctStage. Output is of type BaseModelOutputWithSeqPool if return_dict is |
| set to True, else the output is a Tuple |
| """ |
|
|
| def __init__(self, config: CctConfig, sequence_length: int): |
| super().__init__() |
| assert sequence_length is not None, "Sequence Length required to initialize positional embedding" |
|
|
| int(config.embed_dim * config.mlp_ratio) |
| self.attention_pool = nn.Linear(config.embed_dim, 1) |
|
|
| if config.pos_emb_type == "learnable": |
| self.positional_emb = nn.Parameter( |
| self.learnable_embedding(sequence_length, config.embed_dim), requires_grad=True |
| ) |
| else: |
| self.positional_emb = nn.Parameter( |
| self.sinusoidal_embedding(sequence_length, config.embed_dim), requires_grad=False |
| ) |
|
|
| self.dropout = nn.Dropout(config.drop_rate) |
| stochastic_dropout_rate = [ |
| x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_transformer_layers) |
| ] |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| CctStage( |
| config.embed_dim, |
| config.num_heads, |
| config.mlp_ratio, |
| config.drop_rate, |
| config.attention_drop_rate, |
| stochastic_dropout_rate[i], |
| ) |
| for i in range(config.num_transformer_layers) |
| ] |
| ) |
| self.norm = nn.LayerNorm(config.embed_dim) |
|
|
| def forward(self, pixel_values, output_hidden_states=False, return_dict=True) -> BaseModelOutputWithSeqPool: |
| all_hidden_states = () |
|
|
| hidden_state = pixel_values + self.positional_emb |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_state,) |
| hidden_state = self.dropout(hidden_state) |
|
|
| for blk in self.blocks: |
| hidden_state = blk(hidden_state) |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_state,) |
|
|
| hidden_state_pre_pool = self.norm(hidden_state) |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states[:-1] + (hidden_state_pre_pool,) |
|
|
| seq_pool_attn = F.softmax(self.attention_pool(hidden_state_pre_pool), dim=1) |
| hidden_state_post_pool = torch.matmul(seq_pool_attn.transpose(-1, -2), hidden_state_pre_pool).squeeze(-2) |
| seq_pool_attn = seq_pool_attn.squeeze() |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_state_post_pool,) |
|
|
| if not return_dict: |
| if output_hidden_states: |
| return (hidden_state_pre_pool, hidden_state_post_pool, all_hidden_states) |
| else: |
| return (hidden_state_pre_pool, hidden_state_post_pool) |
|
|
| return BaseModelOutputWithSeqPool( |
| last_hidden_state=hidden_state_pre_pool, |
| hidden_state_post_pool=hidden_state_post_pool, |
| hidden_states=all_hidden_states if output_hidden_states else None, |
| ) |
|
|
| @staticmethod |
| def learnable_embedding(sequence_length, embed_dim): |
| pe = torch.zeros(1, sequence_length, embed_dim) |
| return nn.init.trunc_normal_(pe, std=0.2) |
|
|
| @staticmethod |
| def sinusoidal_embedding(sequence_length, embed_dim): |
| pe = torch.FloatTensor( |
| [[p / (10000 ** (2 * (i // 2) / embed_dim)) for i in range(embed_dim)] for p in range(sequence_length)] |
| ) |
| pe[:, 0::2] = torch.sin(pe[:, 0::2]) |
| pe[:, 1::2] = torch.cos(pe[:, 1::2]) |
| return pe.unsqueeze(0) |
|
|
|
|
| class CctPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = CctConfig |
| base_model_prefix = "cct" |
| main_input_name = "pixel_values" |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.ModuleList): |
| for module_child in module: |
| self._init_weights(module_child) |
| elif isinstance(module, nn.Module) and len(list(module.children())) > 0: |
| for module_child in module.children(): |
| self._init_weights(module_child) |
| elif isinstance(module, nn.Linear): |
| nn.init.trunc_normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0.0) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0.0) |
| nn.init.constant_(module.weight, 1.0) |
| elif isinstance(module, nn.Conv2d): |
| nn.init.kaiming_normal_(module.weight) |
|
|
|
|
| class CctModel(CctPreTrainedModel): |
| def __init__(self, config, add_pooling_layer=True): |
| super().__init__(config) |
| self.config = config |
| self.embedder = CctConvEmbeddings(config) |
| self.encoder = CctEncoder(config, self.embedder.get_sequence_length()) |
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithSeqPool]: |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if pixel_values is None: |
| raise ValueError("You have to specify pixel_values (input image)") |
|
|
| embedder_outputs = self.embedder(pixel_values) |
| encoder_outputs = self.encoder( |
| embedder_outputs, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| return encoder_outputs |
|
|
| class CctForImageClassification(CctPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.num_labels = config.num_labels |
| self.cct = CctModel(config, add_pooling_layer=False) |
| |
| self.classifier = nn.Linear(config.embed_dim, config.num_labels) if config.num_labels > 0 else nn.Identity() |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the image classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
|
|
| outputs = self.cct( |
| pixel_values, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs.hidden_state_post_pool if return_dict else outputs[1] |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.config.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.config.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
|
|
| if not return_dict: |
| out = (logits, outputs[2]) if output_hidden_states else (logits,) |
| return (loss,) + out if loss is not None else out |
|
|
| return ImageClassifierOutputWithNoAttention( |
| loss=loss, logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None |
| ) |
|
|