|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class CarliniL2: |
|
|
def __init__(self, model, device): |
|
|
self.model = model |
|
|
self.device = device |
|
|
|
|
|
def parse_params(self, gan, confidence=0, targeted=False, learning_rate=1e-1, |
|
|
binary_search_steps=5, max_iterations=10000, abort_early=False, initial_const=1, |
|
|
clip_min=0, clip_max=1): |
|
|
|
|
|
self.TARGETED = targeted |
|
|
self.LEARNING_RATE = learning_rate |
|
|
self.MAX_ITERATIONS = max_iterations |
|
|
self.BINARY_SEARCH_STEPS = binary_search_steps |
|
|
self.ABORT_EARLY = abort_early |
|
|
self.CONFIDENCE = confidence |
|
|
self.initial_const = initial_const |
|
|
self.clip_min = clip_min |
|
|
self.clip_max = clip_max |
|
|
self.gan = gan |
|
|
self.learning_rate = learning_rate |
|
|
self.repeat = binary_search_steps >= 10 |
|
|
|
|
|
def get_or_guess_labels(self, x, y=None): |
|
|
""" |
|
|
Get the label to use in generating an adversarial example for x. |
|
|
The kwargs are fed directly from the kwargs of the attack. |
|
|
If 'y' is in kwargs, use that as the label. |
|
|
Otherwise, use the model's prediction as the label. |
|
|
""" |
|
|
if y is not None: |
|
|
labels = y |
|
|
else: |
|
|
preds = F.softmax(self.model(x)) |
|
|
preds_max = torch.max(preds, 1, keepdim=True)[0] |
|
|
original_predictions = (preds == preds_max) |
|
|
labels = original_predictions |
|
|
del preds |
|
|
return labels.float() |
|
|
|
|
|
def atanh(self, x): |
|
|
return 0.5 * torch.log((1 + x) / (1 - x)) |
|
|
|
|
|
def to_one_hot(self, x): |
|
|
one_hot = torch.FloatTensor(x.shape[0], 10).to(x.get_device()) |
|
|
one_hot.zero_() |
|
|
x = x.unsqueeze(1) |
|
|
one_hot = one_hot.scatter_(1, x, 1) |
|
|
return one_hot |
|
|
|
|
|
def generate(self, imgs, y, start): |
|
|
|
|
|
batch_size = imgs.shape[0] |
|
|
labs = self.get_or_guess_labels(imgs, y) |
|
|
|
|
|
def compare(x, y): |
|
|
if self.TARGETED is None: return True |
|
|
|
|
|
if sum(x.shape) != 0: |
|
|
x = x.clone() |
|
|
if self.TARGETED: |
|
|
x[y] -= self.CONFIDENCE |
|
|
else: |
|
|
x[y] += self.CONFIDENCE |
|
|
x = torch.argmax(x) |
|
|
if self.TARGETED: |
|
|
return x == y |
|
|
else: |
|
|
return x != y |
|
|
|
|
|
|
|
|
lower_bound = torch.zeros(batch_size).to(self.device) |
|
|
CONST = torch.ones(batch_size).to(self.device) * self.initial_const |
|
|
upper_bound = (torch.ones(batch_size) * 1e10).to(self.device) |
|
|
|
|
|
|
|
|
o_bestl2 = [1e10] * batch_size |
|
|
o_bestscore = [-1] * batch_size |
|
|
o_bestattack = self.gan(start) |
|
|
|
|
|
|
|
|
if len(labs.shape) == 1: |
|
|
tlabs = self.to_one_hot(labs.long()) |
|
|
else: |
|
|
tlabs = labs |
|
|
|
|
|
for outer_step in range(self.BINARY_SEARCH_STEPS): |
|
|
|
|
|
modifier = nn.Parameter(start) |
|
|
optimizer = torch.optim.Adam([modifier, ], lr=self.learning_rate) |
|
|
|
|
|
bestl2 = [1e10] * batch_size |
|
|
bestscore = -1 * torch.ones(batch_size, dtype=torch.float32).to(self.device) |
|
|
|
|
|
|
|
|
if self.repeat and outer_step == self.BINARY_SEARCH_STEPS - 1: |
|
|
CONST = upper_bound |
|
|
prev = 1e6 |
|
|
|
|
|
for i in range(self.MAX_ITERATIONS): |
|
|
optimizer.zero_grad() |
|
|
nimgs = self.gan(modifier.to(self.device)) |
|
|
|
|
|
|
|
|
l2dist = torch.sum(torch.sum(torch.sum((nimgs - imgs) ** 2, 1), 1), 1) |
|
|
loss2 = torch.sum(l2dist) |
|
|
|
|
|
|
|
|
scores = self.model(nimgs) |
|
|
|
|
|
|
|
|
other = torch.max(((1 - tlabs) * scores - tlabs * 10000), 1)[0] |
|
|
real = torch.sum(tlabs * scores, 1) |
|
|
|
|
|
if self.TARGETED: |
|
|
|
|
|
loss1 = torch.max(torch.zeros_like(other), other - real + self.CONFIDENCE) |
|
|
else: |
|
|
|
|
|
loss1 = torch.max(torch.zeros_like(other), real - other + self.CONFIDENCE) |
|
|
|
|
|
|
|
|
loss1 = torch.sum(CONST * loss1) |
|
|
loss = loss1 + loss2 |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
if self.ABORT_EARLY and i % ((self.MAX_ITERATIONS // 10) or 1) == 0: |
|
|
if loss > prev * .9999: |
|
|
|
|
|
break |
|
|
prev = loss |
|
|
|
|
|
|
|
|
for e, (l2, sc, ii) in enumerate(zip(l2dist, scores, nimgs)): |
|
|
lab = torch.argmax(tlabs[e]) |
|
|
|
|
|
if l2 < bestl2[e] and compare(sc, lab): |
|
|
bestl2[e] = l2 |
|
|
bestscore[e] = torch.argmax(sc) |
|
|
|
|
|
if l2 < o_bestl2[e] and compare(sc, lab): |
|
|
o_bestl2[e] = l2 |
|
|
o_bestscore[e] = torch.argmax(sc) |
|
|
o_bestattack[e] = ii |
|
|
|
|
|
|
|
|
for e in range(batch_size): |
|
|
if compare(bestscore[e], torch.argmax(tlabs[e]).float()) and \ |
|
|
bestscore[e] != -1: |
|
|
|
|
|
upper_bound[e] = min(upper_bound[e], CONST[e]) |
|
|
if upper_bound[e] < 1e9: |
|
|
CONST[e] = (lower_bound[e] + upper_bound[e]) / 2 |
|
|
else: |
|
|
|
|
|
|
|
|
lower_bound[e] = max(lower_bound[e], CONST[e]) |
|
|
if upper_bound[e] < 1e9: |
|
|
CONST[e] = (lower_bound[e] + upper_bound[e]) / 2 |
|
|
else: |
|
|
CONST[e] *= 10 |
|
|
|
|
|
|
|
|
o_bestl2 = np.array(o_bestl2) |
|
|
return o_bestattack |