import torch from torch import nn import clip from transformers import GPT2Model class AttentionDecoder(nn.Module): def __init__(self, hidden_size, vocab_size, num_layers=1, dropout=0.3): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, hidden_size) self.attention = nn.Linear(hidden_size * 2, 1) self.gru = nn.GRU( input_size=hidden_size * 2, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0 ) self.ln_gru = nn.LayerNorm(hidden_size) self.output = nn.Linear(hidden_size, vocab_size) def forward(self, input_ids, context, hidden): if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(1) embeddings = self.embedding(input_ids).float() context_expanded = context.unsqueeze(1).expand(-1, embeddings.size(1), -1) combined = torch.cat([embeddings, context_expanded], dim=-1) attn_weights = torch.softmax(self.attention(combined), dim=1) attended_context = (context_expanded * attn_weights).sum(dim=1, keepdim=True) gru_input = torch.cat([embeddings, attended_context.expand(-1, embeddings.size(1), -1)], dim=-1) gru_output, hidden = self.gru(gru_input, hidden) gru_output = self.ln_gru(gru_output) return self.output(gru_output), hidden class VQAModel(nn.Module): def __init__( self, vocab_size=3600, question_max_len=16, answer_max_len=10, hidden_size=512, num_layers=2, dropout=0.3, device='cuda', pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3 ): super().__init__() self.device = device self.question_max_len = question_max_len self.answer_max_len = answer_max_len self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers self.fine_tuning_mode = False self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.unk_token_id = unk_token_id self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=device) for p in self.clip_model.parameters(): p.requires_grad = False self.gpt2_model = GPT2Model.from_pretrained("distilgpt2") self.gpt2_model.to(device) for p in self.gpt2_model.parameters(): p.requires_grad = False self.img_proj = nn.Linear(512, hidden_size) self.q_proj = nn.Linear(768, hidden_size) self.gate_layer = nn.Linear(hidden_size*2, hidden_size) self.fusion = nn.Sequential( nn.Linear(hidden_size*3, hidden_size), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size, hidden_size) ) self.decoder = AttentionDecoder(hidden_size, vocab_size, num_layers, dropout) def unfreeze_clip_layers(self, num_layers=2): self.clip_model.train() self.clip_model.visual.float() total_blocks = len(self.clip_model.visual.transformer.resblocks) for i, block in enumerate(self.clip_model.visual.transformer.resblocks): if i >= total_blocks - num_layers: for p in block.parameters(): p.requires_grad = True if hasattr(self.clip_model.visual, "proj") and self.clip_model.visual.proj is not None: if isinstance(self.clip_model.visual.proj, torch.nn.Parameter): self.clip_model.visual.proj.requires_grad = True else: for p in self.clip_model.visual.proj.parameters(): p.requires_grad = True if hasattr(self.clip_model.visual, "ln_post"): for p in self.clip_model.visual.ln_post.parameters(): p.requires_grad = True self.fine_tuning_mode = True print(f"Unfrozen last {num_layers} CLIP layers") def unfreeze_gpt2_layers(self, num_layers=1): self.gpt2_model.train() total_layers = len(self.gpt2_model.h) for i, layer in enumerate(self.gpt2_model.h): if i >= total_layers - num_layers: for p in layer.parameters(): p.requires_grad = True p.data = p.data.float() for p in self.gpt2_model.ln_f.parameters(): p.requires_grad = True p.data = p.data.float() self.fine_tuning_mode = True print(f"Unfrozen last {num_layers} GPT-2 layers") def encode_image(self, images): if self.fine_tuning_mode: images = images.float() img_features = self.clip_model.encode_image(images) else: with torch.no_grad(): img_features = self.clip_model.encode_image(images) img_features = img_features / img_features.norm(dim=-1, keepdim=True) return img_features.float() def encode_question(self, input_ids, attention_mask): if self.fine_tuning_mode: outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask) else: with torch.no_grad(): outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask) last_hidden = outputs.last_hidden_state mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype) masked = last_hidden * mask sum_hidden = masked.sum(dim=1) lengths = mask.sum(dim=1).clamp(min=1e-6) text_features = sum_hidden / lengths text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.float() def fuse_features(self, img_features, q_features): x = torch.cat([img_features, q_features], dim=-1) gate = torch.sigmoid(self.gate_layer(x)) fused = gate * img_features + (1-gate) * q_features fused = self.fusion(torch.cat([fused, x], dim=-1)) return fused def forward(self, images, questions, answer_input_ids=None): img_features = self.encode_image(images) img_features = self.img_proj(img_features).float() q_features = self.encode_question(questions["input_ids"], questions["attention_mask"]) q_features = self.q_proj(q_features).float() batch_size = img_features.size(0) context = self.fuse_features(img_features, q_features) hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device, dtype=torch.float) if answer_input_ids is not None: logits, _ = self.decoder(answer_input_ids, context, hidden) return logits else: generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id, dtype=torch.long, device=self.device) generated[:, 0] = self.bos_token_id for t in range(1, self.answer_max_len): current_input = generated[:, t-1] logits, hidden = self.decoder(current_input, context, hidden) next_tokens = logits.squeeze(1).argmax(dim=-1) generated[:, t] = next_tokens if (next_tokens == self.eos_token_id).all(): break return generated def generate_with_beam_search(self, images, questions, beam_width=5): batch_size = images.size(0) all_results = [] for b in range(batch_size): img = images[b:b+1] q_ids = questions["input_ids"][b:b+1] q_mask = questions["attention_mask"][b:b+1] img_features = self.encode_image(img) img_features = self.img_proj(img_features).float() q_features = self.encode_question(q_ids, q_mask) q_features = self.q_proj(q_features).float() context = self.fuse_features(img_features, q_features) initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size, device=self.device, dtype=torch.float) beams = [( torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device), 0.0, initial_hidden )] completed_beams = [] for t in range(1, self.answer_max_len): candidates = [] for seq, score, hidden in beams: if seq[0, -1].item() == self.eos_token_id: completed_beams.append((seq, score)) continue current_input = seq[:, -1] logits, new_hidden = self.decoder(current_input, context, hidden) log_probs = torch.log_softmax(logits.squeeze(1), dim=-1) top_log_probs, top_indices = torch.topk(log_probs[0], beam_width) for i in range(beam_width): next_token = top_indices[i].unsqueeze(0).unsqueeze(0) new_seq = torch.cat([seq, next_token], dim=1) new_score = score + top_log_probs[i].item() candidates.append((new_seq, new_score, new_hidden)) if len(candidates) == 0: break beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width] all_beams = completed_beams + [(seq, score) for seq, score, _ in beams] if len(all_beams) == 0: result = torch.full((1, self.answer_max_len), self.pad_token_id, dtype=torch.long, device=self.device) else: best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7)) result = torch.full((1, self.answer_max_len), self.pad_token_id, dtype=torch.long, device=self.device) seq_len = min(best_beam[0].size(1), self.answer_max_len) result[:, :seq_len] = best_beam[0][:, :seq_len] all_results.append(result) return torch.cat(all_results, dim=0) if __name__ == "__main__": device = "cuda" model = VQAModel(device=device).to(device) model.eval() fake_image = torch.randn(1, 3, 224, 224).to(device) fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0]]).to(device) fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).to(device) question_batch = { "input_ids": fake_question_ids, "attention_mask": fake_question_mask } output = model(fake_image, question_batch) print(output)