Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| import clip | |
| from transformers import GPT2Model | |
| import math | |
| class SpatialAdapter(nn.Module): | |
| """ | |
| Spatial Adapter with Multi-Head Cross-Attention for spatial reasoning. | |
| Processes CLIP patch features (14x14 grid) with question guidance. | |
| """ | |
| def __init__(self, patch_dim=512, question_dim=512, hidden_dim=512, num_heads=8, dropout=0.3): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_heads = num_heads | |
| self.head_dim = hidden_dim // num_heads | |
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" | |
| self.register_buffer('pos_encoding_2d', self._create_2d_positional_encoding(14, 14, patch_dim)) | |
| self.patch_proj = nn.Linear(patch_dim, hidden_dim) | |
| self.question_proj = nn.Linear(question_dim, hidden_dim) | |
| self.cross_attn_query = nn.Linear(hidden_dim, hidden_dim) | |
| self.cross_attn_key = nn.Linear(hidden_dim, hidden_dim) | |
| self.cross_attn_value = nn.Linear(hidden_dim, hidden_dim) | |
| self.cross_attn_out = nn.Linear(hidden_dim, hidden_dim) | |
| self.self_attn_query = nn.Linear(hidden_dim, hidden_dim) | |
| self.self_attn_key = nn.Linear(hidden_dim, hidden_dim) | |
| self.self_attn_value = nn.Linear(hidden_dim, hidden_dim) | |
| self.self_attn_out = nn.Linear(hidden_dim, hidden_dim) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim * 4), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim * 4, hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.ln1 = nn.LayerNorm(hidden_dim) | |
| self.ln2 = nn.LayerNorm(hidden_dim) | |
| self.ln3 = nn.LayerNorm(hidden_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def _create_2d_positional_encoding(self, height, width, dim): | |
| """Create 2D positional encoding for spatial grid""" | |
| pos_h = torch.arange(height).unsqueeze(1).repeat(1, width).flatten() | |
| pos_w = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten() | |
| pe = torch.zeros(height * width, dim) | |
| div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) | |
| pe[:, 0:dim//2:2] = torch.sin(pos_h.unsqueeze(1) * div_term[:dim//4]) | |
| pe[:, 1:dim//2:2] = torch.cos(pos_h.unsqueeze(1) * div_term[:dim//4]) | |
| pe[:, dim//2::2] = torch.sin(pos_w.unsqueeze(1) * div_term[:dim//4]) | |
| pe[:, dim//2+1::2] = torch.cos(pos_w.unsqueeze(1) * div_term[:dim//4]) | |
| return pe.unsqueeze(0) | |
| def _multi_head_attention(self, query, key, value, num_heads): | |
| """Generic multi-head attention implementation""" | |
| batch_size = query.size(0) | |
| Q = query.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2) | |
| K = key.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2) | |
| V = value.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| attn_weights = torch.softmax(scores, dim=-1) | |
| attn_weights = self.dropout(attn_weights) | |
| context = torch.matmul(attn_weights, V) | |
| context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim) | |
| return context, attn_weights | |
| def forward(self, patch_features, question_features): | |
| """ | |
| Args: | |
| patch_features: [batch_size, num_patches, patch_dim] - CLIP patch features | |
| question_features: [batch_size, question_dim] - Question encoding | |
| Returns: | |
| spatial_context: [batch_size, hidden_dim] - Spatially-aware context | |
| """ | |
| batch_size, num_patches, _ = patch_features.shape | |
| patch_features = patch_features + self.pos_encoding_2d[:, :num_patches, :].to(patch_features.device) | |
| patches = self.patch_proj(patch_features) | |
| question = self.question_proj(question_features.unsqueeze(1)) | |
| Q_cross = self.cross_attn_query(patches) | |
| K_cross = self.cross_attn_key(question) | |
| V_cross = self.cross_attn_value(question) | |
| cross_context, _ = self._multi_head_attention(Q_cross, K_cross, V_cross, self.num_heads) | |
| cross_out = self.cross_attn_out(cross_context) | |
| patches = self.ln1(patches + self.dropout(cross_out)) | |
| Q_self = self.self_attn_query(patches) | |
| K_self = self.self_attn_key(patches) | |
| V_self = self.self_attn_value(patches) | |
| self_context, _ = self._multi_head_attention(Q_self, K_self, V_self, self.num_heads) | |
| self_out = self.self_attn_out(self_context) | |
| patches = self.ln2(patches + self.dropout(self_out)) | |
| ffn_out = self.ffn(patches) | |
| patches = self.ln3(patches + ffn_out) | |
| attn_scores = torch.matmul(patches, question.transpose(1, 2)) | |
| attn_weights = torch.softmax(attn_scores, dim=1) | |
| spatial_context = (patches * attn_weights).sum(dim=1) | |
| return spatial_context | |
| class VQAModelWithSpatialAdapter(nn.Module): | |
| """ | |
| Enhanced VQA Model with Spatial Adapter for spatial reasoning. | |
| Uses patch-based CLIP features instead of global encoding. | |
| """ | |
| def __init__( | |
| self, | |
| base_model, | |
| hidden_size=512, | |
| num_heads=8, | |
| dropout=0.3 | |
| ): | |
| super().__init__() | |
| self.device = base_model.device | |
| self.question_max_len = base_model.question_max_len | |
| self.answer_max_len = base_model.answer_max_len | |
| self.vocab_size = base_model.vocab_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = base_model.num_layers | |
| self.fine_tuning_mode = base_model.fine_tuning_mode | |
| self.pad_token_id = base_model.pad_token_id | |
| self.bos_token_id = base_model.bos_token_id | |
| self.eos_token_id = base_model.eos_token_id | |
| self.unk_token_id = base_model.unk_token_id | |
| self.clip_model = base_model.clip_model | |
| self.clip_preprocess = base_model.clip_preprocess | |
| self.gpt2_model = base_model.gpt2_model | |
| self.decoder = base_model.decoder | |
| self.spatial_adapter = SpatialAdapter( | |
| patch_dim=512, | |
| question_dim=768, | |
| hidden_dim=hidden_size, | |
| num_heads=num_heads, | |
| dropout=dropout | |
| ) | |
| self.spatial_context_proj = nn.Linear(hidden_size, hidden_size) | |
| self.q_proj = nn.Linear(768, hidden_size) | |
| self.spatial_fusion = nn.Sequential( | |
| nn.Linear(hidden_size * 2, hidden_size), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.LayerNorm(hidden_size) | |
| ) | |
| def extract_clip_patch_features(self, images): | |
| """ | |
| Extract patch features from CLIP instead of global features. | |
| Returns: [batch_size, num_patches, patch_dim] | |
| """ | |
| clip_dtype = self.clip_model.visual.conv1.weight.dtype | |
| images = images.to(clip_dtype) | |
| if self.fine_tuning_mode: | |
| x = self.clip_model.visual.conv1(images) | |
| x = x.reshape(x.shape[0], x.shape[1], -1) | |
| x = x.permute(0, 2, 1) | |
| class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros( | |
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device | |
| ) | |
| x = torch.cat([class_token, x], dim=1) | |
| x = x + self.clip_model.visual.positional_embedding.to(x.dtype) | |
| x = self.clip_model.visual.ln_pre(x) | |
| x = x.permute(1, 0, 2) | |
| x = self.clip_model.visual.transformer(x) | |
| x = x.permute(1, 0, 2) | |
| patch_features = x[:, 1:, :] | |
| if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None: | |
| if isinstance(self.clip_model.visual.proj, torch.nn.Parameter): | |
| patch_features = patch_features @ self.clip_model.visual.proj | |
| else: | |
| patch_features = self.clip_model.visual.proj(patch_features) | |
| else: | |
| with torch.no_grad(): | |
| x = self.clip_model.visual.conv1(images) | |
| x = x.reshape(x.shape[0], x.shape[1], -1) | |
| x = x.permute(0, 2, 1) | |
| class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros( | |
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device | |
| ) | |
| x = torch.cat([class_token, x], dim=1) | |
| x = x + self.clip_model.visual.positional_embedding.to(x.dtype) | |
| x = self.clip_model.visual.ln_pre(x) | |
| x = x.permute(1, 0, 2) | |
| x = self.clip_model.visual.transformer(x) | |
| x = x.permute(1, 0, 2) | |
| patch_features = x[:, 1:, :] | |
| if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None: | |
| if isinstance(self.clip_model.visual.proj, torch.nn.Parameter): | |
| patch_features = patch_features @ self.clip_model.visual.proj | |
| else: | |
| patch_features = self.clip_model.visual.proj(patch_features) | |
| return patch_features.float() | |
| def encode_question(self, input_ids, attention_mask): | |
| """Encode question using GPT-2 (same as base model)""" | |
| if self.fine_tuning_mode: | |
| outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask) | |
| else: | |
| with torch.no_grad(): | |
| outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask) | |
| last_hidden = outputs.last_hidden_state | |
| mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype) | |
| masked = last_hidden * mask | |
| sum_hidden = masked.sum(dim=1) | |
| lengths = mask.sum(dim=1).clamp(min=1e-6) | |
| text_features = sum_hidden / lengths | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| return text_features.float() | |
| def forward(self, images, questions, answer_input_ids=None): | |
| """ | |
| Forward pass with spatial adapter. | |
| """ | |
| patch_features = self.extract_clip_patch_features(images) | |
| q_features = self.encode_question(questions["input_ids"], questions["attention_mask"]) | |
| spatial_context = self.spatial_adapter(patch_features, q_features) | |
| spatial_context = self.spatial_context_proj(spatial_context) | |
| q_projected = self.q_proj(q_features) | |
| fused = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1)) | |
| batch_size = images.size(0) | |
| hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, | |
| device=self.device, dtype=torch.float) | |
| if answer_input_ids is not None: | |
| logits, _ = self.decoder(answer_input_ids, fused, hidden) | |
| return logits | |
| else: | |
| generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id, | |
| dtype=torch.long, device=self.device) | |
| generated[:, 0] = self.bos_token_id | |
| for t in range(1, self.answer_max_len): | |
| current_input = generated[:, t-1] | |
| logits, hidden = self.decoder(current_input, fused, hidden) | |
| next_tokens = logits.squeeze(1).argmax(dim=-1) | |
| generated[:, t] = next_tokens | |
| if (next_tokens == self.eos_token_id).all(): | |
| break | |
| return generated | |
| def generate_with_beam_search(self, images, questions, beam_width=5): | |
| """Beam search generation (same as base model but with spatial features)""" | |
| batch_size = images.size(0) | |
| all_results = [] | |
| for b in range(batch_size): | |
| img = images[b:b+1] | |
| q_ids = questions["input_ids"][b:b+1] | |
| q_mask = questions["attention_mask"][b:b+1] | |
| patch_features = self.extract_clip_patch_features(img) | |
| q_features = self.encode_question(q_ids, q_mask) | |
| spatial_context = self.spatial_adapter(patch_features, q_features) | |
| spatial_context = self.spatial_context_proj(spatial_context) | |
| q_projected = self.q_proj(q_features) | |
| context = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1)) | |
| initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size, | |
| device=self.device, dtype=torch.float) | |
| beams = [( | |
| torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device), | |
| 0.0, | |
| initial_hidden | |
| )] | |
| completed_beams = [] | |
| for t in range(1, self.answer_max_len): | |
| candidates = [] | |
| for seq, score, hidden in beams: | |
| if seq[0, -1].item() == self.eos_token_id: | |
| completed_beams.append((seq, score)) | |
| continue | |
| current_input = seq[:, -1] | |
| logits, new_hidden = self.decoder(current_input, context, hidden) | |
| log_probs = torch.log_softmax(logits.squeeze(1), dim=-1) | |
| top_log_probs, top_indices = torch.topk(log_probs[0], beam_width) | |
| for i in range(beam_width): | |
| next_token = top_indices[i].unsqueeze(0).unsqueeze(0) | |
| new_seq = torch.cat([seq, next_token], dim=1) | |
| new_score = score + top_log_probs[i].item() | |
| candidates.append((new_seq, new_score, new_hidden)) | |
| if len(candidates) == 0: | |
| break | |
| beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width] | |
| all_beams = completed_beams + [(seq, score) for seq, score, _ in beams] | |
| if len(all_beams) == 0: | |
| result = torch.full((1, self.answer_max_len), self.pad_token_id, | |
| dtype=torch.long, device=self.device) | |
| else: | |
| best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7)) | |
| result = torch.full((1, self.answer_max_len), self.pad_token_id, | |
| dtype=torch.long, device=self.device) | |
| seq_len = min(best_beam[0].size(1), self.answer_max_len) | |
| result[:, :seq_len] = best_beam[0][:, :seq_len] | |
| all_results.append(result) | |
| return torch.cat(all_results, dim=0) | |
| if __name__ == "__main__": | |
| print("Testing Spatial Adapter Architecture...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| from model import VQAModel | |
| base_model = VQAModel(device=device).to(device) | |
| spatial_model = VQAModelWithSpatialAdapter(base_model).to(device) | |
| spatial_model.eval() | |
| fake_image = torch.randn(2, 3, 224, 224).to(device) | |
| fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0], [1, 15, 25, 35, 2, 0, 0]]).to(device) | |
| fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0]]).to(device) | |
| question_batch = { | |
| "input_ids": fake_question_ids, | |
| "attention_mask": fake_question_mask | |
| } | |
| print(f"\nInput shapes:") | |
| print(f" Images: {fake_image.shape}") | |
| print(f" Questions: {fake_question_ids.shape}") | |
| with torch.no_grad(): | |
| patch_features = spatial_model.extract_clip_patch_features(fake_image) | |
| print(f"\nPatch features shape: {patch_features.shape}") | |
| print(f" Expected: [2, 196, 512] (batch_size, num_patches, patch_dim)") | |
| output = spatial_model(fake_image, question_batch) | |
| print(f"\nGenerated output shape: {output.shape}") | |
| print(f" Expected: [2, {spatial_model.answer_max_len}]") | |
| total_params = sum(p.numel() for p in spatial_model.parameters()) | |
| spatial_adapter_params = sum(p.numel() for p in spatial_model.spatial_adapter.parameters()) | |
| trainable_params = sum(p.numel() for p in spatial_model.parameters() if p.requires_grad) | |
| print(f"\nParameter counts:") | |
| print(f" Total parameters: {total_params:,}") | |
| print(f" Spatial adapter parameters: {spatial_adapter_params:,}") | |
| print(f" Trainable parameters: {trainable_params:,}") | |
| print("\n✓ Spatial adapter architecture test passed!") |