File size: 6,698 Bytes
26af353
 
 
 
 
 
 
ea98391
 
bc60cc3
ea98391
bc60cc3
 
8cd11f9
26af353
 
 
 
 
ea98391
 
 
 
 
 
 
 
 
 
 
 
 
26af353
 
 
 
 
 
 
ea98391
26af353
 
bdecc48
 
 
 
 
 
 
 
 
 
3da6aaa
bdecc48
8cd11f9
bdecc48
8cd11f9
bdecc48
8cd11f9
3de0f2a
9c56da2
 
 
 
 
 
e98a2b8
3de0f2a
 
e98a2b8
3de0f2a
e98a2b8
3de0f2a
ff3c24a
74e890e
ff3c24a
74e890e
 
c80e494
82fbda0
 
 
 
 
 
 
 
 
3de0f2a
82fbda0
e98a2b8
82fbda0
e98a2b8
 
3de0f2a
 
 
 
c80e494
e98a2b8
c80e494
e98a2b8
c80e494
e98a2b8
82fbda0
3de0f2a
 
e98a2b8
3de0f2a
e98a2b8
 
 
75aa396
 
 
 
ab559c9
e98a2b8
 
 
2dbd2b9
ff3c24a
 
0657874
5050ebe
 
3de0f2a
 
e98a2b8
 
3de0f2a
 
 
 
c80e494
3de0f2a
 
 
 
9c56da2
8cd11f9
c80e494
 
 
3de0f2a
9c56da2
c80e494
e98a2b8
 
 
 
3de0f2a
 
 
 
 
e98a2b8
 
8cd11f9
26af353
3de0f2a
c80e494
26af353
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import gradio as gr
from torchvision import models, transforms
from PIL import Image
import requests
from huggingface_hub import hf_hub_download

from PIL import Image
import numpy as np
import random




# Load the model checkpoint from Hugging Face
checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")

# Initialize the model
model = models.resnet50()
# change the num_classes to 500
model.fc = torch.nn.Linear(model.fc.in_features, 500)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))['model']
# remove the prefix 'module.' from the keys
# remove the prefix 'model.' from the keys that have it
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
new_state_dict = {k.replace('model.', ''): v for k, v in new_state_dict.items()}
new_state_dict = {k.replace('attacker.', ''): v for k, v in new_state_dict.items()}


print(new_state_dict.keys())
print('********************')
model.load_state_dict(new_state_dict, strict=False)  # ignore Unexpected key(s) in state_dict: "normalizer.new_mean", "normalizer.new_std", "normalize.new_mean", "normalize.new_std". 
model.eval()

# Image preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2
])

# # Function to make predictions
# def predict(image):
#     if isinstance(image, np.ndarray):
#         image = Image.fromarray(image)  # Convert to PIL Image if i
#     image = preprocess(image).unsqueeze(0)  # Add batch dimension
#     with torch.no_grad():
#         output = model(image)  # Perform inference on CPU
#     _, predicted_class = output.max(1)
#     # Fetch 9 random samples from the predicted class
#     class_samples = ds.filter(lambda example: example['label'] == predicted_class.item())

#     sample_images = random.sample(list(class_samples), min(len(class_samples), 9))
    
#     sample_images_urls = [sample['image'] for sample in sample_images]
    
#     return f"Predicted class: {predicted_class.item()}", sample_images_urls


import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np

def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0.01, eps=0.1, noise_ratio=0.1):
    # Preprocess image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Enforce fixed size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    image_original = transform(image).unsqueeze(0)

    image_tensor = image_original.detach().clone()
    image_tensor.requires_grad = True
    # image_tensor.retain_grad()  # Ensure gradients are retained for non-leaf tensor

    if mode == "increase confidence":
        output = model(image_tensor)
        probs = torch.nn.functional.softmax(output, dim=1)
        _, least_likely_indices = torch.topk(probs, k=2, largest=False)
    elif mode == "ReverseDiffuse":
        noisy_image = image_tensor + torch.randn_like(image_tensor) * noise_ratio
    else:
        raise ValueError("Invalid mode selected. Choose 'increase confidence' or 'ReverseDiffuse'.")
        
    for _ in range(n_iterations):

        # Zero gradients
        model.zero_grad()

        # Forward pass
        output = model(image_tensor)

        # Define inference loss based on mode
        if mode == "increase confidence":
            losses = []
            for idx in least_likely_indices[0]:
                target = torch.full((1,), idx, dtype=torch.long, device=output.device)
                loss = torch.nn.CrossEntropyLoss()(output, target)
                losses.append(loss)
            loss = torch.stack(losses).mean()
        elif mode == "ReverseDiffuse":            
            loss = torch.nn.functional.mse_loss(image_tensor, noisy_image)

        # Backward pass
        loss.backward()

        # Access gradient
        grad = image_tensor.grad  # Gradient is now retained
        if torch.isnan(grad).any():  # Check for NaN values in the gradient
            print("Warning: Gradient contains NaN values. Aborting inference.")
            return None, None  

        grad_norm = grad.view(grad.shape[0], -1).norm(dim=1, keepdim=True).view(grad.shape[0], 1, 1, 1)
        grad = grad / (grad_norm + 1e-10)  # Avoid division by zero

        # Update image tensor
        image_tensor = image_tensor + step_size * grad
        delta = image_tensor - image_original
        delta = torch.clamp(delta, -eps, eps)   # Keep within range
        image_tensor = torch.clamp(image_original + delta, 0, 1)
        image_tensor = image_tensor.clone().detach().requires_grad_(True)  # Ensure it's a new leaf tensor
        image_tensor.retain_grad()

    # Generate gradient visualization
    grad_image = grad.abs().mean(dim=1).squeeze().cpu().numpy()
    grad_image = (grad_image - grad_image.min()) / (grad_image.max() - grad_image.min())
    grad_image = Image.fromarray((grad_image * 255).astype(np.uint8))

    # Convert final processed image back to PIL format
    processed_image = image_tensor.detach().squeeze().permute(1, 2, 0).cpu().numpy()
    processed_image = (processed_image - processed_image.min()) / (processed_image.max() - processed_image.min())
    processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))

    return processed_image, grad_image

# Gradio Interface
iface = gr.Interface(
    fn=lambda image, mode, step_size, eps, noise_ratio, n_iterations: simple_generative_inference(
        image, mode, model, step_size=step_size, eps=eps, noise_ratio=noise_ratio, n_iterations=n_iterations
    ),
    inputs=[
        gr.Image(type="pil", label="Input Image"),  # Input image
        gr.Radio(["increase confidence", "ReverseDiffuse"], label="Inference Mode"),  # Mode selection
        gr.Slider(0.1, 20, value=1, step=0.1, label="Step Size"),  # Step size
        gr.Slider(0.1, 40, value=0.5, step=0.1, label="Epsilon (eps)"),  # Epsilon constraint
        gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Noise Ratio"),  # Noise ratio
        gr.Slider(1, 1000, value=100, step=1, label="Number of Iterations"),  # Number of iterations
    ],
    outputs=[
        gr.Image(label="Processed Image"),  # Processed image
        gr.Image(label="Gradient Visualization")  # Gradient visualization
    ],
    title="Generative Inference",
    description="Perform generative inference on input images using adjustable parameters such as step size, epsilon, noise ratio, and number of iterations."
)



iface.launch()