Spaces:
Runtime error
Runtime error
File size: 5,377 Bytes
6a8104b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | # ==============================================================================
# Main Gradio App for DATDA + SDATDA Defense with CNN Classification Ensemble
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import numpy as np
from datda_defense import DATDA, DATDAConfig # primary defense
from datda_index import DATDAIndex # index calculator
from sdatda_defense import SDATDAUltra, SDATDAConfig # secondary ultra-defense
# Import torchvision models
import torchvision.models as models
import torchvision.transforms as T
device = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------------
# Image preprocessing transforms
# ------------------------------
preprocess = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
])
# ------------------------------
# Load candidate CNN models
# ------------------------------
MODEL_CATALOG = {
"resnet50": models.resnet50(weights=models.ResNet50_Weights.DEFAULT).eval().to(device),
"vgg16": models.vgg16(weights=models.VGG16_Weights.DEFAULT).eval().to(device),
"efficientnet_b0": models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).eval().to(device),
"mobilenet_v3_large": models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT).eval().to(device)
}
# Model robustness scores (higher = more robust to adversarial perturbations)
MODEL_ROBUSTNESS = {
"resnet50": 0.8,
"vgg16": 0.6,
"efficientnet_b0": 0.75,
"mobilenet_v3_large": 0.7
}
# ------------------------------
# Initialize defenses and index
# ------------------------------
datda_model = DATDA(DATDAConfig(device=device))
sdatda_model = SDATDAUltra(SDATDAConfig())
index_calc = DATDAIndex()
# ------------------------------
# Ensemble model selection function
# ------------------------------
def ensemble_cnn_predict(datda_index: float, img: torch.Tensor, top_k=2):
"""
Ensemble CNN predictions using robustness and DATDA index as weights.
"""
# Compute HF ratio of image
x_gray = (0.299 * img[0,0] + 0.587 * img[0,1] + 0.114 * img[0,2])
fft = torch.fft.fft2(x_gray)
fft_shift = torch.fft.fftshift(fft)
mag = torch.abs(fft_shift)
H, W = x_gray.shape
center_h, center_w = H//2, W//2
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
dist = torch.sqrt((X - center_w).float()**2 + (Y - center_h).float()**2)
radius = int(min(H, W) * 0.25)
hf_ratio = (mag[dist > radius].sum() / (mag.sum() + 1e-8)).item()
# Compute weights for each model
model_weights = {}
for name, robustness in MODEL_ROBUSTNESS.items():
weight = robustness * (1 - datda_index + hf_ratio) # higher weight = more confident
model_weights[name] = weight
# Pick top_k models for ensemble
selected_models = sorted(model_weights, key=model_weights.get, reverse=True)[:top_k]
# Compute weighted ensemble predictions
logits_sum = None
total_weight = 0.0
for name in selected_models:
model = MODEL_CATALOG[name]
weight = model_weights[name]
with torch.no_grad():
logits = model(img)
if logits_sum is None:
logits_sum = weight * F.softmax(logits, dim=1)
else:
logits_sum += weight * F.softmax(logits, dim=1)
total_weight += weight
ensemble_probs = logits_sum / total_weight
top_prob, top_class = ensemble_probs.max(dim=1)
return top_class.item(), top_prob.item(), selected_models
# ------------------------------
# Main processing pipeline
# ------------------------------
def process_image(img: Image.Image):
# Preprocess to tensor
x_tensor = preprocess(img).unsqueeze(0).to(device)
# Step 1: Initial DATDA Index
index_before = index_calc.compute(x_tensor)
# Step 2: Run DATDA defense
x_defended = datda_model(x_tensor)
index_after = index_calc.compute(x_defended)
# Step 3: If still high, run SDATDAUltra
if index_after > 0.3: # threshold can be tuned
x_defended_img, index_final = sdatda_model.purify(img)
x_defended = preprocess(x_defended_img).unsqueeze(0).to(device)
else:
index_final = index_after
# Step 4: Ensemble CNN classification
top_class, top_prob, selected_models = ensemble_cnn_predict(index_final, x_defended)
# Convert back to PIL for display
img_defended = T.ToPILImage()(x_defended[0].cpu())
return img_defended, float(index_before), float(index_final), ", ".join(selected_models), float(top_prob)
# ------------------------------
# Gradio UI
# ------------------------------
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Defended Image"),
gr.Number(label="Initial DATDA Index"),
gr.Number(label="Final DATDA Index"),
gr.Textbox(label="Selected CNN Models (Ensemble)"),
gr.Number(label="Classification Confidence")
],
title="DATDA + SDATDA Defense with CNN Ensemble",
description="Upload an image. The system applies the primary DATDA defense, optionally the SDATDAUltra defense, then uses an ensemble of CNNs selected adaptively to classify the image."
)
if __name__ == "__main__":
iface.launch()
|