vqa-backend / model_spatial.py
Deva8's picture
Deploy VQA Space with model downloader
bb8f662
import torch
from torch import nn
import clip
from transformers import GPT2Model
import math
class SpatialAdapter(nn.Module):
"""
Spatial Adapter with Multi-Head Cross-Attention for spatial reasoning.
Processes CLIP patch features (14x14 grid) with question guidance.
"""
def __init__(self, patch_dim=512, question_dim=512, hidden_dim=512, num_heads=8, dropout=0.3):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
self.register_buffer('pos_encoding_2d', self._create_2d_positional_encoding(14, 14, patch_dim))
self.patch_proj = nn.Linear(patch_dim, hidden_dim)
self.question_proj = nn.Linear(question_dim, hidden_dim)
self.cross_attn_query = nn.Linear(hidden_dim, hidden_dim)
self.cross_attn_key = nn.Linear(hidden_dim, hidden_dim)
self.cross_attn_value = nn.Linear(hidden_dim, hidden_dim)
self.cross_attn_out = nn.Linear(hidden_dim, hidden_dim)
self.self_attn_query = nn.Linear(hidden_dim, hidden_dim)
self.self_attn_key = nn.Linear(hidden_dim, hidden_dim)
self.self_attn_value = nn.Linear(hidden_dim, hidden_dim)
self.self_attn_out = nn.Linear(hidden_dim, hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(dropout)
)
self.ln1 = nn.LayerNorm(hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim)
self.ln3 = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def _create_2d_positional_encoding(self, height, width, dim):
"""Create 2D positional encoding for spatial grid"""
pos_h = torch.arange(height).unsqueeze(1).repeat(1, width).flatten()
pos_w = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten()
pe = torch.zeros(height * width, dim)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
pe[:, 0:dim//2:2] = torch.sin(pos_h.unsqueeze(1) * div_term[:dim//4])
pe[:, 1:dim//2:2] = torch.cos(pos_h.unsqueeze(1) * div_term[:dim//4])
pe[:, dim//2::2] = torch.sin(pos_w.unsqueeze(1) * div_term[:dim//4])
pe[:, dim//2+1::2] = torch.cos(pos_w.unsqueeze(1) * div_term[:dim//4])
return pe.unsqueeze(0)
def _multi_head_attention(self, query, key, value, num_heads):
"""Generic multi-head attention implementation"""
batch_size = query.size(0)
Q = query.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
K = key.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
V = value.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
return context, attn_weights
def forward(self, patch_features, question_features):
"""
Args:
patch_features: [batch_size, num_patches, patch_dim] - CLIP patch features
question_features: [batch_size, question_dim] - Question encoding
Returns:
spatial_context: [batch_size, hidden_dim] - Spatially-aware context
"""
batch_size, num_patches, _ = patch_features.shape
patch_features = patch_features + self.pos_encoding_2d[:, :num_patches, :].to(patch_features.device)
patches = self.patch_proj(patch_features)
question = self.question_proj(question_features.unsqueeze(1))
Q_cross = self.cross_attn_query(patches)
K_cross = self.cross_attn_key(question)
V_cross = self.cross_attn_value(question)
cross_context, _ = self._multi_head_attention(Q_cross, K_cross, V_cross, self.num_heads)
cross_out = self.cross_attn_out(cross_context)
patches = self.ln1(patches + self.dropout(cross_out))
Q_self = self.self_attn_query(patches)
K_self = self.self_attn_key(patches)
V_self = self.self_attn_value(patches)
self_context, _ = self._multi_head_attention(Q_self, K_self, V_self, self.num_heads)
self_out = self.self_attn_out(self_context)
patches = self.ln2(patches + self.dropout(self_out))
ffn_out = self.ffn(patches)
patches = self.ln3(patches + ffn_out)
attn_scores = torch.matmul(patches, question.transpose(1, 2))
attn_weights = torch.softmax(attn_scores, dim=1)
spatial_context = (patches * attn_weights).sum(dim=1)
return spatial_context
class VQAModelWithSpatialAdapter(nn.Module):
"""
Enhanced VQA Model with Spatial Adapter for spatial reasoning.
Uses patch-based CLIP features instead of global encoding.
"""
def __init__(
self,
base_model,
hidden_size=512,
num_heads=8,
dropout=0.3
):
super().__init__()
self.device = base_model.device
self.question_max_len = base_model.question_max_len
self.answer_max_len = base_model.answer_max_len
self.vocab_size = base_model.vocab_size
self.hidden_size = hidden_size
self.num_layers = base_model.num_layers
self.fine_tuning_mode = base_model.fine_tuning_mode
self.pad_token_id = base_model.pad_token_id
self.bos_token_id = base_model.bos_token_id
self.eos_token_id = base_model.eos_token_id
self.unk_token_id = base_model.unk_token_id
self.clip_model = base_model.clip_model
self.clip_preprocess = base_model.clip_preprocess
self.gpt2_model = base_model.gpt2_model
self.decoder = base_model.decoder
self.spatial_adapter = SpatialAdapter(
patch_dim=512,
question_dim=768,
hidden_dim=hidden_size,
num_heads=num_heads,
dropout=dropout
)
self.spatial_context_proj = nn.Linear(hidden_size, hidden_size)
self.q_proj = nn.Linear(768, hidden_size)
self.spatial_fusion = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size)
)
def extract_clip_patch_features(self, images):
"""
Extract patch features from CLIP instead of global features.
Returns: [batch_size, num_patches, patch_dim]
"""
clip_dtype = self.clip_model.visual.conv1.weight.dtype
images = images.to(clip_dtype)
if self.fine_tuning_mode:
x = self.clip_model.visual.conv1(images)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
)
x = torch.cat([class_token, x], dim=1)
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
x = self.clip_model.visual.ln_pre(x)
x = x.permute(1, 0, 2)
x = self.clip_model.visual.transformer(x)
x = x.permute(1, 0, 2)
patch_features = x[:, 1:, :]
if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
patch_features = patch_features @ self.clip_model.visual.proj
else:
patch_features = self.clip_model.visual.proj(patch_features)
else:
with torch.no_grad():
x = self.clip_model.visual.conv1(images)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
)
x = torch.cat([class_token, x], dim=1)
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
x = self.clip_model.visual.ln_pre(x)
x = x.permute(1, 0, 2)
x = self.clip_model.visual.transformer(x)
x = x.permute(1, 0, 2)
patch_features = x[:, 1:, :]
if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
patch_features = patch_features @ self.clip_model.visual.proj
else:
patch_features = self.clip_model.visual.proj(patch_features)
return patch_features.float()
def encode_question(self, input_ids, attention_mask):
"""Encode question using GPT-2 (same as base model)"""
if self.fine_tuning_mode:
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
else:
with torch.no_grad():
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state
mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
masked = last_hidden * mask
sum_hidden = masked.sum(dim=1)
lengths = mask.sum(dim=1).clamp(min=1e-6)
text_features = sum_hidden / lengths
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.float()
def forward(self, images, questions, answer_input_ids=None):
"""
Forward pass with spatial adapter.
"""
patch_features = self.extract_clip_patch_features(images)
q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
spatial_context = self.spatial_adapter(patch_features, q_features)
spatial_context = self.spatial_context_proj(spatial_context)
q_projected = self.q_proj(q_features)
fused = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
batch_size = images.size(0)
hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
device=self.device, dtype=torch.float)
if answer_input_ids is not None:
logits, _ = self.decoder(answer_input_ids, fused, hidden)
return logits
else:
generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
dtype=torch.long, device=self.device)
generated[:, 0] = self.bos_token_id
for t in range(1, self.answer_max_len):
current_input = generated[:, t-1]
logits, hidden = self.decoder(current_input, fused, hidden)
next_tokens = logits.squeeze(1).argmax(dim=-1)
generated[:, t] = next_tokens
if (next_tokens == self.eos_token_id).all():
break
return generated
def generate_with_beam_search(self, images, questions, beam_width=5):
"""Beam search generation (same as base model but with spatial features)"""
batch_size = images.size(0)
all_results = []
for b in range(batch_size):
img = images[b:b+1]
q_ids = questions["input_ids"][b:b+1]
q_mask = questions["attention_mask"][b:b+1]
patch_features = self.extract_clip_patch_features(img)
q_features = self.encode_question(q_ids, q_mask)
spatial_context = self.spatial_adapter(patch_features, q_features)
spatial_context = self.spatial_context_proj(spatial_context)
q_projected = self.q_proj(q_features)
context = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
device=self.device, dtype=torch.float)
beams = [(
torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
0.0,
initial_hidden
)]
completed_beams = []
for t in range(1, self.answer_max_len):
candidates = []
for seq, score, hidden in beams:
if seq[0, -1].item() == self.eos_token_id:
completed_beams.append((seq, score))
continue
current_input = seq[:, -1]
logits, new_hidden = self.decoder(current_input, context, hidden)
log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
for i in range(beam_width):
next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
new_seq = torch.cat([seq, next_token], dim=1)
new_score = score + top_log_probs[i].item()
candidates.append((new_seq, new_score, new_hidden))
if len(candidates) == 0:
break
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
if len(all_beams) == 0:
result = torch.full((1, self.answer_max_len), self.pad_token_id,
dtype=torch.long, device=self.device)
else:
best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
result = torch.full((1, self.answer_max_len), self.pad_token_id,
dtype=torch.long, device=self.device)
seq_len = min(best_beam[0].size(1), self.answer_max_len)
result[:, :seq_len] = best_beam[0][:, :seq_len]
all_results.append(result)
return torch.cat(all_results, dim=0)
if __name__ == "__main__":
print("Testing Spatial Adapter Architecture...")
device = "cuda" if torch.cuda.is_available() else "cpu"
from model import VQAModel
base_model = VQAModel(device=device).to(device)
spatial_model = VQAModelWithSpatialAdapter(base_model).to(device)
spatial_model.eval()
fake_image = torch.randn(2, 3, 224, 224).to(device)
fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0], [1, 15, 25, 35, 2, 0, 0]]).to(device)
fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0]]).to(device)
question_batch = {
"input_ids": fake_question_ids,
"attention_mask": fake_question_mask
}
print(f"\nInput shapes:")
print(f" Images: {fake_image.shape}")
print(f" Questions: {fake_question_ids.shape}")
with torch.no_grad():
patch_features = spatial_model.extract_clip_patch_features(fake_image)
print(f"\nPatch features shape: {patch_features.shape}")
print(f" Expected: [2, 196, 512] (batch_size, num_patches, patch_dim)")
output = spatial_model(fake_image, question_batch)
print(f"\nGenerated output shape: {output.shape}")
print(f" Expected: [2, {spatial_model.answer_max_len}]")
total_params = sum(p.numel() for p in spatial_model.parameters())
spatial_adapter_params = sum(p.numel() for p in spatial_model.spatial_adapter.parameters())
trainable_params = sum(p.numel() for p in spatial_model.parameters() if p.requires_grad)
print(f"\nParameter counts:")
print(f" Total parameters: {total_params:,}")
print(f" Spatial adapter parameters: {spatial_adapter_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print("\n✓ Spatial adapter architecture test passed!")