Spaces:
Running
Running
File size: 5,166 Bytes
e1887f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
AnyAttack Decoder Network.
Takes a CLIP embedding (512-dim for ViT-B/32) and generates an adversarial
noise image (3 x 224 x 224). The noise is clamped externally to [-eps, eps].
Architecture:
FC(512 -> 256*14*14) -> 4x(ResBlock + UpBlock) -> Conv(16->3)
ResBlocks include EfficientAttention for spatial self-attention.
Adapted from: https://github.com/jiamingzhang94/AnyAttack/blob/master/models/model.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class EfficientAttention(nn.Module):
"""Linear-complexity spatial self-attention (O(N*C^2) instead of O(N^2*C))."""
def __init__(self, in_channels: int, key_channels: int,
head_count: int, value_channels: int):
super().__init__()
self.key_channels = key_channels
self.head_count = head_count
self.value_channels = value_channels
self.keys = nn.Conv2d(in_channels, key_channels, 1)
self.queries = nn.Conv2d(in_channels, key_channels, 1)
self.values = nn.Conv2d(in_channels, value_channels, 1)
self.reprojection = nn.Conv2d(value_channels, in_channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n, _, h, w = x.size()
keys = self.keys(x).reshape(n, self.key_channels, h * w)
queries = self.queries(x).reshape(n, self.key_channels, h * w)
values = self.values(x).reshape(n, self.value_channels, h * w)
head_key_ch = self.key_channels // self.head_count
head_val_ch = self.value_channels // self.head_count
attended = []
for i in range(self.head_count):
k = F.softmax(keys[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=2)
q = F.softmax(queries[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=1)
v = values[:, i * head_val_ch:(i + 1) * head_val_ch, :]
context = k @ v.transpose(1, 2)
out = (context.transpose(1, 2) @ q).reshape(n, head_val_ch, h, w)
attended.append(out)
aggregated = torch.cat(attended, dim=1)
return self.reprojection(aggregated) + x
class ResBlock(nn.Module):
"""Residual block with EfficientAttention."""
def __init__(self, in_ch: int, out_ch: int,
key_ch: int, head_count: int, val_ch: int):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_ch)
self.act = nn.LeakyReLU(0.2, inplace=True)
self.attention = EfficientAttention(out_ch, key_ch, head_count, val_ch)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.skip(x)
out = self.act(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.attention(out)
return self.act(out + residual)
class UpBlock(nn.Module):
"""2x spatial upsampling with conv."""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
self.bn = nn.BatchNorm2d(out_ch)
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.conv(self.up(x))))
class Decoder(nn.Module):
"""
AnyAttack noise generator: CLIP embedding -> adversarial noise image.
Args:
embed_dim: Input embedding dimension (512 for ViT-B/32, 1024 for ViT-L/14).
img_channels: Output image channels (3 for RGB).
img_size: Output spatial resolution (224).
"""
def __init__(self, embed_dim: int = 512, img_channels: int = 3, img_size: int = 224):
super().__init__()
self.init_size = img_size // 16 # 14 for 224
self.fc = nn.Sequential(
nn.Linear(embed_dim, 256 * self.init_size ** 2)
)
self.blocks = nn.ModuleList([
ResBlock(256, 256, 64, 8, 256),
UpBlock(256, 128),
ResBlock(128, 128, 32, 8, 128),
UpBlock(128, 64),
ResBlock(64, 64, 16, 8, 64),
UpBlock(64, 32),
ResBlock(32, 32, 8, 8, 32),
UpBlock(32, 16),
ResBlock(16, 16, 4, 8, 16),
])
self.head = nn.Conv2d(16, img_channels, 3, 1, 1)
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
"""
Generate noise from CLIP embedding.
Args:
embedding: (B, embed_dim) CLIP image embedding.
Returns:
(B, 3, img_size, img_size) raw noise (NOT clamped to [-eps, eps]).
"""
out = self.fc(embedding.float().view(embedding.size(0), -1))
out = out.view(out.size(0), 256, self.init_size, self.init_size)
for block in self.blocks:
out = block(out)
return self.head(out)
|