FGSVQA / src /model /qd_model.py
Xinyi Wang
project files
db25ead
# ----------------------------
# Model: CLIP encoder + w_art/w_str pooling -> weighted sequences -> seq model -> head fusion
# ----------------------------
import torch
import torch.nn as nn
from .clip_dense_encoder import CLIPDenseEncoder
def weighted_pool_2d(fmap, wmap, eps=1e-6):
"""
fmap: [B, C, H, W]
wmap: [B, 1, H, W]
-> [B, C]
"""
w = wmap.clamp(0.0, 1.0)
w = w / (w.sum(dim=(2, 3), keepdim=True) + eps)
return (fmap * w).sum(dim=(2, 3))
def default_gate_stats(w_art, w_str, fmap_bt, eps=1e-6):
"""
w_art: [B, 1, T, H, W]
w_str: [B, 1, T, H, W]
fmap_bt: [B, T, C, H, W]
-> [B, 3]
"""
mu_art = w_art.mean(dim=(1, 2, 3, 4))
mu_str = w_str.mean(dim=(1, 2, 3, 4))
mu_raw = fmap_bt.abs().mean(dim=(1, 2, 3, 4))
stats = torch.stack([mu_art, mu_str, mu_raw], dim=1)
return stats
def two_gate_stats(w_art, w_str):
# [B, 2]
mu_art = w_art.mean(dim=(1, 2, 3, 4))
mu_str = w_str.mean(dim=(1, 2, 3, 4))
return torch.stack([mu_art, mu_str], dim=1)
class QD_MODEL(nn.Module):
def __init__(
self,
*,
clip_model="openai/clip-vit-base-patch16",
head_hidden=384,
gate_hidden=32,
head_dropout=0.2,
gate_dropout=0.1,
ablation_mode="full", # "full" | "art" | "str" | "raw"
):
super().__init__()
self.ablation_mode = ablation_mode
self.encoder = CLIPDenseEncoder(model_name=str(clip_model))
c = int(self.encoder.hidden_size)
# 3 heads
self.head_art = nn.Sequential(
nn.Linear(c, head_hidden),
nn.ReLU(),
nn.Dropout(head_dropout),
nn.Linear(head_hidden, head_hidden // 2),
nn.ReLU(),
nn.Dropout(head_dropout),
nn.Linear(head_hidden // 2, 1),
)
self.head_str = nn.Sequential(
nn.Linear(c, head_hidden),
nn.ReLU(),
nn.Dropout(head_dropout),
nn.Linear(head_hidden, head_hidden // 2),
nn.ReLU(),
nn.Dropout(head_dropout),
nn.Linear(head_hidden // 2, 1),
)
self.head_fmap = nn.Sequential(
nn.Linear(c, head_hidden),
nn.ReLU(),
nn.Dropout(head_dropout),
nn.Linear(head_hidden, head_hidden // 2),
nn.ReLU(),
nn.Dropout(head_dropout),
nn.Linear(head_hidden // 2, 1),
)
# full model: 3-way gate
self.gate = nn.Sequential(
nn.Linear(3, gate_hidden),
nn.ReLU(),
nn.Dropout(gate_dropout),
nn.Linear(gate_hidden, 3),
)
# ablation: two gates only
self.gate_two = nn.Sequential(
nn.Linear(2, gate_hidden),
nn.ReLU(),
nn.Dropout(gate_dropout),
nn.Linear(gate_hidden, 2),
)
def freeze_clip_all(self):
self.encoder.freeze_all()
def unfreeze_clip_last_blocks(self, n_blocks=2, also_unfreeze_ln=True):
self.encoder.unfreeze_last_blocks(
n_blocks=n_blocks,
also_unfreeze_ln=also_unfreeze_ln,
)
def temporal_pool(self, x):
"""
x: [B, T, C]
return: [B, C]
"""
return x.mean(dim=1)
def forward(self, rgb, w_art, w_str, gate_stats=None):
"""
rgb: [B,3,T,H,W] (rgb in [0,1])
w_art/w_str: [B,1,T,H,W] (0..1)
gate_stats: [B,3]
"""
B, _C, T, H, W = rgb.shape
x2d = rgb.permute(0, 2, 1, 3, 4).contiguous().view(B * T, 3, H, W)
fmap2d = self.encoder(x2d)
_, C, Hp, Wp = fmap2d.shape # (B*T, 768, 14, 14)
w_art_bt = w_art.transpose(1, 2) # (B, T, 1, 14, 14)
w_str_bt = w_str.transpose(1, 2)
fmap_bt = fmap2d.view(B, T, C, Hp, Wp) # (B, T, 768, 14, 14)
z_art = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_art_bt[:, i]) for i in range(T)], dim=1) # (B, T, 768)
z_str = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_str_bt[:, i]) for i in range(T)], dim=1)
z_raw = fmap_bt.mean(dim=(-2, -1))
h_art = self.temporal_pool(z_art) # [B, C]
h_str = self.temporal_pool(z_str)
h_fmap = self.temporal_pool(z_raw)
q_art = self.head_art(h_art)
q_str = self.head_str(h_str)
q_fmap = self.head_fmap(h_fmap)
mode = self.ablation_mode.lower()
if mode == "art":
y_hat = q_art
weights = torch.tensor([1.0, 0.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
elif mode == "str":
y_hat = q_str
weights = torch.tensor([0.0, 1.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
elif mode == "raw":
y_hat = q_fmap
weights = torch.tensor([0.0, 0.0, 1.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
elif mode == "art+str":
if gate_stats is None:
gate_stats = two_gate_stats(w_art, w_str) # [B, 2]
g_ar = self.gate_two(gate_stats) # [B, 2]
a, b_ = torch.softmax(g_ar, dim=1).split(1, dim=1)
y_hat = a * q_art + b_ * q_str
zero = torch.zeros_like(a)
weights = torch.cat([a, b_, zero], dim=1)
elif mode == "full":
if gate_stats is None:
gate_stats = default_gate_stats(w_art, w_str, fmap_bt)
g = self.gate(gate_stats)
a, b_, c_ = torch.softmax(g, dim=1).split(1, dim=1)
y_hat = a * q_art + b_ * q_str + c_ * q_fmap
weights = torch.cat([a, b_, c_], dim=1)
else:
raise ValueError(f"Unknown ablation_mode: {self.ablation_mode}")
aux = (
q_art.squeeze(1),
q_str.squeeze(1),
q_fmap.squeeze(1),
weights,
)
return y_hat.squeeze(1), aux