from typing import Any import torch from torch import nn import math from fractions import Fraction from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel import torch.nn.functional as F class QFormerCrossAttention(nn.Module): """Multi-headed cross-attention for QFormer with SDPA/Flash Attention support""" def __init__(self, hidden_size, num_heads, attn_bias=False, attention_dropout=0.05, final_dropout=0.05): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.attention_dropout = attention_dropout if self.head_dim * num_heads != hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size} " f"and `num_heads`: {num_heads})." ) # Q from queries, K and V from encoder self.q_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.dropout = nn.Dropout(final_dropout) def forward(self, hidden_states, encoder_hidden_states, attention_mask=None): """ Args: hidden_states: (B, query_len, hidden_size) - queries encoder_hidden_states: (B, encoder_len, hidden_size) - keys and values attention_mask: optional attention mask Returns: (B, query_len, hidden_size) """ batch_size, query_len, _ = hidden_states.shape encoder_len = encoder_hidden_states.shape[1] # Project queries from hidden_states query_states = self.q_proj(hidden_states).view( batch_size, query_len, self.num_heads, self.head_dim ).transpose(1, 2) # Project keys and values from encoder_hidden_states key_states = self.k_proj(encoder_hidden_states).view( batch_size, encoder_len, self.num_heads, self.head_dim ).transpose(1, 2) value_states = self.v_proj(encoder_hidden_states).view( batch_size, encoder_len, self.num_heads, self.head_dim ).transpose(1, 2) # Use PyTorch's scaled_dot_product_attention (SDPA) # This automatically uses Flash Attention when available attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=False, ) # Reshape back to (B, query_len, hidden_size) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size) attn_output = self.o_proj(attn_output) attn_output = self.dropout(attn_output) return attn_output class QFormerMLP(nn.Module): """Feed-forward network (MLP) for QFormer with SiLU activation""" def __init__(self, hidden_size, mlp_hidden_size, mlp_bias=False, dropout_prob=0.05): super().__init__() self.hidden_size = hidden_size self.fc1 = nn.Linear(hidden_size, mlp_hidden_size, bias=mlp_bias) self.act = nn.SiLU() self.fc2 = nn.Linear(mlp_hidden_size, hidden_size, bias=mlp_bias) self.dropout = nn.Dropout(dropout_prob) def forward(self, hidden_states): """ Args: hidden_states: (B, seq_len, hidden_size) Returns: (B, seq_len, hidden_size) """ hidden_states = self.fc1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(self.fc2(hidden_states)) return hidden_states class SimplifiedQFormer(nn.Module): """ Simplified QFormer with a single cross-attention layer followed by an MLP. Lightweight design: queries attend to encoder hidden states via cross-attention, then pass through a feed-forward network, similar to a transformer block. """ def __init__(self, hidden_size, num_heads=8, mlp_hidden_size=2048, mlp_bias=False, attn_bias=False, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads # Cross-attention block self.attn_norm = nn.LayerNorm(hidden_size, eps=eps) self.cross_attention = QFormerCrossAttention( hidden_size, num_heads, attn_bias=attn_bias, ) # MLP block (feed-forward network) self.mlp_norm = nn.LayerNorm(hidden_size, eps=eps) self.mlp = QFormerMLP(hidden_size, mlp_hidden_size, mlp_bias=mlp_bias) def forward(self, query_embeds, encoder_hidden_states): """ Args: query_embeds: (B, num_queries, hidden_size) - learnable queries encoder_hidden_states: (B, num_tokens, hidden_size) - input features Returns: (B, num_queries, hidden_size) - output features """ # Cross-attention block with residual and pre-norm residual = query_embeds hidden_states = self.attn_norm(query_embeds) hidden_states = self.cross_attention(hidden_states, encoder_hidden_states) hidden_states = residual + hidden_states # MLP block with residual and pre-norm residual = hidden_states hidden_states = self.mlp_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class InterpolateDownsampler: def __init__(self, config, mode="area"): self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate)) self.mode = mode def __call__(self, image_features): batch_size, _, dim = image_features.size() up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim] # interpolate expects B,C,H,W large_image_permuted = image_features.view(up_shape).permute(0,3,1,2) small_image_permuted = torch.nn.functional.interpolate( large_image_permuted, size=(self.new_image_side, self.new_image_side), mode=self.mode, ) # back to B,H*W,C final = small_image_permuted.permute(0,2,3,1).flatten(1,2) return final class SpatialOffsetDownsampler: """ Downsampler that samples with local block continuity pattern. Instead of global strided [1,0,1,0], creates local 2x2 blocks where sampling creates continuity: within each 2x2 block, adjacent samples are spatially adjacent. """ def __init__(self, config, offset=0): """ Args: config: Model configuration offset: Integer offset (0, 1, 2, or 3) for position within each 2x2 block 0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right """ self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size self.new_image_side = self.orig_image_side // 2 # downsample by 2x self.offset = offset # Map offset to position within 2x2 blocks self.offsets = [(0, 0), (0, 1), (1, 0), (1, 1)] self.offset_h, self.offset_w = self.offsets[offset] def __call__(self, image_features): """ Extract features by sampling one position from each 2x2 block across the image. This maintains full spatial coverage while creating local continuity. For a 4x4 image with offset=0 (top-left of each 2x2 block): Original: Sampled (raster order): [A B | C D] [A C] [E F | G H] -> [I K] [---+---] [I J | K L] [M N | O P] Result in sequence: [A, C, I, K] - maintains spatial structure Args: image_features: Tensor of shape [batch, height*width, hidden_dim] Returns: Downsampled features of shape [batch, (height/2)*(width/2), hidden_dim] """ batch_size, seq_len, hidden_dim = image_features.shape # Reshape to [batch, height, width, hidden_dim] features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim) # Reshape into 2x2 blocks: [batch, n_blocks_h, 2, n_blocks_w, 2, hidden_dim] n_blocks = self.new_image_side features_blocks = features_2d.reshape( batch_size, n_blocks, 2, n_blocks, 2, hidden_dim ) # Select the specified position from each 2x2 block # This maintains spatial coverage while creating local continuity sampled = features_blocks[:, :, self.offset_h, :, self.offset_w, :] # Flatten spatial dimensions back to [batch, n_blocks*n_blocks, hidden_dim] sampled = sampled.reshape(batch_size, -1, hidden_dim) return sampled class SpatialQuadrantDownsampler: """ Alternative downsampler that samples contiguous spatial quadrants. Takes a full quadrant of the image rather than sampling across the entire image. This creates maximum local continuity but only covers 1/4 of the spatial extent. Use case: When you want queries to focus on a specific region with maximum local coherence, trading off global spatial coverage. """ def __init__(self, config, offset=0): """ Args: config: Model configuration offset: Integer offset (0, 1, 2, or 3) for quadrant selection 0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right """ self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size self.new_image_side = self.orig_image_side // 2 # downsample by 2x self.offset = offset # Map offset to quadrant starting positions self.offsets = [ (0, 0), # top-left (0, self.new_image_side), # top-right (self.new_image_side, 0), # bottom-left (self.new_image_side, self.new_image_side) # bottom-right ] self.start_h, self.start_w = self.offsets[offset] def __call__(self, image_features): """ Extract a contiguous quadrant from the image. For a 4x4 image with offset=0 (top-left quadrant): Original: Sampled: [A B | C D] [A B] [E F | G H] -> [E F] [---+---] [I J | K L] [M N | O P] Result in sequence: [A, B, E, F] - maximum local continuity Args: image_features: Tensor of shape [batch, height*width, hidden_dim] Returns: Downsampled features of shape [batch, (height/2)*(width/2), hidden_dim] """ batch_size, seq_len, hidden_dim = image_features.shape # Reshape to [batch, height, width, hidden_dim] features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim) # Extract contiguous quadrant sampled = features_2d[:, self.start_h:self.start_h + self.new_image_side, self.start_w:self.start_w + self.new_image_side, :] # Flatten spatial dimensions back to [batch, new_height*new_width, hidden_dim] sampled = sampled.reshape(batch_size, -1, hidden_dim) return sampled class WindowQFormerDownsampler(nn.Module): def __init__(self, config, checkerboard_offset=None, use_quadrant_sampling=False): super().__init__() llm_hidden_size = config.text_config.hidden_size vision_hidden_size = config.vision_config.hidden_size # Dropout rates for robustness (conservative approach) self.dropout = nn.Dropout(config.projector_dropout) # Choose downsampler based on parameters if checkerboard_offset is not None: if use_quadrant_sampling: # Use quadrant sampling: maximum local continuity, limited spatial coverage self.downsampler = SpatialQuadrantDownsampler(config, offset=checkerboard_offset) else: # Use block sampling: balanced continuity and full spatial coverage (default) self.downsampler = SpatialOffsetDownsampler(config, offset=checkerboard_offset) else: self.downsampler = InterpolateDownsampler(config) self.use_simplified_qformer = config.simplified_qformer # Choose between SimplifiedQFormer and Blip2QFormerModel if self.use_simplified_qformer: # Use our simplified QFormer with full self-attention self.qformer = SimplifiedQFormer( hidden_size=vision_hidden_size, num_heads=vision_hidden_size // 64, mlp_hidden_size=3072, mlp_bias=True, attn_bias=True ) else: # Use original Blip2QFormerModel with cross-attention configuration = Blip2QFormerConfig( hidden_size=vision_hidden_size, num_attention_heads=vision_hidden_size // 64, intermediate_size=3072, num_hidden_layers=1, encoder_hidden_size=vision_hidden_size, cross_attention_frequency=1, max_position_embeddings=2048, use_qformer_text_input=False, ) self.qformer = Blip2QFormerModel(configuration) self.image_side = config.vision_config.image_size // config.vision_config.patch_size q, w = config.downsample_rate.split("/") self.query_side, self.window_side = int(q), int(w) # query length is cubical for seamless integration with llava next self.query_length = self.query_side ** 2 embed_std = 1 / math.sqrt(vision_hidden_size) self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6) self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std) # qformer model doesn't have positional embeddings, adding to the flat patches self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std) self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True) def _win(self, x, side, win): """ (B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win windows are raster-ordered, and tokens inside each window are raster-ordered. """ B, _, C = x.shape n = side // win return ( x.view(B, side, side, C) .view(B, n, win, n, win, C) .transpose(2, 3) # (B, n, n, win, win, C) .flatten(0, 2) # (B*n*n, win, win, C) .flatten(1, 2) # (B*n*n, win*win, C) ) def _unwin(self, xw, n, win): """ (B*n*n, win*win, C) -> (B, (n*win)^2, C) raster """ Bnn, _, C = xw.shape assert Bnn % (n * n) == 0 B = Bnn // (n * n) side = n * win return ( xw.view(B, n, n, win, win, C) .transpose(2, 3) # (B, n, win, n, win, C) .contiguous() .view(B, side, side, C) .flatten(1, 2) ) def forward(self, image_features): B, HW, C = image_features.shape assert HW == self.image_side * self.image_side n = self.image_side // self.window_side image_features = self.norm(image_features) enc = self._win(image_features, self.image_side, self.window_side) # (B*n^2, w^2, C) # Apply downsampling (either spatial offset or interpolation) downsampled = self.downsampler(image_features) # (B, new_side^2, C) raster new_side = n * self.query_side downsampled_w = self._win(downsampled, new_side, self.query_side) # (B*n^2, q^2, C) # Apply QFormer based on the chosen mechanism if self.use_simplified_qformer: # SimplifiedQFormer: full self-attention between queries and inputs # Broadcasting handles batch dimension automatically # Apply dropout to embeddings for robustness query_embeds = self.dropout(self.query + downsampled_w) encoder_embeds = self.dropout(enc + self.image_positions) out_w = self.qformer( query_embeds=query_embeds, encoder_hidden_states=encoder_embeds ) # (B*n^2, q^2, C) else: # Blip2QFormerModel: cross-attention mechanism # Apply dropout to embeddings for robustness query_embeds = self.query + downsampled_w # blip already dropouts the queries encoder_embeds = self.dropout(enc + self.image_positions) out_w = self.qformer( query_embeds=query_embeds, encoder_hidden_states=encoder_embeds, return_dict=True, ).last_hidden_state # (B*n^2, q^2, C) out = self._unwin(out_w, n=n, win=self.query_side) # (B, new_side^2, C) raster # Apply output dropout before final projection out = self.dropout(out) return self.out_linear(out)