File size: 16,292 Bytes
bb8f662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import torch
from torch import nn
import clip
from transformers import GPT2Model
import math
class SpatialAdapter(nn.Module):
    """
    Spatial Adapter with Multi-Head Cross-Attention for spatial reasoning.
    Processes CLIP patch features (14x14 grid) with question guidance.
    """
    def __init__(self, patch_dim=512, question_dim=512, hidden_dim=512, num_heads=8, dropout=0.3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.register_buffer('pos_encoding_2d', self._create_2d_positional_encoding(14, 14, patch_dim))
        self.patch_proj = nn.Linear(patch_dim, hidden_dim)
        self.question_proj = nn.Linear(question_dim, hidden_dim)
        self.cross_attn_query = nn.Linear(hidden_dim, hidden_dim)
        self.cross_attn_key = nn.Linear(hidden_dim, hidden_dim)
        self.cross_attn_value = nn.Linear(hidden_dim, hidden_dim)
        self.cross_attn_out = nn.Linear(hidden_dim, hidden_dim)
        self.self_attn_query = nn.Linear(hidden_dim, hidden_dim)
        self.self_attn_key = nn.Linear(hidden_dim, hidden_dim)
        self.self_attn_value = nn.Linear(hidden_dim, hidden_dim)
        self.self_attn_out = nn.Linear(hidden_dim, hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.ln3 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def _create_2d_positional_encoding(self, height, width, dim):
        """Create 2D positional encoding for spatial grid"""
        pos_h = torch.arange(height).unsqueeze(1).repeat(1, width).flatten()
        pos_w = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten()
        pe = torch.zeros(height * width, dim)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        pe[:, 0:dim//2:2] = torch.sin(pos_h.unsqueeze(1) * div_term[:dim//4])
        pe[:, 1:dim//2:2] = torch.cos(pos_h.unsqueeze(1) * div_term[:dim//4])
        pe[:, dim//2::2] = torch.sin(pos_w.unsqueeze(1) * div_term[:dim//4])
        pe[:, dim//2+1::2] = torch.cos(pos_w.unsqueeze(1) * div_term[:dim//4])
        return pe.unsqueeze(0)
    def _multi_head_attention(self, query, key, value, num_heads):
        """Generic multi-head attention implementation"""
        batch_size = query.size(0)
        Q = query.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
        K = key.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
        V = value.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
        return context, attn_weights
    def forward(self, patch_features, question_features):
        """
        Args:
            patch_features: [batch_size, num_patches, patch_dim] - CLIP patch features
            question_features: [batch_size, question_dim] - Question encoding
        Returns:
            spatial_context: [batch_size, hidden_dim] - Spatially-aware context
        """
        batch_size, num_patches, _ = patch_features.shape
        patch_features = patch_features + self.pos_encoding_2d[:, :num_patches, :].to(patch_features.device)
        patches = self.patch_proj(patch_features)
        question = self.question_proj(question_features.unsqueeze(1))
        Q_cross = self.cross_attn_query(patches)
        K_cross = self.cross_attn_key(question)
        V_cross = self.cross_attn_value(question)
        cross_context, _ = self._multi_head_attention(Q_cross, K_cross, V_cross, self.num_heads)
        cross_out = self.cross_attn_out(cross_context)
        patches = self.ln1(patches + self.dropout(cross_out))
        Q_self = self.self_attn_query(patches)
        K_self = self.self_attn_key(patches)
        V_self = self.self_attn_value(patches)
        self_context, _ = self._multi_head_attention(Q_self, K_self, V_self, self.num_heads)
        self_out = self.self_attn_out(self_context)
        patches = self.ln2(patches + self.dropout(self_out))
        ffn_out = self.ffn(patches)
        patches = self.ln3(patches + ffn_out)
        attn_scores = torch.matmul(patches, question.transpose(1, 2))
        attn_weights = torch.softmax(attn_scores, dim=1)
        spatial_context = (patches * attn_weights).sum(dim=1)
        return spatial_context
class VQAModelWithSpatialAdapter(nn.Module):
    """
    Enhanced VQA Model with Spatial Adapter for spatial reasoning.
    Uses patch-based CLIP features instead of global encoding.
    """
    def __init__(
        self,
        base_model,
        hidden_size=512,
        num_heads=8,
        dropout=0.3
    ):
        super().__init__()
        self.device = base_model.device
        self.question_max_len = base_model.question_max_len
        self.answer_max_len = base_model.answer_max_len
        self.vocab_size = base_model.vocab_size
        self.hidden_size = hidden_size
        self.num_layers = base_model.num_layers
        self.fine_tuning_mode = base_model.fine_tuning_mode
        self.pad_token_id = base_model.pad_token_id
        self.bos_token_id = base_model.bos_token_id
        self.eos_token_id = base_model.eos_token_id
        self.unk_token_id = base_model.unk_token_id
        self.clip_model = base_model.clip_model
        self.clip_preprocess = base_model.clip_preprocess
        self.gpt2_model = base_model.gpt2_model
        self.decoder = base_model.decoder
        self.spatial_adapter = SpatialAdapter(
            patch_dim=512,
            question_dim=768,
            hidden_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout
        )
        self.spatial_context_proj = nn.Linear(hidden_size, hidden_size)
        self.q_proj = nn.Linear(768, hidden_size)
        self.spatial_fusion = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size)
        )
    def extract_clip_patch_features(self, images):
        """
        Extract patch features from CLIP instead of global features.
        Returns: [batch_size, num_patches, patch_dim]
        """
        clip_dtype = self.clip_model.visual.conv1.weight.dtype
        images = images.to(clip_dtype)
        if self.fine_tuning_mode:
            x = self.clip_model.visual.conv1(images)
            x = x.reshape(x.shape[0], x.shape[1], -1)
            x = x.permute(0, 2, 1)
            class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
                x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
            )
            x = torch.cat([class_token, x], dim=1)
            x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
            x = self.clip_model.visual.ln_pre(x)
            x = x.permute(1, 0, 2)
            x = self.clip_model.visual.transformer(x)
            x = x.permute(1, 0, 2)
            patch_features = x[:, 1:, :]
            if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
                if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
                    patch_features = patch_features @ self.clip_model.visual.proj
                else:
                    patch_features = self.clip_model.visual.proj(patch_features)
        else:
            with torch.no_grad():
                x = self.clip_model.visual.conv1(images)
                x = x.reshape(x.shape[0], x.shape[1], -1)
                x = x.permute(0, 2, 1)
                class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
                    x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                )
                x = torch.cat([class_token, x], dim=1)
                x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
                x = self.clip_model.visual.ln_pre(x)
                x = x.permute(1, 0, 2)
                x = self.clip_model.visual.transformer(x)
                x = x.permute(1, 0, 2)
                patch_features = x[:, 1:, :]
                if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
                    if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
                        patch_features = patch_features @ self.clip_model.visual.proj
                    else:
                        patch_features = self.clip_model.visual.proj(patch_features)
        return patch_features.float()
    def encode_question(self, input_ids, attention_mask):
        """Encode question using GPT-2 (same as base model)"""
        if self.fine_tuning_mode:
            outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
        else:
            with torch.no_grad():
                outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
        masked = last_hidden * mask
        sum_hidden = masked.sum(dim=1)
        lengths = mask.sum(dim=1).clamp(min=1e-6)
        text_features = sum_hidden / lengths
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        return text_features.float()
    def forward(self, images, questions, answer_input_ids=None):
        """
        Forward pass with spatial adapter.
        """
        patch_features = self.extract_clip_patch_features(images)
        q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
        spatial_context = self.spatial_adapter(patch_features, q_features)
        spatial_context = self.spatial_context_proj(spatial_context)
        q_projected = self.q_proj(q_features)
        fused = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
        batch_size = images.size(0)
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, 
                           device=self.device, dtype=torch.float)
        if answer_input_ids is not None:
            logits, _ = self.decoder(answer_input_ids, fused, hidden)
            return logits
        else:
            generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
                                 dtype=torch.long, device=self.device)
            generated[:, 0] = self.bos_token_id
            for t in range(1, self.answer_max_len):
                current_input = generated[:, t-1]
                logits, hidden = self.decoder(current_input, fused, hidden)
                next_tokens = logits.squeeze(1).argmax(dim=-1)
                generated[:, t] = next_tokens
                if (next_tokens == self.eos_token_id).all():
                    break
            return generated
    def generate_with_beam_search(self, images, questions, beam_width=5):
        """Beam search generation (same as base model but with spatial features)"""
        batch_size = images.size(0)
        all_results = []
        for b in range(batch_size):
            img = images[b:b+1]
            q_ids = questions["input_ids"][b:b+1]
            q_mask = questions["attention_mask"][b:b+1]
            patch_features = self.extract_clip_patch_features(img)
            q_features = self.encode_question(q_ids, q_mask)
            spatial_context = self.spatial_adapter(patch_features, q_features)
            spatial_context = self.spatial_context_proj(spatial_context)
            q_projected = self.q_proj(q_features)
            context = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
            initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size, 
                                         device=self.device, dtype=torch.float)
            beams = [(
                torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
                0.0,
                initial_hidden
            )]
            completed_beams = []
            for t in range(1, self.answer_max_len):
                candidates = []
                for seq, score, hidden in beams:
                    if seq[0, -1].item() == self.eos_token_id:
                        completed_beams.append((seq, score))
                        continue
                    current_input = seq[:, -1]
                    logits, new_hidden = self.decoder(current_input, context, hidden)
                    log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
                    top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
                    for i in range(beam_width):
                        next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
                        new_seq = torch.cat([seq, next_token], dim=1)
                        new_score = score + top_log_probs[i].item()
                        candidates.append((new_seq, new_score, new_hidden))
                if len(candidates) == 0:
                    break
                beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
            if len(all_beams) == 0:
                result = torch.full((1, self.answer_max_len), self.pad_token_id,
                                dtype=torch.long, device=self.device)
            else:
                best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
                result = torch.full((1, self.answer_max_len), self.pad_token_id,
                                   dtype=torch.long, device=self.device)
                seq_len = min(best_beam[0].size(1), self.answer_max_len)
                result[:, :seq_len] = best_beam[0][:, :seq_len]
            all_results.append(result)
        return torch.cat(all_results, dim=0)
