import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import os import math import numpy as np # Use same tokenization and model classes from the original file # without all the training/evaluation boilerplates. # To keep this script truly standalone, we bring over the tokenizer and CompleteMultiTaskAuditor. class SimpleTokenizer: def __init__(self, vocab_dir='./tokenizer_vocab'): self.word_to_idx = {"": 0, "": 1, "": 2, "": 3} self.idx_to_word = {0: "", 1: "", 2: "", 3: ""} # Try to load existing vocab if doing inference import json vocab_path = os.path.join(vocab_dir, 'vocab.json') if os.path.exists(vocab_path): with open(vocab_path, 'r') as f: self.word_to_idx = json.load(f) self.idx_to_word = {int(k): v for k, v in self.word_to_idx.items()} def encode(self, text, max_length=77): import re if not isinstance(text, str): text = "" text = str(text).lower() words = re.findall(r'\w+', text) tokens = [self.word_to_idx[""]] for word in words: tokens.append(self.word_to_idx.get(word, self.word_to_idx[""])) tokens.append(self.word_to_idx[""]) if len(tokens) > max_length: tokens = tokens[:max_length-1] + [self.word_to_idx[""]] else: tokens = tokens + [self.word_to_idx[""]] * (max_length - len(tokens)) return torch.tensor(tokens, dtype=torch.long) # Basic dense block for feature extraction class DenseBlock(nn.Module): def __init__(self, in_channels, growth_rate, num_layers): super().__init__() self.layers = nn.ModuleList() for i in range(num_layers): self.layers.append( nn.Sequential( nn.BatchNorm2d(in_channels + i * growth_rate), nn.ReLU(inplace=True), nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) ) ) def forward(self, x): features = [x] for layer in self.layers: new_feature = layer(torch.cat(features, 1)) features.append(new_feature) return torch.cat(features, 1) class TransitionLayer(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.transition = nn.Sequential( nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.AvgPool2d(kernel_size=2, stride=2) ) def forward(self, x): return self.transition(x) class ExtractorBackbone(nn.Module): def __init__(self): super().__init__() self.init_conv = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) self.block1 = DenseBlock(64, 32, 6) self.trans1 = TransitionLayer(256, 128) self.block2 = DenseBlock(128, 32, 12) self.trans2 = TransitionLayer(512, 256) self.block3 = DenseBlock(256, 32, 24) def forward(self, x): x = self.init_conv(x) x = self.block1(x) x = self.trans1(x) x = self.block2(x) x = self.trans2(x) x = self.block3(x) return x class AdversarialImageAuditor(nn.Module): def __init__(self, num_classes=4, vocab_size=10000): super().__init__() self.backbone = ExtractorBackbone() feature_dim = 1024 self.text_embedding = nn.Embedding(vocab_size, 256) self.text_rnn = nn.GRU(256, 256, batch_first=True, bidirectional=True) self.text_proj = nn.Linear(512, feature_dim) self.timestep_embed = nn.Sequential( nn.Linear(1, 128), nn.ReLU(), nn.Linear(128, feature_dim) ) self.film_gamma = nn.Linear(feature_dim, feature_dim) self.film_beta = nn.Linear(feature_dim, feature_dim) self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True) self.norm1 = nn.LayerNorm(feature_dim) self.bottleneck = nn.Sequential( nn.Conv2d(feature_dim, 256, kernel_size=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True) ) self.adversarial_head = nn.Conv2d(256, 1, kernel_size=1) self.class_head = nn.Conv2d(256, num_classes, kernel_size=1) self.seam_quality_head = nn.Conv2d(256, 1, kernel_size=1) self.quality_head = nn.Linear(256, 1) self.relative_adv_head = nn.Sequential( nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid() ) self.img_faith_proj = nn.Linear(256, 128) self.txt_faith_proj = nn.Linear(feature_dim, 128) self.log_temperature = nn.Parameter(torch.tensor([0.0])) def forward(self, image, text_tokens=None, timestep=None, return_features=False): batch_size = image.size(0) img_features = self.backbone(image) _, f_c, f_h, f_w = img_features.shape global_text = torch.zeros(batch_size, f_c, device=image.device) text_seq = None padding_mask = None if text_tokens is not None: text_emb = self.text_embedding(text_tokens) text_out, _ = self.text_rnn(text_emb) text_seq = self.text_proj(text_out) global_text = torch.mean(text_seq, dim=1) padding_mask = (text_tokens == 0) if padding_mask.all(): padding_mask[:, 0] = False time_emb = self.timestep_embed(timestep) if timestep is not None else torch.zeros(batch_size, f_c, device=image.device) cond_vec = global_text + time_emb gamma = torch.clamp(self.film_gamma(cond_vec), -3.0, 3.0) beta = torch.clamp(self.film_beta(cond_vec), -3.0, 3.0) gamma = gamma.view(batch_size, f_c, 1, 1).expand_as(img_features) beta = beta.view(batch_size, f_c, 1, 1).expand_as(img_features) fused_features = img_features * (1 + gamma) + beta img_seq = fused_features.flatten(2).transpose(1, 2) if text_seq is not None: img_seq_normed = self.norm1(img_seq) attn_out, _ = self.cross_attn(query=img_seq_normed, key=text_seq, value=text_seq, key_padding_mask=padding_mask) img_seq = img_seq + attn_out if torch.isnan(img_seq).any(): img_seq = img_seq_normed fused_features = img_seq.transpose(1, 2).view(batch_size, f_c, f_h, f_w) enhanced_features = self.bottleneck(fused_features) adv_map = self.adversarial_head(enhanced_features) class_map = self.class_head(enhanced_features) seam_map = torch.sigmoid(self.seam_quality_head(enhanced_features)) global_pool = F.adaptive_avg_pool2d(enhanced_features, (1, 1)).view(batch_size, -1) quality_logits = self.quality_head(global_pool) adv_logits = F.adaptive_max_pool2d(adv_map, (1, 1)).view(batch_size, -1) class_logits = F.adaptive_max_pool2d(class_map, (1, 1)).view(batch_size, -1) seam_score = F.adaptive_avg_pool2d(seam_map, (1, 1)).view(batch_size, -1) relative_adv = self.relative_adv_head(global_pool) v_img = self.img_faith_proj(global_pool) v_txt = self.txt_faith_proj(global_text) v_img = F.normalize(v_img, p=2, dim=1) v_txt = F.normalize(v_txt, p=2, dim=1) out = { 'binary_logits': adv_logits, 'class_logits': class_logits, 'quality_logits': quality_logits, 'seam_quality_score': seam_score, 'relative_adv_score': relative_adv, 'img_embed': v_img, 'txt_embed': v_txt } if return_features: out['adversarial_map'] = torch.sigmoid(adv_map) out['object_heatmaps'] = torch.sigmoid(class_map) out['seam_quality_map'] = seam_map out['class_map'] = class_map return out def audit_image(model_path, image_path, prompt="", num_classes=4): """ Independent plug-and-play function for auditing an image using the standalone model weights. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = SimpleTokenizer(vocab_dir='./tokenizer_vocab') vocab_size = len(tokenizer.word_to_idx) model = AdversarialImageAuditor(num_classes=num_classes, vocab_size=vocab_size) if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=device)) print(f"Loaded weights from {model_path}") else: print(f"Warning: {model_path} not found. Running with random weights.") model.to(device) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') image_tensor = transform(image).unsqueeze(0).to(device) text_tokens = tokenizer.encode(prompt).unsqueeze(0).to(device) timestep = torch.tensor([[0.0]], dtype=torch.float32).to(device) with torch.no_grad(): outputs = model(image_tensor, text_tokens=text_tokens, timestep=timestep) binary_prob = torch.sigmoid(outputs['binary_logits']).item() global_safety_score = 1.0 - binary_prob class_probs = F.softmax(outputs['class_logits'], dim=1)[0].cpu().numpy() # We use the generic 4 classes mapping here for the generic auditor CLASS_NAMES = ['Safe', 'NSFW', 'Gore', 'Weapons'] category_probabilities = {CLASS_NAMES[i]: float(class_probs[i]) for i in range(len(CLASS_NAMES))} cos_sim = F.cosine_similarity(outputs['img_embed'], outputs['txt_embed'], dim=-1).item() faithfulness_score = (cos_sim + 1.0) / 2.0 seam_quality = outputs['seam_quality_score'].item() return { "global_safety_score": global_safety_score, "is_adversarial": binary_prob > 0.5, "category_probabilities": category_probabilities, "faithfulness_score": faithfulness_score, "seam_quality": seam_quality, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Adversarial Image Auditor Inference") parser.add_argument("--model", type=str, required=True, help="Path to best.pth weights") parser.add_argument("--image", type=str, required=True, help="Path to internal image") parser.add_argument("--prompt", type=str, default="", help="Prompt given to the generator") args = parser.parse_args() res = audit_image(args.model, args.image, args.prompt) for k, v in res.items(): if isinstance(v, dict): print(f"{k}:") for sub_k, sub_v in v.items(): print(f" {sub_k}: {sub_v:.4f}") elif isinstance(v, float): print(f"{k}: {v:.4f}") else: print(f"{k}: {v}")