Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import cvxpy as cp | |
| import qpth | |
| from . import _qp_solver_patch | |
| def solve_qp(Q, P, G, H): | |
| B = Q.shape[0] | |
| if B == 1: | |
| # Batch size of 1 has weird instabilities | |
| # I imagine there is a .squeeze() or something inside the QP solver's code | |
| # that messes up broadcasting dimensions when batch dimension is 1 so let's | |
| # artificially make 2 solutions when we need 1 | |
| Q = Q.expand(2, -1, -1) | |
| P = P.expand(2, -1) | |
| G = G.expand(2, -1, -1) | |
| H = H.expand(2, -1) | |
| e = torch.empty(0, device=Q.device) | |
| z_sol = qpth.qp.QPFunction(verbose=-1, eps=1e-2, check_Q_spd=False)(Q, P, G, H, e, e) | |
| if B == 1: | |
| z_sol = z_sol[:1] | |
| return z_sol | |
| class CVXProjLoss(nn.Module): | |
| def __init__(self, confidence=0): | |
| super().__init__() | |
| self.confidence = confidence | |
| def precompute(self, attack_targets, gt_labels, config): | |
| return { | |
| "margin": config.cvx_proj_margin | |
| } | |
| def forward(self, logits_pred, feats_pred, feats_pred_0, attack_targets, model, margin, **kwargs): | |
| device = logits_pred.device | |
| head_W, head_bias = model.head_matrices() | |
| num_feats = head_W.shape[1] | |
| num_classes = head_W.shape[0] | |
| K = attack_targets.shape[-1] | |
| B = logits_pred.shape[0] | |
| # Start with all classes should be less than smallest attack target | |
| D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1) # [B, C, C] | |
| attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1) | |
| D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device)) | |
| # Clear out the constraint row for each item in the attack targets | |
| attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1]) | |
| D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device)) | |
| batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1) | |
| attack_targets_pos = attack_targets[:, :-1] # [B, K-1] | |
| attack_targets_neg = attack_targets[:, 1:] # [B, K-1] | |
| attack_targets_neg_inds = torch.stack(( | |
| batch_inds, | |
| attack_targets_neg, | |
| attack_targets_neg | |
| ), dim=0) # [3, B, K - 1] | |
| attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1) | |
| D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1 | |
| attack_targets_pos_inds = torch.stack(( | |
| batch_inds, | |
| attack_targets_neg, | |
| attack_targets_pos | |
| ), dim=0) # [3, B, K - 1] | |
| D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1 | |
| A = head_W | |
| b = head_bias | |
| Q = 2*torch.eye(feats_pred.shape[1], device=device)[None].expand(B, -1, -1) | |
| # We want the solution features to be as close as possible | |
| # to the current features but also head on the direction of | |
| # the smallest possible perturbation from the initial predicted | |
| # features | |
| anchor_feats = feats_pred | |
| P = -2*anchor_feats.expand(B, -1) | |
| G = -D@A | |
| H = -(margin - D @ b) | |
| # Constraints are indexed by smaller logit | |
| # First attack target isn't smaller than any logit, so its | |
| # constraint index is redundant, but we keep it for easier parallelization | |
| # Make this constraint all 0s | |
| zero_inds = attack_targets[:, 0:1] # [B, 1] | |
| H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device)) | |
| z_sol = solve_qp(Q, P, G, H) | |
| loss = (feats_pred - z_sol).square().sum(dim=-1) | |
| # loss_check = self.forward_check(logits_pred, feats_pred, attack_targets, model, **kwargs) | |
| return loss |