File size: 5,552 Bytes
e9f0a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from torch import nn
import math
from fractions import Fraction
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel


class BilinearDownsampler:
    def __init__(self, config):
        self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
        self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate))
    
    def __call__(self, image_features):
        batch_size, _, dim = image_features.size()
        up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim]
        # interpolate expects B,C,H,W
        large_image_permuted = image_features.view(up_shape).permute(0,3,1,2)
        small_image_permuted = torch.nn.functional.interpolate(
                large_image_permuted, size=(self.new_image_side, self.new_image_side),
                mode="area",
        )
        # back to B,H*W,C
        final = small_image_permuted.permute(0,2,3,1).flatten(1,2)
        return final
    

class QFormerDownsampler(nn.Module):
    def __init__(self, config):
        super().__init__()
        llm_hidden_size = config.text_config.hidden_size
        self.bilinear = BilinearDownsampler(config)
        
        configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, 
                                        num_attention_heads=32,
                                        intermediate_size=4096, 
                                        num_hidden_layers=1,
                                        encoder_hidden_size=llm_hidden_size, 
                                        cross_attention_frequency=1,
                                        max_position_embeddings=2048,
                                        use_qformer_text_input=False,
                                    )
        self.qformer = Blip2QFormerModel(configuration)

        
        self.image_side = config.vision_config.image_size // config.vision_config.patch_size
        downsample_rate = Fraction(config.downsample_rate)
        self.query_side = int(downsample_rate * self.image_side)
        # query length is cubical for seamless integration with llava next
        self.query_length = self.query_side  ** 2
        embed_std = 1 / math.sqrt(llm_hidden_size)
        self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std)
        # qformer model doesn't have positional embeddings, adding to the flat patches
        self.image_positions = nn.Parameter(torch.randn(1, self.image_side ** 2, llm_hidden_size) * embed_std)


    def forward(self, image_features):
        batch_size, image_size, dim = image_features.size()
        query_output = self.qformer(
            query_embeds=self.query,
            encoder_hidden_states=image_features + self.image_positions,
            return_dict=True,
        ).last_hidden_state
        
        bilinear_output = self.bilinear(image_features)
        return query_output + bilinear_output

class WindowQFormerDownsampler(nn.Module):
    def __init__(self, config):
        super().__init__()
        llm_hidden_size = config.text_config.hidden_size
        self.bilinear = BilinearDownsampler(config)
        #  non causal attention layer
        configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, 
                                        num_attention_heads=32,
                                        intermediate_size=4096, 
                                        num_hidden_layers=1,
                                        encoder_hidden_size=llm_hidden_size, 
                                        cross_attention_frequency=1,
                                        max_position_embeddings=2048,
                                        use_qformer_text_input=False,
                                    )
        self.qformer = Blip2QFormerModel(configuration)

        self.image_side = config.vision_config.image_size // config.vision_config.patch_size
        downsample_rate = Fraction(config.downsample_rate, _normalize=False)
        self.query_side, self.window_side = downsample_rate.as_integer_ratio()
        # query length is cubical for seamless integration with llava next
        self.query_length = self.query_side  ** 2
        embed_std = 1 / math.sqrt(llm_hidden_size)
        self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std)
        # qformer model doesn't have positional embeddings, adding to the flat patches
        self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, llm_hidden_size) * embed_std)


    def forward(self, image_features):
        batch_size, image_size, dim = image_features.size()
        num_side_windows = self.image_side // self.window_side
        # splitting to
        up_shape = [batch_size] + [num_side_windows, self.window_side] * 2 + [dim]
        # permuting before flattening batch/image dims
        image_features_permuted = image_features.view(up_shape).transpose(2,3)
        image_features_flattened = image_features_permuted.flatten(0, 2).flatten(1, 2)
        
        query_output = self.qformer(
            query_embeds=self.query,
            encoder_hidden_states=image_features_flattened + self.image_positions,
            return_dict=True,
        ).last_hidden_state
    
        query_output = query_output.reshape(batch_size, -1, dim)
        bilinear_output = self.bilinear(image_features)
        return query_output + bilinear_output