File size: 5,552 Bytes
e9f0a60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
from torch import nn
import math
from fractions import Fraction
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel
class BilinearDownsampler:
def __init__(self, config):
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate))
def __call__(self, image_features):
batch_size, _, dim = image_features.size()
up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim]
# interpolate expects B,C,H,W
large_image_permuted = image_features.view(up_shape).permute(0,3,1,2)
small_image_permuted = torch.nn.functional.interpolate(
large_image_permuted, size=(self.new_image_side, self.new_image_side),
mode="area",
)
# back to B,H*W,C
final = small_image_permuted.permute(0,2,3,1).flatten(1,2)
return final
class QFormerDownsampler(nn.Module):
def __init__(self, config):
super().__init__()
llm_hidden_size = config.text_config.hidden_size
self.bilinear = BilinearDownsampler(config)
configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size,
num_attention_heads=32,
intermediate_size=4096,
num_hidden_layers=1,
encoder_hidden_size=llm_hidden_size,
cross_attention_frequency=1,
max_position_embeddings=2048,
use_qformer_text_input=False,
)
self.qformer = Blip2QFormerModel(configuration)
self.image_side = config.vision_config.image_size // config.vision_config.patch_size
downsample_rate = Fraction(config.downsample_rate)
self.query_side = int(downsample_rate * self.image_side)
# query length is cubical for seamless integration with llava next
self.query_length = self.query_side ** 2
embed_std = 1 / math.sqrt(llm_hidden_size)
self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std)
# qformer model doesn't have positional embeddings, adding to the flat patches
self.image_positions = nn.Parameter(torch.randn(1, self.image_side ** 2, llm_hidden_size) * embed_std)
def forward(self, image_features):
batch_size, image_size, dim = image_features.size()
query_output = self.qformer(
query_embeds=self.query,
encoder_hidden_states=image_features + self.image_positions,
return_dict=True,
).last_hidden_state
bilinear_output = self.bilinear(image_features)
return query_output + bilinear_output
class WindowQFormerDownsampler(nn.Module):
def __init__(self, config):
super().__init__()
llm_hidden_size = config.text_config.hidden_size
self.bilinear = BilinearDownsampler(config)
# non causal attention layer
configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size,
num_attention_heads=32,
intermediate_size=4096,
num_hidden_layers=1,
encoder_hidden_size=llm_hidden_size,
cross_attention_frequency=1,
max_position_embeddings=2048,
use_qformer_text_input=False,
)
self.qformer = Blip2QFormerModel(configuration)
self.image_side = config.vision_config.image_size // config.vision_config.patch_size
downsample_rate = Fraction(config.downsample_rate, _normalize=False)
self.query_side, self.window_side = downsample_rate.as_integer_ratio()
# query length is cubical for seamless integration with llava next
self.query_length = self.query_side ** 2
embed_std = 1 / math.sqrt(llm_hidden_size)
self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std)
# qformer model doesn't have positional embeddings, adding to the flat patches
self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, llm_hidden_size) * embed_std)
def forward(self, image_features):
batch_size, image_size, dim = image_features.size()
num_side_windows = self.image_side // self.window_side
# splitting to
up_shape = [batch_size] + [num_side_windows, self.window_side] * 2 + [dim]
# permuting before flattening batch/image dims
image_features_permuted = image_features.view(up_shape).transpose(2,3)
image_features_flattened = image_features_permuted.flatten(0, 2).flatten(1, 2)
query_output = self.qformer(
query_embeds=self.query,
encoder_hidden_states=image_features_flattened + self.image_positions,
return_dict=True,
).last_hidden_state
query_output = query_output.reshape(batch_size, -1, dim)
bilinear_output = self.bilinear(image_features)
return query_output + bilinear_output
|