File size: 5,166 Bytes
e1887f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""

AnyAttack Decoder Network.



Takes a CLIP embedding (512-dim for ViT-B/32) and generates an adversarial

noise image (3 x 224 x 224). The noise is clamped externally to [-eps, eps].



Architecture:

  FC(512 -> 256*14*14) -> 4x(ResBlock + UpBlock) -> Conv(16->3)

  ResBlocks include EfficientAttention for spatial self-attention.



Adapted from: https://github.com/jiamingzhang94/AnyAttack/blob/master/models/model.py

"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class EfficientAttention(nn.Module):
    """Linear-complexity spatial self-attention (O(N*C^2) instead of O(N^2*C))."""

    def __init__(self, in_channels: int, key_channels: int,

                 head_count: int, value_channels: int):
        super().__init__()
        self.key_channels = key_channels
        self.head_count = head_count
        self.value_channels = value_channels

        self.keys = nn.Conv2d(in_channels, key_channels, 1)
        self.queries = nn.Conv2d(in_channels, key_channels, 1)
        self.values = nn.Conv2d(in_channels, value_channels, 1)
        self.reprojection = nn.Conv2d(value_channels, in_channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        n, _, h, w = x.size()
        keys = self.keys(x).reshape(n, self.key_channels, h * w)
        queries = self.queries(x).reshape(n, self.key_channels, h * w)
        values = self.values(x).reshape(n, self.value_channels, h * w)

        head_key_ch = self.key_channels // self.head_count
        head_val_ch = self.value_channels // self.head_count

        attended = []
        for i in range(self.head_count):
            k = F.softmax(keys[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=2)
            q = F.softmax(queries[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=1)
            v = values[:, i * head_val_ch:(i + 1) * head_val_ch, :]
            context = k @ v.transpose(1, 2)
            out = (context.transpose(1, 2) @ q).reshape(n, head_val_ch, h, w)
            attended.append(out)

        aggregated = torch.cat(attended, dim=1)
        return self.reprojection(aggregated) + x


class ResBlock(nn.Module):
    """Residual block with EfficientAttention."""

    def __init__(self, in_ch: int, out_ch: int,

                 key_ch: int, head_count: int, val_ch: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.attention = EfficientAttention(out_ch, key_ch, head_count, val_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.skip(x)
        out = self.act(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.attention(out)
        return self.act(out + residual)


class UpBlock(nn.Module):
    """2x spatial upsampling with conv."""

    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.bn(self.conv(self.up(x))))


class Decoder(nn.Module):
    """

    AnyAttack noise generator: CLIP embedding -> adversarial noise image.



    Args:

        embed_dim: Input embedding dimension (512 for ViT-B/32, 1024 for ViT-L/14).

        img_channels: Output image channels (3 for RGB).

        img_size: Output spatial resolution (224).

    """

    def __init__(self, embed_dim: int = 512, img_channels: int = 3, img_size: int = 224):
        super().__init__()
        self.init_size = img_size // 16  # 14 for 224

        self.fc = nn.Sequential(
            nn.Linear(embed_dim, 256 * self.init_size ** 2)
        )

        self.blocks = nn.ModuleList([
            ResBlock(256, 256, 64, 8, 256),
            UpBlock(256, 128),
            ResBlock(128, 128, 32, 8, 128),
            UpBlock(128, 64),
            ResBlock(64, 64, 16, 8, 64),
            UpBlock(64, 32),
            ResBlock(32, 32, 8, 8, 32),
            UpBlock(32, 16),
            ResBlock(16, 16, 4, 8, 16),
        ])

        self.head = nn.Conv2d(16, img_channels, 3, 1, 1)

    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
        """

        Generate noise from CLIP embedding.



        Args:

            embedding: (B, embed_dim) CLIP image embedding.



        Returns:

            (B, 3, img_size, img_size) raw noise (NOT clamped to [-eps, eps]).

        """
        out = self.fc(embedding.float().view(embedding.size(0), -1))
        out = out.view(out.size(0), 256, self.init_size, self.init_size)
        for block in self.blocks:
            out = block(out)
        return self.head(out)