Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision import transforms | |
| class Attack: | |
| def __init__(self, pipe, classifer, device="cpu"): | |
| self.device = device | |
| self.pipe = pipe | |
| self.generator = torch.Generator(device=self.device).manual_seed(1024) | |
| self.classifer = classifer | |
| def __call__( | |
| self, prompt, negative_prompt="", size=512, guidance_scale=8, epsilon=0 | |
| ): | |
| pipe_output = self.pipe( | |
| prompt=prompt, # What to generate | |
| negative_prompt=negative_prompt, # What NOT to generate | |
| height=size, | |
| width=size, # Specify the image size | |
| guidance_scale=guidance_scale, # How strongly to follow the prompt | |
| num_inference_steps=30, # How many steps to take | |
| generator=self.generator, # Fixed random seed | |
| ) | |
| # Resulting image: | |
| init_image = pipe_output.images[0] | |
| image = self.transform(init_image) | |
| image.requires_grad = True | |
| outputs = self.classifer(image).to(self.device) | |
| target = torch.tensor([0]).to(self.device) | |
| return ( | |
| init_image, | |
| self.untargeted_attack(image, outputs, target, epsilon), | |
| ) | |
| def transform(self, image): | |
| img_tfms = transforms.Compose( | |
| [transforms.Resize(32), transforms.ToTensor()] | |
| ) | |
| image = img_tfms(image) | |
| image = torch.unsqueeze(image, dim=0) | |
| return image | |
| def untargeted_attack(self, image, pred, target, epsilon): | |
| loss = torch.nn.functional.nll_loss(pred, target) | |
| self.classifer.zero_grad() | |
| loss.backward() | |
| gradient_sign = image.grad.data.sign() | |
| perturbed_image = image + epsilon * gradient_sign | |
| perturbed_image = torch.clamp(perturbed_image, 0, 1) | |
| return perturbed_image | |