Auditor_Model / auditor_inference.py
kricko's picture
Add clean standalone inference script
dc8bbcc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import math
import numpy as np
# Use same tokenization and model classes from the original file
# without all the training/evaluation boilerplates.
# To keep this script truly standalone, we bring over the tokenizer and CompleteMultiTaskAuditor.
class SimpleTokenizer:
def __init__(self, vocab_dir='./tokenizer_vocab'):
self.word_to_idx = {"<PAD>": 0, "<UNK>": 1, "<SOS>": 2, "<EOS>": 3}
self.idx_to_word = {0: "<PAD>", 1: "<UNK>", 2: "<SOS>", 3: "<EOS>"}
# Try to load existing vocab if doing inference
import json
vocab_path = os.path.join(vocab_dir, 'vocab.json')
if os.path.exists(vocab_path):
with open(vocab_path, 'r') as f:
self.word_to_idx = json.load(f)
self.idx_to_word = {int(k): v for k, v in self.word_to_idx.items()}
def encode(self, text, max_length=77):
import re
if not isinstance(text, str):
text = ""
text = str(text).lower()
words = re.findall(r'\w+', text)
tokens = [self.word_to_idx["<SOS>"]]
for word in words:
tokens.append(self.word_to_idx.get(word, self.word_to_idx["<UNK>"]))
tokens.append(self.word_to_idx["<EOS>"])
if len(tokens) > max_length:
tokens = tokens[:max_length-1] + [self.word_to_idx["<EOS>"]]
else:
tokens = tokens + [self.word_to_idx["<PAD>"]] * (max_length - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
# Basic dense block for feature extraction
class DenseBlock(nn.Module):
def __init__(self, in_channels, growth_rate, num_layers):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
nn.Sequential(
nn.BatchNorm2d(in_channels + i * growth_rate),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
)
)
def forward(self, x):
features = [x]
for layer in self.layers:
new_feature = layer(torch.cat(features, 1))
features.append(new_feature)
return torch.cat(features, 1)
class TransitionLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.transition = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.AvgPool2d(kernel_size=2, stride=2)
)
def forward(self, x):
return self.transition(x)
class ExtractorBackbone(nn.Module):
def __init__(self):
super().__init__()
self.init_conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.block1 = DenseBlock(64, 32, 6)
self.trans1 = TransitionLayer(256, 128)
self.block2 = DenseBlock(128, 32, 12)
self.trans2 = TransitionLayer(512, 256)
self.block3 = DenseBlock(256, 32, 24)
def forward(self, x):
x = self.init_conv(x)
x = self.block1(x)
x = self.trans1(x)
x = self.block2(x)
x = self.trans2(x)
x = self.block3(x)
return x
class AdversarialImageAuditor(nn.Module):
def __init__(self, num_classes=4, vocab_size=10000):
super().__init__()
self.backbone = ExtractorBackbone()
feature_dim = 1024
self.text_embedding = nn.Embedding(vocab_size, 256)
self.text_rnn = nn.GRU(256, 256, batch_first=True, bidirectional=True)
self.text_proj = nn.Linear(512, feature_dim)
self.timestep_embed = nn.Sequential(
nn.Linear(1, 128), nn.ReLU(),
nn.Linear(128, feature_dim)
)
self.film_gamma = nn.Linear(feature_dim, feature_dim)
self.film_beta = nn.Linear(feature_dim, feature_dim)
self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True)
self.norm1 = nn.LayerNorm(feature_dim)
self.bottleneck = nn.Sequential(
nn.Conv2d(feature_dim, 256, kernel_size=1),
nn.BatchNorm2d(256), nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256), nn.ReLU(inplace=True)
)
self.adversarial_head = nn.Conv2d(256, 1, kernel_size=1)
self.class_head = nn.Conv2d(256, num_classes, kernel_size=1)
self.seam_quality_head = nn.Conv2d(256, 1, kernel_size=1)
self.quality_head = nn.Linear(256, 1)
self.relative_adv_head = nn.Sequential(
nn.Linear(256, 128), nn.ReLU(),
nn.Linear(128, 1), nn.Sigmoid()
)
self.img_faith_proj = nn.Linear(256, 128)
self.txt_faith_proj = nn.Linear(feature_dim, 128)
self.log_temperature = nn.Parameter(torch.tensor([0.0]))
def forward(self, image, text_tokens=None, timestep=None, return_features=False):
batch_size = image.size(0)
img_features = self.backbone(image)
_, f_c, f_h, f_w = img_features.shape
global_text = torch.zeros(batch_size, f_c, device=image.device)
text_seq = None
padding_mask = None
if text_tokens is not None:
text_emb = self.text_embedding(text_tokens)
text_out, _ = self.text_rnn(text_emb)
text_seq = self.text_proj(text_out)
global_text = torch.mean(text_seq, dim=1)
padding_mask = (text_tokens == 0)
if padding_mask.all():
padding_mask[:, 0] = False
time_emb = self.timestep_embed(timestep) if timestep is not None else torch.zeros(batch_size, f_c, device=image.device)
cond_vec = global_text + time_emb
gamma = torch.clamp(self.film_gamma(cond_vec), -3.0, 3.0)
beta = torch.clamp(self.film_beta(cond_vec), -3.0, 3.0)
gamma = gamma.view(batch_size, f_c, 1, 1).expand_as(img_features)
beta = beta.view(batch_size, f_c, 1, 1).expand_as(img_features)
fused_features = img_features * (1 + gamma) + beta
img_seq = fused_features.flatten(2).transpose(1, 2)
if text_seq is not None:
img_seq_normed = self.norm1(img_seq)
attn_out, _ = self.cross_attn(query=img_seq_normed, key=text_seq, value=text_seq, key_padding_mask=padding_mask)
img_seq = img_seq + attn_out
if torch.isnan(img_seq).any():
img_seq = img_seq_normed
fused_features = img_seq.transpose(1, 2).view(batch_size, f_c, f_h, f_w)
enhanced_features = self.bottleneck(fused_features)
adv_map = self.adversarial_head(enhanced_features)
class_map = self.class_head(enhanced_features)
seam_map = torch.sigmoid(self.seam_quality_head(enhanced_features))
global_pool = F.adaptive_avg_pool2d(enhanced_features, (1, 1)).view(batch_size, -1)
quality_logits = self.quality_head(global_pool)
adv_logits = F.adaptive_max_pool2d(adv_map, (1, 1)).view(batch_size, -1)
class_logits = F.adaptive_max_pool2d(class_map, (1, 1)).view(batch_size, -1)
seam_score = F.adaptive_avg_pool2d(seam_map, (1, 1)).view(batch_size, -1)
relative_adv = self.relative_adv_head(global_pool)
v_img = self.img_faith_proj(global_pool)
v_txt = self.txt_faith_proj(global_text)
v_img = F.normalize(v_img, p=2, dim=1)
v_txt = F.normalize(v_txt, p=2, dim=1)
out = {
'binary_logits': adv_logits,
'class_logits': class_logits,
'quality_logits': quality_logits,
'seam_quality_score': seam_score,
'relative_adv_score': relative_adv,
'img_embed': v_img,
'txt_embed': v_txt
}
if return_features:
out['adversarial_map'] = torch.sigmoid(adv_map)
out['object_heatmaps'] = torch.sigmoid(class_map)
out['seam_quality_map'] = seam_map
out['class_map'] = class_map
return out
def audit_image(model_path, image_path, prompt="", num_classes=4):
"""
Independent plug-and-play function for auditing an image using the standalone model weights.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = SimpleTokenizer(vocab_dir='./tokenizer_vocab')
vocab_size = len(tokenizer.word_to_idx)
model = AdversarialImageAuditor(num_classes=num_classes, vocab_size=vocab_size)
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"Loaded weights from {model_path}")
else:
print(f"Warning: {model_path} not found. Running with random weights.")
model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
text_tokens = tokenizer.encode(prompt).unsqueeze(0).to(device)
timestep = torch.tensor([[0.0]], dtype=torch.float32).to(device)
with torch.no_grad():
outputs = model(image_tensor, text_tokens=text_tokens, timestep=timestep)
binary_prob = torch.sigmoid(outputs['binary_logits']).item()
global_safety_score = 1.0 - binary_prob
class_probs = F.softmax(outputs['class_logits'], dim=1)[0].cpu().numpy()
# We use the generic 4 classes mapping here for the generic auditor
CLASS_NAMES = ['Safe', 'NSFW', 'Gore', 'Weapons']
category_probabilities = {CLASS_NAMES[i]: float(class_probs[i]) for i in range(len(CLASS_NAMES))}
cos_sim = F.cosine_similarity(outputs['img_embed'], outputs['txt_embed'], dim=-1).item()
faithfulness_score = (cos_sim + 1.0) / 2.0
seam_quality = outputs['seam_quality_score'].item()
return {
"global_safety_score": global_safety_score,
"is_adversarial": binary_prob > 0.5,
"category_probabilities": category_probabilities,
"faithfulness_score": faithfulness_score,
"seam_quality": seam_quality,
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Adversarial Image Auditor Inference")
parser.add_argument("--model", type=str, required=True, help="Path to best.pth weights")
parser.add_argument("--image", type=str, required=True, help="Path to internal image")
parser.add_argument("--prompt", type=str, default="", help="Prompt given to the generator")
args = parser.parse_args()
res = audit_image(args.model, args.image, args.prompt)
for k, v in res.items():
if isinstance(v, dict):
print(f"{k}:")
for sub_k, sub_v in v.items():
print(f" {sub_k}: {sub_v:.4f}")
elif isinstance(v, float):
print(f"{k}: {v:.4f}")
else:
print(f"{k}: {v}")