kricko commited on
Commit
dc8bbcc
·
verified ·
1 Parent(s): 0ef6495

Add clean standalone inference script

Browse files
Files changed (1) hide show
  1. auditor_inference.py +296 -0
auditor_inference.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import os
7
+ import math
8
+ import numpy as np
9
+
10
+ # Use same tokenization and model classes from the original file
11
+ # without all the training/evaluation boilerplates.
12
+ # To keep this script truly standalone, we bring over the tokenizer and CompleteMultiTaskAuditor.
13
+
14
+ class SimpleTokenizer:
15
+ def __init__(self, vocab_dir='./tokenizer_vocab'):
16
+ self.word_to_idx = {"<PAD>": 0, "<UNK>": 1, "<SOS>": 2, "<EOS>": 3}
17
+ self.idx_to_word = {0: "<PAD>", 1: "<UNK>", 2: "<SOS>", 3: "<EOS>"}
18
+
19
+ # Try to load existing vocab if doing inference
20
+ import json
21
+ vocab_path = os.path.join(vocab_dir, 'vocab.json')
22
+ if os.path.exists(vocab_path):
23
+ with open(vocab_path, 'r') as f:
24
+ self.word_to_idx = json.load(f)
25
+ self.idx_to_word = {int(k): v for k, v in self.word_to_idx.items()}
26
+
27
+ def encode(self, text, max_length=77):
28
+ import re
29
+ if not isinstance(text, str):
30
+ text = ""
31
+ text = str(text).lower()
32
+ words = re.findall(r'\w+', text)
33
+
34
+ tokens = [self.word_to_idx["<SOS>"]]
35
+ for word in words:
36
+ tokens.append(self.word_to_idx.get(word, self.word_to_idx["<UNK>"]))
37
+ tokens.append(self.word_to_idx["<EOS>"])
38
+
39
+ if len(tokens) > max_length:
40
+ tokens = tokens[:max_length-1] + [self.word_to_idx["<EOS>"]]
41
+ else:
42
+ tokens = tokens + [self.word_to_idx["<PAD>"]] * (max_length - len(tokens))
43
+
44
+ return torch.tensor(tokens, dtype=torch.long)
45
+
46
+ # Basic dense block for feature extraction
47
+ class DenseBlock(nn.Module):
48
+ def __init__(self, in_channels, growth_rate, num_layers):
49
+ super().__init__()
50
+ self.layers = nn.ModuleList()
51
+ for i in range(num_layers):
52
+ self.layers.append(
53
+ nn.Sequential(
54
+ nn.BatchNorm2d(in_channels + i * growth_rate),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
57
+ )
58
+ )
59
+
60
+ def forward(self, x):
61
+ features = [x]
62
+ for layer in self.layers:
63
+ new_feature = layer(torch.cat(features, 1))
64
+ features.append(new_feature)
65
+ return torch.cat(features, 1)
66
+
67
+ class TransitionLayer(nn.Module):
68
+ def __init__(self, in_channels, out_channels):
69
+ super().__init__()
70
+ self.transition = nn.Sequential(
71
+ nn.BatchNorm2d(in_channels),
72
+ nn.ReLU(inplace=True),
73
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
74
+ nn.AvgPool2d(kernel_size=2, stride=2)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.transition(x)
79
+
80
+ class ExtractorBackbone(nn.Module):
81
+ def __init__(self):
82
+ super().__init__()
83
+ self.init_conv = nn.Sequential(
84
+ nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
85
+ nn.BatchNorm2d(64),
86
+ nn.ReLU(inplace=True),
87
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
88
+ )
89
+ self.block1 = DenseBlock(64, 32, 6)
90
+ self.trans1 = TransitionLayer(256, 128)
91
+ self.block2 = DenseBlock(128, 32, 12)
92
+ self.trans2 = TransitionLayer(512, 256)
93
+ self.block3 = DenseBlock(256, 32, 24)
94
+
95
+ def forward(self, x):
96
+ x = self.init_conv(x)
97
+ x = self.block1(x)
98
+ x = self.trans1(x)
99
+ x = self.block2(x)
100
+ x = self.trans2(x)
101
+ x = self.block3(x)
102
+ return x
103
+
104
+ class AdversarialImageAuditor(nn.Module):
105
+ def __init__(self, num_classes=4, vocab_size=10000):
106
+ super().__init__()
107
+ self.backbone = ExtractorBackbone()
108
+ feature_dim = 1024
109
+
110
+ self.text_embedding = nn.Embedding(vocab_size, 256)
111
+ self.text_rnn = nn.GRU(256, 256, batch_first=True, bidirectional=True)
112
+ self.text_proj = nn.Linear(512, feature_dim)
113
+
114
+ self.timestep_embed = nn.Sequential(
115
+ nn.Linear(1, 128), nn.ReLU(),
116
+ nn.Linear(128, feature_dim)
117
+ )
118
+
119
+ self.film_gamma = nn.Linear(feature_dim, feature_dim)
120
+ self.film_beta = nn.Linear(feature_dim, feature_dim)
121
+
122
+ self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True)
123
+ self.norm1 = nn.LayerNorm(feature_dim)
124
+
125
+ self.bottleneck = nn.Sequential(
126
+ nn.Conv2d(feature_dim, 256, kernel_size=1),
127
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True),
128
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
129
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True)
130
+ )
131
+
132
+ self.adversarial_head = nn.Conv2d(256, 1, kernel_size=1)
133
+ self.class_head = nn.Conv2d(256, num_classes, kernel_size=1)
134
+ self.seam_quality_head = nn.Conv2d(256, 1, kernel_size=1)
135
+ self.quality_head = nn.Linear(256, 1)
136
+
137
+ self.relative_adv_head = nn.Sequential(
138
+ nn.Linear(256, 128), nn.ReLU(),
139
+ nn.Linear(128, 1), nn.Sigmoid()
140
+ )
141
+
142
+ self.img_faith_proj = nn.Linear(256, 128)
143
+ self.txt_faith_proj = nn.Linear(feature_dim, 128)
144
+ self.log_temperature = nn.Parameter(torch.tensor([0.0]))
145
+
146
+ def forward(self, image, text_tokens=None, timestep=None, return_features=False):
147
+ batch_size = image.size(0)
148
+ img_features = self.backbone(image)
149
+ _, f_c, f_h, f_w = img_features.shape
150
+
151
+ global_text = torch.zeros(batch_size, f_c, device=image.device)
152
+ text_seq = None
153
+ padding_mask = None
154
+
155
+ if text_tokens is not None:
156
+ text_emb = self.text_embedding(text_tokens)
157
+ text_out, _ = self.text_rnn(text_emb)
158
+ text_seq = self.text_proj(text_out)
159
+ global_text = torch.mean(text_seq, dim=1)
160
+ padding_mask = (text_tokens == 0)
161
+ if padding_mask.all():
162
+ padding_mask[:, 0] = False
163
+
164
+ time_emb = self.timestep_embed(timestep) if timestep is not None else torch.zeros(batch_size, f_c, device=image.device)
165
+ cond_vec = global_text + time_emb
166
+
167
+ gamma = torch.clamp(self.film_gamma(cond_vec), -3.0, 3.0)
168
+ beta = torch.clamp(self.film_beta(cond_vec), -3.0, 3.0)
169
+
170
+ gamma = gamma.view(batch_size, f_c, 1, 1).expand_as(img_features)
171
+ beta = beta.view(batch_size, f_c, 1, 1).expand_as(img_features)
172
+
173
+ fused_features = img_features * (1 + gamma) + beta
174
+ img_seq = fused_features.flatten(2).transpose(1, 2)
175
+
176
+ if text_seq is not None:
177
+ img_seq_normed = self.norm1(img_seq)
178
+ attn_out, _ = self.cross_attn(query=img_seq_normed, key=text_seq, value=text_seq, key_padding_mask=padding_mask)
179
+ img_seq = img_seq + attn_out
180
+ if torch.isnan(img_seq).any():
181
+ img_seq = img_seq_normed
182
+
183
+ fused_features = img_seq.transpose(1, 2).view(batch_size, f_c, f_h, f_w)
184
+ enhanced_features = self.bottleneck(fused_features)
185
+
186
+ adv_map = self.adversarial_head(enhanced_features)
187
+ class_map = self.class_head(enhanced_features)
188
+ seam_map = torch.sigmoid(self.seam_quality_head(enhanced_features))
189
+
190
+ global_pool = F.adaptive_avg_pool2d(enhanced_features, (1, 1)).view(batch_size, -1)
191
+ quality_logits = self.quality_head(global_pool)
192
+ adv_logits = F.adaptive_max_pool2d(adv_map, (1, 1)).view(batch_size, -1)
193
+ class_logits = F.adaptive_max_pool2d(class_map, (1, 1)).view(batch_size, -1)
194
+ seam_score = F.adaptive_avg_pool2d(seam_map, (1, 1)).view(batch_size, -1)
195
+ relative_adv = self.relative_adv_head(global_pool)
196
+
197
+ v_img = self.img_faith_proj(global_pool)
198
+ v_txt = self.txt_faith_proj(global_text)
199
+
200
+ v_img = F.normalize(v_img, p=2, dim=1)
201
+ v_txt = F.normalize(v_txt, p=2, dim=1)
202
+
203
+ out = {
204
+ 'binary_logits': adv_logits,
205
+ 'class_logits': class_logits,
206
+ 'quality_logits': quality_logits,
207
+ 'seam_quality_score': seam_score,
208
+ 'relative_adv_score': relative_adv,
209
+ 'img_embed': v_img,
210
+ 'txt_embed': v_txt
211
+ }
212
+
213
+ if return_features:
214
+ out['adversarial_map'] = torch.sigmoid(adv_map)
215
+ out['object_heatmaps'] = torch.sigmoid(class_map)
216
+ out['seam_quality_map'] = seam_map
217
+ out['class_map'] = class_map
218
+
219
+ return out
220
+
221
+
222
+ def audit_image(model_path, image_path, prompt="", num_classes=4):
223
+ """
224
+ Independent plug-and-play function for auditing an image using the standalone model weights.
225
+ """
226
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
227
+
228
+ tokenizer = SimpleTokenizer(vocab_dir='./tokenizer_vocab')
229
+ vocab_size = len(tokenizer.word_to_idx)
230
+
231
+ model = AdversarialImageAuditor(num_classes=num_classes, vocab_size=vocab_size)
232
+
233
+ if os.path.exists(model_path):
234
+ model.load_state_dict(torch.load(model_path, map_location=device))
235
+ print(f"Loaded weights from {model_path}")
236
+ else:
237
+ print(f"Warning: {model_path} not found. Running with random weights.")
238
+
239
+ model.to(device)
240
+ model.eval()
241
+
242
+ transform = transforms.Compose([
243
+ transforms.Resize((224, 224)),
244
+ transforms.ToTensor(),
245
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
246
+ ])
247
+
248
+ image = Image.open(image_path).convert('RGB')
249
+ image_tensor = transform(image).unsqueeze(0).to(device)
250
+
251
+ text_tokens = tokenizer.encode(prompt).unsqueeze(0).to(device)
252
+ timestep = torch.tensor([[0.0]], dtype=torch.float32).to(device)
253
+
254
+ with torch.no_grad():
255
+ outputs = model(image_tensor, text_tokens=text_tokens, timestep=timestep)
256
+
257
+ binary_prob = torch.sigmoid(outputs['binary_logits']).item()
258
+ global_safety_score = 1.0 - binary_prob
259
+
260
+ class_probs = F.softmax(outputs['class_logits'], dim=1)[0].cpu().numpy()
261
+
262
+ # We use the generic 4 classes mapping here for the generic auditor
263
+ CLASS_NAMES = ['Safe', 'NSFW', 'Gore', 'Weapons']
264
+ category_probabilities = {CLASS_NAMES[i]: float(class_probs[i]) for i in range(len(CLASS_NAMES))}
265
+
266
+ cos_sim = F.cosine_similarity(outputs['img_embed'], outputs['txt_embed'], dim=-1).item()
267
+ faithfulness_score = (cos_sim + 1.0) / 2.0
268
+
269
+ seam_quality = outputs['seam_quality_score'].item()
270
+
271
+ return {
272
+ "global_safety_score": global_safety_score,
273
+ "is_adversarial": binary_prob > 0.5,
274
+ "category_probabilities": category_probabilities,
275
+ "faithfulness_score": faithfulness_score,
276
+ "seam_quality": seam_quality,
277
+ }
278
+
279
+ if __name__ == "__main__":
280
+ import argparse
281
+ parser = argparse.ArgumentParser("Adversarial Image Auditor Inference")
282
+ parser.add_argument("--model", type=str, required=True, help="Path to best.pth weights")
283
+ parser.add_argument("--image", type=str, required=True, help="Path to internal image")
284
+ parser.add_argument("--prompt", type=str, default="", help="Prompt given to the generator")
285
+ args = parser.parse_args()
286
+
287
+ res = audit_image(args.model, args.image, args.prompt)
288
+ for k, v in res.items():
289
+ if isinstance(v, dict):
290
+ print(f"{k}:")
291
+ for sub_k, sub_v in v.items():
292
+ print(f" {sub_k}: {sub_v:.4f}")
293
+ elif isinstance(v, float):
294
+ print(f"{k}: {v:.4f}")
295
+ else:
296
+ print(f"{k}: {v}")