| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision import transforms |
| | from PIL import Image |
| | import os |
| | import math |
| | import numpy as np |
| |
|
| | |
| | |
| | |
| |
|
| | class SimpleTokenizer: |
| | def __init__(self, vocab_dir='./tokenizer_vocab'): |
| | self.word_to_idx = {"<PAD>": 0, "<UNK>": 1, "<SOS>": 2, "<EOS>": 3} |
| | self.idx_to_word = {0: "<PAD>", 1: "<UNK>", 2: "<SOS>", 3: "<EOS>"} |
| | |
| | |
| | import json |
| | vocab_path = os.path.join(vocab_dir, 'vocab.json') |
| | if os.path.exists(vocab_path): |
| | with open(vocab_path, 'r') as f: |
| | self.word_to_idx = json.load(f) |
| | self.idx_to_word = {int(k): v for k, v in self.word_to_idx.items()} |
| | |
| | def encode(self, text, max_length=77): |
| | import re |
| | if not isinstance(text, str): |
| | text = "" |
| | text = str(text).lower() |
| | words = re.findall(r'\w+', text) |
| | |
| | tokens = [self.word_to_idx["<SOS>"]] |
| | for word in words: |
| | tokens.append(self.word_to_idx.get(word, self.word_to_idx["<UNK>"])) |
| | tokens.append(self.word_to_idx["<EOS>"]) |
| | |
| | if len(tokens) > max_length: |
| | tokens = tokens[:max_length-1] + [self.word_to_idx["<EOS>"]] |
| | else: |
| | tokens = tokens + [self.word_to_idx["<PAD>"]] * (max_length - len(tokens)) |
| | |
| | return torch.tensor(tokens, dtype=torch.long) |
| |
|
| | |
| | class DenseBlock(nn.Module): |
| | def __init__(self, in_channels, growth_rate, num_layers): |
| | super().__init__() |
| | self.layers = nn.ModuleList() |
| | for i in range(num_layers): |
| | self.layers.append( |
| | nn.Sequential( |
| | nn.BatchNorm2d(in_channels + i * growth_rate), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) |
| | ) |
| | ) |
| | |
| | def forward(self, x): |
| | features = [x] |
| | for layer in self.layers: |
| | new_feature = layer(torch.cat(features, 1)) |
| | features.append(new_feature) |
| | return torch.cat(features, 1) |
| |
|
| | class TransitionLayer(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super().__init__() |
| | self.transition = nn.Sequential( |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), |
| | nn.AvgPool2d(kernel_size=2, stride=2) |
| | ) |
| | |
| | def forward(self, x): |
| | return self.transition(x) |
| |
|
| | class ExtractorBackbone(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.init_conv = nn.Sequential( |
| | nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(inplace=True), |
| | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| | ) |
| | self.block1 = DenseBlock(64, 32, 6) |
| | self.trans1 = TransitionLayer(256, 128) |
| | self.block2 = DenseBlock(128, 32, 12) |
| | self.trans2 = TransitionLayer(512, 256) |
| | self.block3 = DenseBlock(256, 32, 24) |
| | |
| | def forward(self, x): |
| | x = self.init_conv(x) |
| | x = self.block1(x) |
| | x = self.trans1(x) |
| | x = self.block2(x) |
| | x = self.trans2(x) |
| | x = self.block3(x) |
| | return x |
| |
|
| | class AdversarialImageAuditor(nn.Module): |
| | def __init__(self, num_classes=4, vocab_size=10000): |
| | super().__init__() |
| | self.backbone = ExtractorBackbone() |
| | feature_dim = 1024 |
| | |
| | self.text_embedding = nn.Embedding(vocab_size, 256) |
| | self.text_rnn = nn.GRU(256, 256, batch_first=True, bidirectional=True) |
| | self.text_proj = nn.Linear(512, feature_dim) |
| | |
| | self.timestep_embed = nn.Sequential( |
| | nn.Linear(1, 128), nn.ReLU(), |
| | nn.Linear(128, feature_dim) |
| | ) |
| | |
| | self.film_gamma = nn.Linear(feature_dim, feature_dim) |
| | self.film_beta = nn.Linear(feature_dim, feature_dim) |
| | |
| | self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True) |
| | self.norm1 = nn.LayerNorm(feature_dim) |
| | |
| | self.bottleneck = nn.Sequential( |
| | nn.Conv2d(feature_dim, 256, kernel_size=1), |
| | nn.BatchNorm2d(256), nn.ReLU(inplace=True), |
| | nn.Conv2d(256, 256, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(256), nn.ReLU(inplace=True) |
| | ) |
| | |
| | self.adversarial_head = nn.Conv2d(256, 1, kernel_size=1) |
| | self.class_head = nn.Conv2d(256, num_classes, kernel_size=1) |
| | self.seam_quality_head = nn.Conv2d(256, 1, kernel_size=1) |
| | self.quality_head = nn.Linear(256, 1) |
| | |
| | self.relative_adv_head = nn.Sequential( |
| | nn.Linear(256, 128), nn.ReLU(), |
| | nn.Linear(128, 1), nn.Sigmoid() |
| | ) |
| | |
| | self.img_faith_proj = nn.Linear(256, 128) |
| | self.txt_faith_proj = nn.Linear(feature_dim, 128) |
| | self.log_temperature = nn.Parameter(torch.tensor([0.0])) |
| | |
| | def forward(self, image, text_tokens=None, timestep=None, return_features=False): |
| | batch_size = image.size(0) |
| | img_features = self.backbone(image) |
| | _, f_c, f_h, f_w = img_features.shape |
| |
|
| | global_text = torch.zeros(batch_size, f_c, device=image.device) |
| | text_seq = None |
| | padding_mask = None |
| | |
| | if text_tokens is not None: |
| | text_emb = self.text_embedding(text_tokens) |
| | text_out, _ = self.text_rnn(text_emb) |
| | text_seq = self.text_proj(text_out) |
| | global_text = torch.mean(text_seq, dim=1) |
| | padding_mask = (text_tokens == 0) |
| | if padding_mask.all(): |
| | padding_mask[:, 0] = False |
| | |
| | time_emb = self.timestep_embed(timestep) if timestep is not None else torch.zeros(batch_size, f_c, device=image.device) |
| | cond_vec = global_text + time_emb |
| | |
| | gamma = torch.clamp(self.film_gamma(cond_vec), -3.0, 3.0) |
| | beta = torch.clamp(self.film_beta(cond_vec), -3.0, 3.0) |
| | |
| | gamma = gamma.view(batch_size, f_c, 1, 1).expand_as(img_features) |
| | beta = beta.view(batch_size, f_c, 1, 1).expand_as(img_features) |
| | |
| | fused_features = img_features * (1 + gamma) + beta |
| | img_seq = fused_features.flatten(2).transpose(1, 2) |
| | |
| | if text_seq is not None: |
| | img_seq_normed = self.norm1(img_seq) |
| | attn_out, _ = self.cross_attn(query=img_seq_normed, key=text_seq, value=text_seq, key_padding_mask=padding_mask) |
| | img_seq = img_seq + attn_out |
| | if torch.isnan(img_seq).any(): |
| | img_seq = img_seq_normed |
| | |
| | fused_features = img_seq.transpose(1, 2).view(batch_size, f_c, f_h, f_w) |
| | enhanced_features = self.bottleneck(fused_features) |
| | |
| | adv_map = self.adversarial_head(enhanced_features) |
| | class_map = self.class_head(enhanced_features) |
| | seam_map = torch.sigmoid(self.seam_quality_head(enhanced_features)) |
| | |
| | global_pool = F.adaptive_avg_pool2d(enhanced_features, (1, 1)).view(batch_size, -1) |
| | quality_logits = self.quality_head(global_pool) |
| | adv_logits = F.adaptive_max_pool2d(adv_map, (1, 1)).view(batch_size, -1) |
| | class_logits = F.adaptive_max_pool2d(class_map, (1, 1)).view(batch_size, -1) |
| | seam_score = F.adaptive_avg_pool2d(seam_map, (1, 1)).view(batch_size, -1) |
| | relative_adv = self.relative_adv_head(global_pool) |
| | |
| | v_img = self.img_faith_proj(global_pool) |
| | v_txt = self.txt_faith_proj(global_text) |
| | |
| | v_img = F.normalize(v_img, p=2, dim=1) |
| | v_txt = F.normalize(v_txt, p=2, dim=1) |
| | |
| | out = { |
| | 'binary_logits': adv_logits, |
| | 'class_logits': class_logits, |
| | 'quality_logits': quality_logits, |
| | 'seam_quality_score': seam_score, |
| | 'relative_adv_score': relative_adv, |
| | 'img_embed': v_img, |
| | 'txt_embed': v_txt |
| | } |
| | |
| | if return_features: |
| | out['adversarial_map'] = torch.sigmoid(adv_map) |
| | out['object_heatmaps'] = torch.sigmoid(class_map) |
| | out['seam_quality_map'] = seam_map |
| | out['class_map'] = class_map |
| | |
| | return out |
| |
|
| |
|
| | def audit_image(model_path, image_path, prompt="", num_classes=4): |
| | """ |
| | Independent plug-and-play function for auditing an image using the standalone model weights. |
| | """ |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | |
| | tokenizer = SimpleTokenizer(vocab_dir='./tokenizer_vocab') |
| | vocab_size = len(tokenizer.word_to_idx) |
| | |
| | model = AdversarialImageAuditor(num_classes=num_classes, vocab_size=vocab_size) |
| | |
| | if os.path.exists(model_path): |
| | model.load_state_dict(torch.load(model_path, map_location=device)) |
| | print(f"Loaded weights from {model_path}") |
| | else: |
| | print(f"Warning: {model_path} not found. Running with random weights.") |
| | |
| | model.to(device) |
| | model.eval() |
| |
|
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | image = Image.open(image_path).convert('RGB') |
| | image_tensor = transform(image).unsqueeze(0).to(device) |
| | |
| | text_tokens = tokenizer.encode(prompt).unsqueeze(0).to(device) |
| | timestep = torch.tensor([[0.0]], dtype=torch.float32).to(device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(image_tensor, text_tokens=text_tokens, timestep=timestep) |
| |
|
| | binary_prob = torch.sigmoid(outputs['binary_logits']).item() |
| | global_safety_score = 1.0 - binary_prob |
| |
|
| | class_probs = F.softmax(outputs['class_logits'], dim=1)[0].cpu().numpy() |
| | |
| | |
| | CLASS_NAMES = ['Safe', 'NSFW', 'Gore', 'Weapons'] |
| | category_probabilities = {CLASS_NAMES[i]: float(class_probs[i]) for i in range(len(CLASS_NAMES))} |
| |
|
| | cos_sim = F.cosine_similarity(outputs['img_embed'], outputs['txt_embed'], dim=-1).item() |
| | faithfulness_score = (cos_sim + 1.0) / 2.0 |
| |
|
| | seam_quality = outputs['seam_quality_score'].item() |
| |
|
| | return { |
| | "global_safety_score": global_safety_score, |
| | "is_adversarial": binary_prob > 0.5, |
| | "category_probabilities": category_probabilities, |
| | "faithfulness_score": faithfulness_score, |
| | "seam_quality": seam_quality, |
| | } |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser("Adversarial Image Auditor Inference") |
| | parser.add_argument("--model", type=str, required=True, help="Path to best.pth weights") |
| | parser.add_argument("--image", type=str, required=True, help="Path to internal image") |
| | parser.add_argument("--prompt", type=str, default="", help="Prompt given to the generator") |
| | args = parser.parse_args() |
| | |
| | res = audit_image(args.model, args.image, args.prompt) |
| | for k, v in res.items(): |
| | if isinstance(v, dict): |
| | print(f"{k}:") |
| | for sub_k, sub_v in v.items(): |
| | print(f" {sub_k}: {sub_v:.4f}") |
| | elif isinstance(v, float): |
| | print(f"{k}: {v:.4f}") |
| | else: |
| | print(f"{k}: {v}") |
| |
|