""" AnyAttack Decoder Network. Takes a CLIP embedding (512-dim for ViT-B/32) and generates an adversarial noise image (3 x 224 x 224). The noise is clamped externally to [-eps, eps]. Architecture: FC(512 -> 256*14*14) -> 4x(ResBlock + UpBlock) -> Conv(16->3) ResBlocks include EfficientAttention for spatial self-attention. Adapted from: https://github.com/jiamingzhang94/AnyAttack/blob/master/models/model.py """ import torch import torch.nn as nn import torch.nn.functional as F class EfficientAttention(nn.Module): """Linear-complexity spatial self-attention (O(N*C^2) instead of O(N^2*C)).""" def __init__(self, in_channels: int, key_channels: int, head_count: int, value_channels: int): super().__init__() self.key_channels = key_channels self.head_count = head_count self.value_channels = value_channels self.keys = nn.Conv2d(in_channels, key_channels, 1) self.queries = nn.Conv2d(in_channels, key_channels, 1) self.values = nn.Conv2d(in_channels, value_channels, 1) self.reprojection = nn.Conv2d(value_channels, in_channels, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: n, _, h, w = x.size() keys = self.keys(x).reshape(n, self.key_channels, h * w) queries = self.queries(x).reshape(n, self.key_channels, h * w) values = self.values(x).reshape(n, self.value_channels, h * w) head_key_ch = self.key_channels // self.head_count head_val_ch = self.value_channels // self.head_count attended = [] for i in range(self.head_count): k = F.softmax(keys[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=2) q = F.softmax(queries[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=1) v = values[:, i * head_val_ch:(i + 1) * head_val_ch, :] context = k @ v.transpose(1, 2) out = (context.transpose(1, 2) @ q).reshape(n, head_val_ch, h, w) attended.append(out) aggregated = torch.cat(attended, dim=1) return self.reprojection(aggregated) + x class ResBlock(nn.Module): """Residual block with EfficientAttention.""" def __init__(self, in_ch: int, out_ch: int, key_ch: int, head_count: int, val_ch: int): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1) self.bn1 = nn.BatchNorm2d(out_ch) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1) self.bn2 = nn.BatchNorm2d(out_ch) self.act = nn.LeakyReLU(0.2, inplace=True) self.attention = EfficientAttention(out_ch, key_ch, head_count, val_ch) self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: residual = self.skip(x) out = self.act(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out = self.attention(out) return self.act(out + residual) class UpBlock(nn.Module): """2x spatial upsampling with conv.""" def __init__(self, in_ch: int, out_ch: int): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="nearest") self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1) self.bn = nn.BatchNorm2d(out_ch) self.act = nn.LeakyReLU(0.2, inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(self.bn(self.conv(self.up(x)))) class Decoder(nn.Module): """ AnyAttack noise generator: CLIP embedding -> adversarial noise image. Args: embed_dim: Input embedding dimension (512 for ViT-B/32, 1024 for ViT-L/14). img_channels: Output image channels (3 for RGB). img_size: Output spatial resolution (224). """ def __init__(self, embed_dim: int = 512, img_channels: int = 3, img_size: int = 224): super().__init__() self.init_size = img_size // 16 # 14 for 224 self.fc = nn.Sequential( nn.Linear(embed_dim, 256 * self.init_size ** 2) ) self.blocks = nn.ModuleList([ ResBlock(256, 256, 64, 8, 256), UpBlock(256, 128), ResBlock(128, 128, 32, 8, 128), UpBlock(128, 64), ResBlock(64, 64, 16, 8, 64), UpBlock(64, 32), ResBlock(32, 32, 8, 8, 32), UpBlock(32, 16), ResBlock(16, 16, 4, 8, 16), ]) self.head = nn.Conv2d(16, img_channels, 3, 1, 1) def forward(self, embedding: torch.Tensor) -> torch.Tensor: """ Generate noise from CLIP embedding. Args: embedding: (B, embed_dim) CLIP image embedding. Returns: (B, 3, img_size, img_size) raw noise (NOT clamped to [-eps, eps]). """ out = self.fc(embedding.float().view(embedding.size(0), -1)) out = out.view(out.size(0), 256, self.init_size, self.init_size) for block in self.blocks: out = block(out) return self.head(out)