# ============================================================================== # Main Gradio App for DATDA + SDATDA Defense with CNN Classification Ensemble # ============================================================================== import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr import numpy as np from datda_defense import DATDA, DATDAConfig # primary defense from datda_index import DATDAIndex # index calculator from sdatda_defense import SDATDAUltra, SDATDAConfig # secondary ultra-defense # Import torchvision models import torchvision.models as models import torchvision.transforms as T device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------------ # Image preprocessing transforms # ------------------------------ preprocess = T.Compose([ T.Resize((224, 224)), T.ToTensor(), ]) # ------------------------------ # Load candidate CNN models # ------------------------------ MODEL_CATALOG = { "resnet50": models.resnet50(weights=models.ResNet50_Weights.DEFAULT).eval().to(device), "vgg16": models.vgg16(weights=models.VGG16_Weights.DEFAULT).eval().to(device), "efficientnet_b0": models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).eval().to(device), "mobilenet_v3_large": models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT).eval().to(device) } # Model robustness scores (higher = more robust to adversarial perturbations) MODEL_ROBUSTNESS = { "resnet50": 0.8, "vgg16": 0.6, "efficientnet_b0": 0.75, "mobilenet_v3_large": 0.7 } # ------------------------------ # Initialize defenses and index # ------------------------------ datda_model = DATDA(DATDAConfig(device=device)) sdatda_model = SDATDAUltra(SDATDAConfig()) index_calc = DATDAIndex() # ------------------------------ # Ensemble model selection function # ------------------------------ def ensemble_cnn_predict(datda_index: float, img: torch.Tensor, top_k=2): """ Ensemble CNN predictions using robustness and DATDA index as weights. """ # Compute HF ratio of image x_gray = (0.299 * img[0,0] + 0.587 * img[0,1] + 0.114 * img[0,2]) fft = torch.fft.fft2(x_gray) fft_shift = torch.fft.fftshift(fft) mag = torch.abs(fft_shift) H, W = x_gray.shape center_h, center_w = H//2, W//2 Y, X = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') dist = torch.sqrt((X - center_w).float()**2 + (Y - center_h).float()**2) radius = int(min(H, W) * 0.25) hf_ratio = (mag[dist > radius].sum() / (mag.sum() + 1e-8)).item() # Compute weights for each model model_weights = {} for name, robustness in MODEL_ROBUSTNESS.items(): weight = robustness * (1 - datda_index + hf_ratio) # higher weight = more confident model_weights[name] = weight # Pick top_k models for ensemble selected_models = sorted(model_weights, key=model_weights.get, reverse=True)[:top_k] # Compute weighted ensemble predictions logits_sum = None total_weight = 0.0 for name in selected_models: model = MODEL_CATALOG[name] weight = model_weights[name] with torch.no_grad(): logits = model(img) if logits_sum is None: logits_sum = weight * F.softmax(logits, dim=1) else: logits_sum += weight * F.softmax(logits, dim=1) total_weight += weight ensemble_probs = logits_sum / total_weight top_prob, top_class = ensemble_probs.max(dim=1) return top_class.item(), top_prob.item(), selected_models # ------------------------------ # Main processing pipeline # ------------------------------ def process_image(img: Image.Image): # Preprocess to tensor x_tensor = preprocess(img).unsqueeze(0).to(device) # Step 1: Initial DATDA Index index_before = index_calc.compute(x_tensor) # Step 2: Run DATDA defense x_defended = datda_model(x_tensor) index_after = index_calc.compute(x_defended) # Step 3: If still high, run SDATDAUltra if index_after > 0.3: # threshold can be tuned x_defended_img, index_final = sdatda_model.purify(img) x_defended = preprocess(x_defended_img).unsqueeze(0).to(device) else: index_final = index_after # Step 4: Ensemble CNN classification top_class, top_prob, selected_models = ensemble_cnn_predict(index_final, x_defended) # Convert back to PIL for display img_defended = T.ToPILImage()(x_defended[0].cpu()) return img_defended, float(index_before), float(index_final), ", ".join(selected_models), float(top_prob) # ------------------------------ # Gradio UI # ------------------------------ iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), outputs=[ gr.Image(type="pil", label="Defended Image"), gr.Number(label="Initial DATDA Index"), gr.Number(label="Final DATDA Index"), gr.Textbox(label="Selected CNN Models (Ensemble)"), gr.Number(label="Classification Confidence") ], title="DATDA + SDATDA Defense with CNN Ensemble", description="Upload an image. The system applies the primary DATDA defense, optionally the SDATDAUltra defense, then uses an ensemble of CNNs selected adaptively to classify the image." ) if __name__ == "__main__": iface.launch()