|
|
import torch |
|
|
from torch import nn |
|
|
import math |
|
|
from fractions import Fraction |
|
|
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig |
|
|
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel |
|
|
|
|
|
|
|
|
class BilinearDownsampler: |
|
|
def __init__(self, config): |
|
|
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size |
|
|
self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate)) |
|
|
|
|
|
def __call__(self, image_features): |
|
|
batch_size, _, dim = image_features.size() |
|
|
up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim] |
|
|
|
|
|
large_image_permuted = image_features.view(up_shape).permute(0,3,1,2) |
|
|
small_image_permuted = torch.nn.functional.interpolate( |
|
|
large_image_permuted, size=(self.new_image_side, self.new_image_side), |
|
|
mode="area", |
|
|
) |
|
|
|
|
|
final = small_image_permuted.permute(0,2,3,1).flatten(1,2) |
|
|
return final |
|
|
|
|
|
|
|
|
class QFormerDownsampler(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
llm_hidden_size = config.text_config.hidden_size |
|
|
self.bilinear = BilinearDownsampler(config) |
|
|
|
|
|
configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, |
|
|
num_attention_heads=32, |
|
|
intermediate_size=4096, |
|
|
num_hidden_layers=1, |
|
|
encoder_hidden_size=llm_hidden_size, |
|
|
cross_attention_frequency=1, |
|
|
max_position_embeddings=2048, |
|
|
use_qformer_text_input=False, |
|
|
) |
|
|
self.qformer = Blip2QFormerModel(configuration) |
|
|
|
|
|
|
|
|
self.image_side = config.vision_config.image_size // config.vision_config.patch_size |
|
|
downsample_rate = Fraction(config.downsample_rate) |
|
|
self.query_side = int(downsample_rate * self.image_side) |
|
|
|
|
|
self.query_length = self.query_side ** 2 |
|
|
embed_std = 1 / math.sqrt(llm_hidden_size) |
|
|
self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std) |
|
|
|
|
|
self.image_positions = nn.Parameter(torch.randn(1, self.image_side ** 2, llm_hidden_size) * embed_std) |
|
|
|
|
|
|
|
|
def forward(self, image_features): |
|
|
batch_size, image_size, dim = image_features.size() |
|
|
query_output = self.qformer( |
|
|
query_embeds=self.query, |
|
|
encoder_hidden_states=image_features + self.image_positions, |
|
|
return_dict=True, |
|
|
).last_hidden_state |
|
|
|
|
|
bilinear_output = self.bilinear(image_features) |
|
|
return query_output + bilinear_output |
|
|
|
|
|
class WindowQFormerDownsampler(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
llm_hidden_size = config.text_config.hidden_size |
|
|
self.bilinear = BilinearDownsampler(config) |
|
|
|
|
|
configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, |
|
|
num_attention_heads=32, |
|
|
intermediate_size=4096, |
|
|
num_hidden_layers=1, |
|
|
encoder_hidden_size=llm_hidden_size, |
|
|
cross_attention_frequency=1, |
|
|
max_position_embeddings=2048, |
|
|
use_qformer_text_input=False, |
|
|
) |
|
|
self.qformer = Blip2QFormerModel(configuration) |
|
|
|
|
|
self.image_side = config.vision_config.image_size // config.vision_config.patch_size |
|
|
downsample_rate = Fraction(config.downsample_rate, _normalize=False) |
|
|
self.query_side, self.window_side = downsample_rate.as_integer_ratio() |
|
|
|
|
|
self.query_length = self.query_side ** 2 |
|
|
embed_std = 1 / math.sqrt(llm_hidden_size) |
|
|
self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std) |
|
|
|
|
|
self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, llm_hidden_size) * embed_std) |
|
|
|
|
|
|
|
|
def forward(self, image_features): |
|
|
batch_size, image_size, dim = image_features.size() |
|
|
num_side_windows = self.image_side // self.window_side |
|
|
|
|
|
up_shape = [batch_size] + [num_side_windows, self.window_side] * 2 + [dim] |
|
|
|
|
|
image_features_permuted = image_features.view(up_shape).transpose(2,3) |
|
|
image_features_flattened = image_features_permuted.flatten(0, 2).flatten(1, 2) |
|
|
|
|
|
query_output = self.qformer( |
|
|
query_embeds=self.query, |
|
|
encoder_hidden_states=image_features_flattened + self.image_positions, |
|
|
return_dict=True, |
|
|
).last_hidden_state |
|
|
|
|
|
query_output = query_output.reshape(batch_size, -1, dim) |
|
|
bilinear_output = self.bilinear(image_features) |
|
|
return query_output + bilinear_output |
|
|
|