kricko commited on
Commit
0ef6495
·
verified ·
1 Parent(s): 24bf801

Clear repository before fresh upload

Browse files
Files changed (1) hide show
  1. auditor_inference.py +0 -296
auditor_inference.py DELETED
@@ -1,296 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torchvision import transforms
5
- from PIL import Image
6
- import os
7
- import math
8
- import numpy as np
9
-
10
- # Use same tokenization and model classes from the original file
11
- # without all the training/evaluation boilerplates.
12
- # To keep this script truly standalone, we bring over the tokenizer and CompleteMultiTaskAuditor.
13
-
14
- class SimpleTokenizer:
15
- def __init__(self, vocab_dir='./tokenizer_vocab'):
16
- self.word_to_idx = {"<PAD>": 0, "<UNK>": 1, "<SOS>": 2, "<EOS>": 3}
17
- self.idx_to_word = {0: "<PAD>", 1: "<UNK>", 2: "<SOS>", 3: "<EOS>"}
18
-
19
- # Try to load existing vocab if doing inference
20
- import json
21
- vocab_path = os.path.join(vocab_dir, 'vocab.json')
22
- if os.path.exists(vocab_path):
23
- with open(vocab_path, 'r') as f:
24
- self.word_to_idx = json.load(f)
25
- self.idx_to_word = {int(k): v for k, v in self.word_to_idx.items()}
26
-
27
- def encode(self, text, max_length=77):
28
- import re
29
- if not isinstance(text, str):
30
- text = ""
31
- text = str(text).lower()
32
- words = re.findall(r'\w+', text)
33
-
34
- tokens = [self.word_to_idx["<SOS>"]]
35
- for word in words:
36
- tokens.append(self.word_to_idx.get(word, self.word_to_idx["<UNK>"]))
37
- tokens.append(self.word_to_idx["<EOS>"])
38
-
39
- if len(tokens) > max_length:
40
- tokens = tokens[:max_length-1] + [self.word_to_idx["<EOS>"]]
41
- else:
42
- tokens = tokens + [self.word_to_idx["<PAD>"]] * (max_length - len(tokens))
43
-
44
- return torch.tensor(tokens, dtype=torch.long)
45
-
46
- # Basic dense block for feature extraction
47
- class DenseBlock(nn.Module):
48
- def __init__(self, in_channels, growth_rate, num_layers):
49
- super().__init__()
50
- self.layers = nn.ModuleList()
51
- for i in range(num_layers):
52
- self.layers.append(
53
- nn.Sequential(
54
- nn.BatchNorm2d(in_channels + i * growth_rate),
55
- nn.ReLU(inplace=True),
56
- nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
57
- )
58
- )
59
-
60
- def forward(self, x):
61
- features = [x]
62
- for layer in self.layers:
63
- new_feature = layer(torch.cat(features, 1))
64
- features.append(new_feature)
65
- return torch.cat(features, 1)
66
-
67
- class TransitionLayer(nn.Module):
68
- def __init__(self, in_channels, out_channels):
69
- super().__init__()
70
- self.transition = nn.Sequential(
71
- nn.BatchNorm2d(in_channels),
72
- nn.ReLU(inplace=True),
73
- nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
74
- nn.AvgPool2d(kernel_size=2, stride=2)
75
- )
76
-
77
- def forward(self, x):
78
- return self.transition(x)
79
-
80
- class ExtractorBackbone(nn.Module):
81
- def __init__(self):
82
- super().__init__()
83
- self.init_conv = nn.Sequential(
84
- nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
85
- nn.BatchNorm2d(64),
86
- nn.ReLU(inplace=True),
87
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
88
- )
89
- self.block1 = DenseBlock(64, 32, 6)
90
- self.trans1 = TransitionLayer(256, 128)
91
- self.block2 = DenseBlock(128, 32, 12)
92
- self.trans2 = TransitionLayer(512, 256)
93
- self.block3 = DenseBlock(256, 32, 24)
94
-
95
- def forward(self, x):
96
- x = self.init_conv(x)
97
- x = self.block1(x)
98
- x = self.trans1(x)
99
- x = self.block2(x)
100
- x = self.trans2(x)
101
- x = self.block3(x)
102
- return x
103
-
104
- class AdversarialImageAuditor(nn.Module):
105
- def __init__(self, num_classes=4, vocab_size=10000):
106
- super().__init__()
107
- self.backbone = ExtractorBackbone()
108
- feature_dim = 1024
109
-
110
- self.text_embedding = nn.Embedding(vocab_size, 256)
111
- self.text_rnn = nn.GRU(256, 256, batch_first=True, bidirectional=True)
112
- self.text_proj = nn.Linear(512, feature_dim)
113
-
114
- self.timestep_embed = nn.Sequential(
115
- nn.Linear(1, 128), nn.ReLU(),
116
- nn.Linear(128, feature_dim)
117
- )
118
-
119
- self.film_gamma = nn.Linear(feature_dim, feature_dim)
120
- self.film_beta = nn.Linear(feature_dim, feature_dim)
121
-
122
- self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True)
123
- self.norm1 = nn.LayerNorm(feature_dim)
124
-
125
- self.bottleneck = nn.Sequential(
126
- nn.Conv2d(feature_dim, 256, kernel_size=1),
127
- nn.BatchNorm2d(256), nn.ReLU(inplace=True),
128
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
129
- nn.BatchNorm2d(256), nn.ReLU(inplace=True)
130
- )
131
-
132
- self.adversarial_head = nn.Conv2d(256, 1, kernel_size=1)
133
- self.class_head = nn.Conv2d(256, num_classes, kernel_size=1)
134
- self.seam_quality_head = nn.Conv2d(256, 1, kernel_size=1)
135
- self.quality_head = nn.Linear(256, 1)
136
-
137
- self.relative_adv_head = nn.Sequential(
138
- nn.Linear(256, 128), nn.ReLU(),
139
- nn.Linear(128, 1), nn.Sigmoid()
140
- )
141
-
142
- self.img_faith_proj = nn.Linear(256, 128)
143
- self.txt_faith_proj = nn.Linear(feature_dim, 128)
144
- self.log_temperature = nn.Parameter(torch.tensor([0.0]))
145
-
146
- def forward(self, image, text_tokens=None, timestep=None, return_features=False):
147
- batch_size = image.size(0)
148
- img_features = self.backbone(image)
149
- _, f_c, f_h, f_w = img_features.shape
150
-
151
- global_text = torch.zeros(batch_size, f_c, device=image.device)
152
- text_seq = None
153
- padding_mask = None
154
-
155
- if text_tokens is not None:
156
- text_emb = self.text_embedding(text_tokens)
157
- text_out, _ = self.text_rnn(text_emb)
158
- text_seq = self.text_proj(text_out)
159
- global_text = torch.mean(text_seq, dim=1)
160
- padding_mask = (text_tokens == 0)
161
- if padding_mask.all():
162
- padding_mask[:, 0] = False
163
-
164
- time_emb = self.timestep_embed(timestep) if timestep is not None else torch.zeros(batch_size, f_c, device=image.device)
165
- cond_vec = global_text + time_emb
166
-
167
- gamma = torch.clamp(self.film_gamma(cond_vec), -3.0, 3.0)
168
- beta = torch.clamp(self.film_beta(cond_vec), -3.0, 3.0)
169
-
170
- gamma = gamma.view(batch_size, f_c, 1, 1).expand_as(img_features)
171
- beta = beta.view(batch_size, f_c, 1, 1).expand_as(img_features)
172
-
173
- fused_features = img_features * (1 + gamma) + beta
174
- img_seq = fused_features.flatten(2).transpose(1, 2)
175
-
176
- if text_seq is not None:
177
- img_seq_normed = self.norm1(img_seq)
178
- attn_out, _ = self.cross_attn(query=img_seq_normed, key=text_seq, value=text_seq, key_padding_mask=padding_mask)
179
- img_seq = img_seq + attn_out
180
- if torch.isnan(img_seq).any():
181
- img_seq = img_seq_normed
182
-
183
- fused_features = img_seq.transpose(1, 2).view(batch_size, f_c, f_h, f_w)
184
- enhanced_features = self.bottleneck(fused_features)
185
-
186
- adv_map = self.adversarial_head(enhanced_features)
187
- class_map = self.class_head(enhanced_features)
188
- seam_map = torch.sigmoid(self.seam_quality_head(enhanced_features))
189
-
190
- global_pool = F.adaptive_avg_pool2d(enhanced_features, (1, 1)).view(batch_size, -1)
191
- quality_logits = self.quality_head(global_pool)
192
- adv_logits = F.adaptive_max_pool2d(adv_map, (1, 1)).view(batch_size, -1)
193
- class_logits = F.adaptive_max_pool2d(class_map, (1, 1)).view(batch_size, -1)
194
- seam_score = F.adaptive_avg_pool2d(seam_map, (1, 1)).view(batch_size, -1)
195
- relative_adv = self.relative_adv_head(global_pool)
196
-
197
- v_img = self.img_faith_proj(global_pool)
198
- v_txt = self.txt_faith_proj(global_text)
199
-
200
- v_img = F.normalize(v_img, p=2, dim=1)
201
- v_txt = F.normalize(v_txt, p=2, dim=1)
202
-
203
- out = {
204
- 'binary_logits': adv_logits,
205
- 'class_logits': class_logits,
206
- 'quality_logits': quality_logits,
207
- 'seam_quality_score': seam_score,
208
- 'relative_adv_score': relative_adv,
209
- 'img_embed': v_img,
210
- 'txt_embed': v_txt
211
- }
212
-
213
- if return_features:
214
- out['adversarial_map'] = torch.sigmoid(adv_map)
215
- out['object_heatmaps'] = torch.sigmoid(class_map)
216
- out['seam_quality_map'] = seam_map
217
- out['class_map'] = class_map
218
-
219
- return out
220
-
221
-
222
- def audit_image(model_path, image_path, prompt="", num_classes=4):
223
- """
224
- Independent plug-and-play function for auditing an image using the standalone model weights.
225
- """
226
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
227
-
228
- tokenizer = SimpleTokenizer(vocab_dir='./tokenizer_vocab')
229
- vocab_size = len(tokenizer.word_to_idx)
230
-
231
- model = AdversarialImageAuditor(num_classes=num_classes, vocab_size=vocab_size)
232
-
233
- if os.path.exists(model_path):
234
- model.load_state_dict(torch.load(model_path, map_location=device))
235
- print(f"Loaded weights from {model_path}")
236
- else:
237
- print(f"Warning: {model_path} not found. Running with random weights.")
238
-
239
- model.to(device)
240
- model.eval()
241
-
242
- transform = transforms.Compose([
243
- transforms.Resize((224, 224)),
244
- transforms.ToTensor(),
245
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
246
- ])
247
-
248
- image = Image.open(image_path).convert('RGB')
249
- image_tensor = transform(image).unsqueeze(0).to(device)
250
-
251
- text_tokens = tokenizer.encode(prompt).unsqueeze(0).to(device)
252
- timestep = torch.tensor([[0.0]], dtype=torch.float32).to(device)
253
-
254
- with torch.no_grad():
255
- outputs = model(image_tensor, text_tokens=text_tokens, timestep=timestep)
256
-
257
- binary_prob = torch.sigmoid(outputs['binary_logits']).item()
258
- global_safety_score = 1.0 - binary_prob
259
-
260
- class_probs = F.softmax(outputs['class_logits'], dim=1)[0].cpu().numpy()
261
-
262
- # We use the generic 4 classes mapping here for the generic auditor
263
- CLASS_NAMES = ['Safe', 'NSFW', 'Gore', 'Weapons']
264
- category_probabilities = {CLASS_NAMES[i]: float(class_probs[i]) for i in range(len(CLASS_NAMES))}
265
-
266
- cos_sim = F.cosine_similarity(outputs['img_embed'], outputs['txt_embed'], dim=-1).item()
267
- faithfulness_score = (cos_sim + 1.0) / 2.0
268
-
269
- seam_quality = outputs['seam_quality_score'].item()
270
-
271
- return {
272
- "global_safety_score": global_safety_score,
273
- "is_adversarial": binary_prob > 0.5,
274
- "category_probabilities": category_probabilities,
275
- "faithfulness_score": faithfulness_score,
276
- "seam_quality": seam_quality,
277
- }
278
-
279
- if __name__ == "__main__":
280
- import argparse
281
- parser = argparse.ArgumentParser("Adversarial Image Auditor Inference")
282
- parser.add_argument("--model", type=str, required=True, help="Path to best.pth weights")
283
- parser.add_argument("--image", type=str, required=True, help="Path to internal image")
284
- parser.add_argument("--prompt", type=str, default="", help="Prompt given to the generator")
285
- args = parser.parse_args()
286
-
287
- res = audit_image(args.model, args.image, args.prompt)
288
- for k, v in res.items():
289
- if isinstance(v, dict):
290
- print(f"{k}:")
291
- for sub_k, sub_v in v.items():
292
- print(f" {sub_k}: {sub_v:.4f}")
293
- elif isinstance(v, float):
294
- print(f"{k}: {v:.4f}")
295
- else:
296
- print(f"{k}: {v}")