import torch from torch import nn import math from fractions import Fraction from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel class BilinearDownsampler: def __init__(self, config): self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate)) def __call__(self, image_features): batch_size, _, dim = image_features.size() up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim] # interpolate expects B,C,H,W large_image_permuted = image_features.view(up_shape).permute(0,3,1,2) small_image_permuted = torch.nn.functional.interpolate( large_image_permuted, size=(self.new_image_side, self.new_image_side), mode="area", ) # back to B,H*W,C final = small_image_permuted.permute(0,2,3,1).flatten(1,2) return final class QFormerDownsampler(nn.Module): def __init__(self, config): super().__init__() llm_hidden_size = config.text_config.hidden_size self.bilinear = BilinearDownsampler(config) configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, num_attention_heads=32, intermediate_size=4096, num_hidden_layers=1, encoder_hidden_size=llm_hidden_size, cross_attention_frequency=1, max_position_embeddings=2048, use_qformer_text_input=False, ) self.qformer = Blip2QFormerModel(configuration) self.image_side = config.vision_config.image_size // config.vision_config.patch_size downsample_rate = Fraction(config.downsample_rate) self.query_side = int(downsample_rate * self.image_side) # query length is cubical for seamless integration with llava next self.query_length = self.query_side ** 2 embed_std = 1 / math.sqrt(llm_hidden_size) self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std) # qformer model doesn't have positional embeddings, adding to the flat patches self.image_positions = nn.Parameter(torch.randn(1, self.image_side ** 2, llm_hidden_size) * embed_std) def forward(self, image_features): batch_size, image_size, dim = image_features.size() query_output = self.qformer( query_embeds=self.query, encoder_hidden_states=image_features + self.image_positions, return_dict=True, ).last_hidden_state bilinear_output = self.bilinear(image_features) return query_output + bilinear_output class WindowQFormerDownsampler(nn.Module): def __init__(self, config): super().__init__() llm_hidden_size = config.text_config.hidden_size self.bilinear = BilinearDownsampler(config) # non causal attention layer configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, num_attention_heads=32, intermediate_size=4096, num_hidden_layers=1, encoder_hidden_size=llm_hidden_size, cross_attention_frequency=1, max_position_embeddings=2048, use_qformer_text_input=False, ) self.qformer = Blip2QFormerModel(configuration) self.image_side = config.vision_config.image_size // config.vision_config.patch_size downsample_rate = Fraction(config.downsample_rate, _normalize=False) self.query_side, self.window_side = downsample_rate.as_integer_ratio() # query length is cubical for seamless integration with llava next self.query_length = self.query_side ** 2 embed_std = 1 / math.sqrt(llm_hidden_size) self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std) # qformer model doesn't have positional embeddings, adding to the flat patches self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, llm_hidden_size) * embed_std) def forward(self, image_features): batch_size, image_size, dim = image_features.size() num_side_windows = self.image_side // self.window_side # splitting to up_shape = [batch_size] + [num_side_windows, self.window_side] * 2 + [dim] # permuting before flattening batch/image dims image_features_permuted = image_features.view(up_shape).transpose(2,3) image_features_flattened = image_features_permuted.flatten(0, 2).flatten(1, 2) query_output = self.qformer( query_embeds=self.query, encoder_hidden_states=image_features_flattened + self.image_positions, return_dict=True, ).last_hidden_state query_output = query_output.reshape(batch_size, -1, dim) bilinear_output = self.bilinear(image_features) return query_output + bilinear_output