if __name__ == "__main__":
    print("Testing Spatial Adapter Architecture...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    from model import VQAModel
    base_model = VQAModel(device=device).to(device)
    spatial_model = VQAModelWithSpatialAdapter(base_model).to(device)
    spatial_model.eval()
    fake_image = torch.randn(2, 3, 224, 224).to(device)
    fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0], [1, 15, 25, 35, 2, 0, 0]]).to(device)
    fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0]]).to(device)
    question_batch = {
        "input_ids": fake_question_ids,
        "attention_mask": fake_question_mask
    }
    print(f"\nInput shapes:")
    print(f"  Images: {fake_image.shape}")
    print(f"  Questions: {fake_question_ids.shape}")
    with torch.no_grad():
        patch_features = spatial_model.extract_clip_patch_features(fake_image)
        print(f"\nPatch features shape: {patch_features.shape}")
        print(f"  Expected: [2, 196, 512] (batch_size, num_patches, patch_dim)")
        output = spatial_model(fake_image, question_batch)
        print(f"\nGenerated output shape: {output.shape}")
        print(f"  Expected: [2, {spatial_model.answer_max_len}]")
    total_params = sum(p.numel() for p in spatial_model.parameters())
    spatial_adapter_params = sum(p.numel() for p in spatial_model.spatial_adapter.parameters())
    trainable_params = sum(p.numel() for p in spatial_model.parameters() if p.requires_grad)
    print(f"\nParameter counts:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Spatial adapter parameters: {spatial_adapter_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print("\n✓ Spatial adapter architecture test passed!")