import math import torch import torch.nn as nn import timm from typing import Optional, Union from transformers.activations import ACT2FN from transformers.models.mask2former.modeling_mask2former import ( Mask2FormerPixelDecoder, Mask2FormerPixelDecoderOutput, Mask2FormerPixelDecoderEncoderOnly, Mask2FormerPixelDecoderEncoderLayer, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention, ) from transformers.modeling_outputs import BaseModelOutput from src.models.loupe.configuration_loupe import LoupeConfig # ----------------------------------- CLASSIFICATION-RELATED MODULES ----------------------------------- class LoupeClsHead(nn.Module): def __init__( self, input_dim: int, hidden_dim: Optional[int] = None, num_layers: int = 2, hidden_act: Union[str, type] = "gelu", ) -> None: super().__init__() self.hidden_size = hidden_dim self.num_layers = num_layers if num_layers >= 2 and hidden_dim is None: raise ValueError( "If num_layers >= 2, hidden_dim must be specified. " "Otherwise, the model will not be able to learn." ) if isinstance(hidden_act, str): self.hidden_act = ACT2FN[hidden_act] else: self.hidden_act = hidden_act self.layers = nn.ModuleList( [ nn.Linear( input_dim if i == 0 else self.hidden_size, self.hidden_size, ) for i in range(self.num_layers - 1) ] ) self.layers.append( nn.Linear( (input_dim if self.num_layers == 1 else self.hidden_size), 1, # logits for predicting if it is forgery ) ) def init_tensors(self): for layer in self.layers: layer.reset_parameters() def forward(self, x): for layer in self.layers: x = layer(x) if layer != self.layers[-1]: x = self.hidden_act(x) return x class FuseHead(nn.Module): def __init__(self, config: LoupeConfig) -> None: super().__init__() num_patches = (config.image_size // config.patch_size) ** 2 self.fuse = nn.Linear(1 + num_patches, 1, bias=False) def init_tensors(self): nn.init.constant_(self.fuse.weight.data, 1 / self.fuse.in_features) def forward(self, x): x = self.fuse(x) return x # ----------------------------------- SEGMENTATION-RELATED MODULES ----------------------------------- class ScaleBlock(nn.Module): """ Upscale or downscale the input feature map 2x times using nn.ConvTranspose2d or nn.Conv2d. """ def __init__( self, n_channels, conv1_layer=nn.ConvTranspose2d, hidden_act: Union[str, type] = "gelu", ) -> None: super().__init__() if conv1_layer == nn.ConvTranspose2d: channel_args = ( n_channels, n_channels, ) else: channel_args = () self.conv1 = conv1_layer( *channel_args, kernel_size=2, stride=2, ) if isinstance(hidden_act, str): self.hidden_act = ACT2FN[hidden_act] else: self.hidden_act = hidden_act self.conv2 = nn.Conv2d( n_channels, n_channels, kernel_size=3, padding=1, groups=n_channels, bias=False, ) self.norm = timm.layers.LayerNorm2d(n_channels) def forward(self, x): x = self.conv1(x) x = self.hidden_act(x) x = self.conv2( x.contiguous() ) # who knows why I have to add contiguous here ????? x = self.norm(x) return x class FeaturePyramid(nn.Module): def __init__(self, n_channels: int, scales: list[float | int] = None): """ Initializes the FeaturePyramid with the given scales. Args: n_channels (int): The number of channels in the input feature map. scales (list[float or int]): A list whose length=4 representing the scales for the pyramid. Should be integer powers of 2 in ascending order. Defaults to [1/2, 1, 2, 4]. """ super().__init__() self.hidden_dim = n_channels self.scales = scales or [1 / 2, 1, 2, 4] is_power_of_2 = lambda n: n > 0 and math.isclose( math.log2(n), round(math.log2(n)) ) if any(not is_power_of_2(scale) for scale in self.scales): raise ValueError( f"All scales must be integer powers of 2, but got {self.scales}" ) self.scale_layers = nn.ModuleList( [ nn.Sequential(*self._make_layer(scale)) for scale in sorted(self.scales, reverse=True) ] ) def _make_layer(self, scale: float | int) -> list[nn.Module]: if scale == 1: return [nn.Identity()] conv1_layer = nn.ConvTranspose2d if scale > 1 else nn.MaxPool2d num_steps = abs(int(round(math.log2(scale)))) return [ ScaleBlock(self.hidden_dim, conv1_layer=conv1_layer) for _ in range(num_steps) ] def forward(self, x): return [layer(x) for layer in self.scale_layers] class PixelDecoderEncoderLayer(Mask2FormerPixelDecoderEncoderLayer): def __init__(self, config: LoupeConfig): mask2former_config = config.mask2former_config super().__init__(mask2former_config) if config.enable_conditional_queries: self.cross_attn = nn.MultiheadAttention( embed_dim=self.embed_dim, num_heads=mask2former_config.num_attention_heads, batch_first=True, ) self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, conditional_queries: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): """ Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Input to the layer. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Attention mask. conditional_queries (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`): Pseudo query for the cross attention, added from Loupe. position_embeddings (`torch.FloatTensor`, *optional*): Position embeddings, to be added to `hidden_states`. reference_points (`torch.FloatTensor`, *optional*): Reference points. spatial_shapes_list (`list` of `tuple`): Spatial shapes of the backbone feature maps as a list of tuples. level_start_index (`torch.LongTensor`, *optional*): Level start index. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training ) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) if conditional_queries is not None: residual = hidden_states # Apply cross attention on the pseudo query. hidden_states, _ = self.cross_attn( query=conditional_queries, key=hidden_states, value=hidden_states, key_padding_mask=attention_mask, need_weights=False, ) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training ) hidden_states = residual + hidden_states hidden_states = self.cross_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout( hidden_states, p=self.activation_dropout, training=self.training ) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training ) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) if self.training: if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp( hidden_states, min=-clamp_value, max=clamp_value ) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights.transpose(1, 0),) return outputs class PixelDecoderConditionalEncoder(Mask2FormerPixelDecoderEncoderOnly): def __init__(self, config: LoupeConfig): mask2former_config = config.mask2former_config super().__init__(mask2former_config) # replace the original encoder layer with our loupe encoder layer self.layers = nn.ModuleList( [ PixelDecoderEncoderLayer(config) for _ in range(mask2former_config.encoder_layers) ] ) def forward( self, inputs_embeds=None, attention_mask=None, conditional_queries=None, position_embeddings=None, spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: - 1 for pixel features that are real (i.e. **not masked**), - 0 for pixel features that are padding (i.e. **masked**). [What are attention masks?](../glossary#attention-mask) conditional_queries (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`): Conditional query for the cross attention, added from Loupe. position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Position embeddings that are added to the queries and keys in each self-attention layer. spatial_shapes_list (`list` of `tuple`): Spatial shapes of each feature map as a list of tuples. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): Starting index of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): Ratio of valid area in each feature level. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = inputs_embeds reference_points = self.get_reference_points( spatial_shapes_list, valid_ratios, device=inputs_embeds.device ) all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, encoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states.transpose(1, 0),) layer_outputs = encoder_layer( hidden_states, attention_mask, conditional_queries=conditional_queries, position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states.transpose(1, 0),) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, ) class PixelDecoder(Mask2FormerPixelDecoder): def __init__(self, config: LoupeConfig): mask2former_config = config.mask2former_config super().__init__(mask2former_config, config.feature_channels) # replace the original encoder with our loupe conditional encoder self.encoder = PixelDecoderConditionalEncoder(config) # modified from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoder def forward( self, features, conditional_queries=None, encoder_outputs=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # Apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) input_embeds = [] position_embeddings = [] for level, x in enumerate(features[::-1][: self.num_feature_levels]): input_embeds.append(self.input_projections[level](x)) position_embeddings.append(self.position_embedding(x)) masks = [ torch.zeros( (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool ) for x in input_embeds ] # Prepare encoder inputs (by flattening) spatial_shapes_list = [ (embed.shape[2], embed.shape[3]) for embed in input_embeds ] input_embeds_flat = torch.cat( [embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1 ) spatial_shapes = torch.as_tensor( spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device ) masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) position_embeddings = [ embed.flatten(2).transpose(1, 2) for embed in position_embeddings ] level_pos_embed_flat = [ x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings) ] level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) level_start_index = torch.cat( (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) ) valid_ratios = torch.stack( [ self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks ], 1, ) # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder if encoder_outputs is None: encoder_outputs = self.encoder( inputs_embeds=input_embeds_flat, attention_mask=masks_flat, conditional_queries=conditional_queries, position_embeddings=level_pos_embed_flat, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs.last_hidden_state batch_size = last_hidden_state.shape[0] # We compute level_start_index_list separately from the tensor version level_start_index # to avoid iterating over a tensor which breaks torch.compile/export. level_start_index_list = [0] for height, width in spatial_shapes_list[:-1]: level_start_index_list.append(level_start_index_list[-1] + height * width) split_sizes = [None] * self.num_feature_levels for i in range(self.num_feature_levels): if i < self.num_feature_levels - 1: split_sizes[i] = ( level_start_index_list[i + 1] - level_start_index_list[i] ) else: split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i] encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) # Compute final features outputs = [ x.transpose(1, 2).view( batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1] ) for i, x in enumerate(encoder_output) ] # Append extra FPN levels to outputs, ordered from low to high resolution for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]): lateral_conv = self.lateral_convolutions[idx] output_conv = self.output_convolutions[idx] current_fpn = lateral_conv(feature) # Following FPN implementation, we use nearest upsampling here out = current_fpn + nn.functional.interpolate( outputs[-1], size=current_fpn.shape[-2:], mode="bilinear", align_corners=False, ) out = output_conv(out) outputs.append(out) num_cur_levels = 0 multi_scale_features = [] for out in outputs: if num_cur_levels < self.num_feature_levels: multi_scale_features.append(out) num_cur_levels += 1 return Mask2FormerPixelDecoderOutput( mask_features=self.mask_projection(outputs[-1]), multi_scale_features=tuple(multi_scale_features), attentions=encoder_outputs.attentions, )