| |
| |
| |
| import torch |
| import torch.nn as nn |
|
|
| from .clip_dense_encoder import CLIPDenseEncoder |
|
|
|
|
| def weighted_pool_2d(fmap, wmap, eps=1e-6): |
| """ |
| fmap: [B, C, H, W] |
| wmap: [B, 1, H, W] |
| -> [B, C] |
| """ |
| w = wmap.clamp(0.0, 1.0) |
| w = w / (w.sum(dim=(2, 3), keepdim=True) + eps) |
| return (fmap * w).sum(dim=(2, 3)) |
|
|
| def default_gate_stats(w_art, w_str, fmap_bt, eps=1e-6): |
| """ |
| w_art: [B, 1, T, H, W] |
| w_str: [B, 1, T, H, W] |
| fmap_bt: [B, T, C, H, W] |
| -> [B, 3] |
| """ |
| mu_art = w_art.mean(dim=(1, 2, 3, 4)) |
| mu_str = w_str.mean(dim=(1, 2, 3, 4)) |
| mu_raw = fmap_bt.abs().mean(dim=(1, 2, 3, 4)) |
|
|
| stats = torch.stack([mu_art, mu_str, mu_raw], dim=1) |
| return stats |
|
|
| def two_gate_stats(w_art, w_str): |
| |
| mu_art = w_art.mean(dim=(1, 2, 3, 4)) |
| mu_str = w_str.mean(dim=(1, 2, 3, 4)) |
| return torch.stack([mu_art, mu_str], dim=1) |
|
|
| class QD_MODEL(nn.Module): |
| def __init__( |
| self, |
| *, |
| clip_model="openai/clip-vit-base-patch16", |
| head_hidden=384, |
| gate_hidden=32, |
| head_dropout=0.2, |
| gate_dropout=0.1, |
| ablation_mode="full", |
| ): |
| super().__init__() |
| self.ablation_mode = ablation_mode |
|
|
| self.encoder = CLIPDenseEncoder(model_name=str(clip_model)) |
| c = int(self.encoder.hidden_size) |
|
|
| |
| self.head_art = nn.Sequential( |
| nn.Linear(c, head_hidden), |
| nn.ReLU(), |
| nn.Dropout(head_dropout), |
| nn.Linear(head_hidden, head_hidden // 2), |
| nn.ReLU(), |
| nn.Dropout(head_dropout), |
| nn.Linear(head_hidden // 2, 1), |
| ) |
| self.head_str = nn.Sequential( |
| nn.Linear(c, head_hidden), |
| nn.ReLU(), |
| nn.Dropout(head_dropout), |
| nn.Linear(head_hidden, head_hidden // 2), |
| nn.ReLU(), |
| nn.Dropout(head_dropout), |
| nn.Linear(head_hidden // 2, 1), |
| ) |
| self.head_fmap = nn.Sequential( |
| nn.Linear(c, head_hidden), |
| nn.ReLU(), |
| nn.Dropout(head_dropout), |
| nn.Linear(head_hidden, head_hidden // 2), |
| nn.ReLU(), |
| nn.Dropout(head_dropout), |
| nn.Linear(head_hidden // 2, 1), |
| ) |
| |
| self.gate = nn.Sequential( |
| nn.Linear(3, gate_hidden), |
| nn.ReLU(), |
| nn.Dropout(gate_dropout), |
| nn.Linear(gate_hidden, 3), |
| ) |
| |
| self.gate_two = nn.Sequential( |
| nn.Linear(2, gate_hidden), |
| nn.ReLU(), |
| nn.Dropout(gate_dropout), |
| nn.Linear(gate_hidden, 2), |
| ) |
|
|
| def freeze_clip_all(self): |
| self.encoder.freeze_all() |
|
|
| def unfreeze_clip_last_blocks(self, n_blocks=2, also_unfreeze_ln=True): |
| self.encoder.unfreeze_last_blocks( |
| n_blocks=n_blocks, |
| also_unfreeze_ln=also_unfreeze_ln, |
| ) |
|
|
| def temporal_pool(self, x): |
| """ |
| x: [B, T, C] |
| return: [B, C] |
| """ |
| return x.mean(dim=1) |
|
|
| def forward(self, rgb, w_art, w_str, gate_stats=None): |
| """ |
| rgb: [B,3,T,H,W] (rgb in [0,1]) |
| w_art/w_str: [B,1,T,H,W] (0..1) |
| gate_stats: [B,3] |
| """ |
| B, _C, T, H, W = rgb.shape |
|
|
| x2d = rgb.permute(0, 2, 1, 3, 4).contiguous().view(B * T, 3, H, W) |
| fmap2d = self.encoder(x2d) |
| _, C, Hp, Wp = fmap2d.shape |
|
|
| w_art_bt = w_art.transpose(1, 2) |
| w_str_bt = w_str.transpose(1, 2) |
| fmap_bt = fmap2d.view(B, T, C, Hp, Wp) |
|
|
| z_art = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_art_bt[:, i]) for i in range(T)], dim=1) |
| z_str = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_str_bt[:, i]) for i in range(T)], dim=1) |
| z_raw = fmap_bt.mean(dim=(-2, -1)) |
|
|
| h_art = self.temporal_pool(z_art) |
| h_str = self.temporal_pool(z_str) |
| h_fmap = self.temporal_pool(z_raw) |
|
|
| q_art = self.head_art(h_art) |
| q_str = self.head_str(h_str) |
| q_fmap = self.head_fmap(h_fmap) |
|
|
| mode = self.ablation_mode.lower() |
| if mode == "art": |
| y_hat = q_art |
| weights = torch.tensor([1.0, 0.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1) |
| elif mode == "str": |
| y_hat = q_str |
| weights = torch.tensor([0.0, 1.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1) |
| elif mode == "raw": |
| y_hat = q_fmap |
| weights = torch.tensor([0.0, 0.0, 1.0], device=q_art.device).unsqueeze(0).repeat(B, 1) |
| elif mode == "art+str": |
| if gate_stats is None: |
| gate_stats = two_gate_stats(w_art, w_str) |
| g_ar = self.gate_two(gate_stats) |
| a, b_ = torch.softmax(g_ar, dim=1).split(1, dim=1) |
| y_hat = a * q_art + b_ * q_str |
| zero = torch.zeros_like(a) |
| weights = torch.cat([a, b_, zero], dim=1) |
| elif mode == "full": |
| if gate_stats is None: |
| gate_stats = default_gate_stats(w_art, w_str, fmap_bt) |
| g = self.gate(gate_stats) |
| a, b_, c_ = torch.softmax(g, dim=1).split(1, dim=1) |
| y_hat = a * q_art + b_ * q_str + c_ * q_fmap |
| weights = torch.cat([a, b_, c_], dim=1) |
| else: |
| raise ValueError(f"Unknown ablation_mode: {self.ablation_mode}") |
|
|
| aux = ( |
| q_art.squeeze(1), |
| q_str.squeeze(1), |
| q_fmap.squeeze(1), |
| weights, |
| ) |
| return y_hat.squeeze(1), aux |