| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision.models as models |
| | import torchvision.transforms as transforms |
| | from torchvision.models.resnet import ResNet50_Weights |
| | from PIL import Image |
| | import numpy as np |
| | import os |
| | import requests |
| | import time |
| | import copy |
| | from collections import OrderedDict |
| | from pathlib import Path |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | device = torch.device("mps") |
| | else: |
| | device = torch.device("cpu") |
| | print(f"Using device: {device}") |
| |
|
| | |
| | MODEL_URLS = { |
| | |
| | 'resnet50_robust': 'https://huggingface.co/ttoosi/robust_resnet50_eps_3.00_epoch_best/resolve/main/resnet50_robust.pt', |
| | 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt', |
| | 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/100_checkpoint.pt' |
| | } |
| |
|
| | IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| | IMAGENET_STD = [0.229, 0.224, 0.225] |
| |
|
| | |
| | def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD): |
| | if normalize: |
| | return transforms.Compose([ |
| | transforms.Resize(input_size), |
| | transforms.CenterCrop(input_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize(norm_mean, norm_std), |
| | ]) |
| | else: |
| | return transforms.Compose([ |
| | transforms.Resize(input_size), |
| | transforms.CenterCrop(input_size), |
| | transforms.ToTensor(), |
| | ]) |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize(224), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | ]) |
| |
|
| | normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) |
| |
|
| |
|
| | def create_flat_top_gaussian_mask(image_shape, center_x, center_y, flat_radius, sigma, max_multiplier=4.0, min_multiplier=1.0, device=None): |
| | """ |
| | Create a flat-top Gaussian epsilon/step-size mask for adaptive constraint application. |
| | Coordinates are normalized: -1 = left/top edge, 1 = right/bottom edge, 0 = center. |
| | |
| | Args: |
| | image_shape: (batch, channels, height, width) |
| | center_x: Center x (normalized -1 to 1) |
| | center_y: Center y (normalized -1 to 1) |
| | flat_radius: Radius of flat-top region (normalized) |
| | sigma: Std for Gaussian fall-off outside flat region |
| | max_multiplier: Multiplier inside flat region |
| | min_multiplier: Multiplier far from center |
| | device: torch device for the output tensor |
| | Returns: |
| | Tensor (1, 1, H, W) |
| | """ |
| | _, _, H, W = image_shape |
| | if device is None: |
| | device = torch.device("cpu") |
| | y_coords = torch.linspace(-1, 1, H, device=device) |
| | x_coords = torch.linspace(-1, 1, W, device=device) |
| | y_coords, x_coords = torch.meshgrid(y_coords, x_coords, indexing='ij') |
| | dist = torch.sqrt((x_coords - center_x) ** 2 + (y_coords - center_y) ** 2) |
| | multiplier = torch.where( |
| | dist <= flat_radius, |
| | torch.full_like(dist, max_multiplier), |
| | min_multiplier + (max_multiplier - min_multiplier) * torch.exp(-((dist - flat_radius) ** 2) / (2 * sigma ** 2)) |
| | ) |
| | return multiplier.unsqueeze(0).unsqueeze(0) |
| |
|
| |
|
| | def extract_middle_layers(model, layer_index): |
| | """ |
| | Extract a subset of the model up to a specific layer. |
| | |
| | Args: |
| | model: The neural network model |
| | layer_index: String 'all' for the full model, or a layer identifier (string or int) |
| | For ResNet: integers 0-8 representing specific layers |
| | For ViT: strings like 'encoder.layers.encoder_layer_3' |
| | |
| | Returns: |
| | A modified model that outputs features from the specified layer |
| | """ |
| | if isinstance(layer_index, str) and layer_index == 'all': |
| | return model |
| | |
| | |
| | if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'): |
| | try: |
| | target_layer_idx = int(layer_index.split('_')[-1]) |
| | |
| | |
| | new_model = copy.deepcopy(model) |
| | |
| | |
| | if hasattr(new_model, 'module'): |
| | |
| | encoder_layers = nn.Sequential() |
| | for i in range(target_layer_idx + 1): |
| | layer_name = f"encoder_layer_{i}" |
| | if hasattr(new_model.module.encoder.layers, layer_name): |
| | encoder_layers.add_module(layer_name, |
| | getattr(new_model.module.encoder.layers, layer_name)) |
| | |
| | |
| | new_model.module.encoder.layers = encoder_layers |
| | |
| | |
| | new_model.module.heads = nn.Identity() |
| | |
| | return new_model |
| | else: |
| | |
| | encoder_layers = nn.Sequential() |
| | for i in range(target_layer_idx + 1): |
| | layer_name = f"encoder_layer_{i}" |
| | if hasattr(new_model.encoder.layers, layer_name): |
| | encoder_layers.add_module(layer_name, |
| | getattr(new_model.encoder.layers, layer_name)) |
| | |
| | |
| | new_model.encoder.layers = encoder_layers |
| | |
| | |
| | new_model.heads = nn.Identity() |
| | |
| | return new_model |
| | |
| | except (ValueError, IndexError) as e: |
| | raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}") |
| | |
| | |
| | elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')): |
| | |
| | base_model = model.module if hasattr(model, 'module') else model |
| | |
| | |
| | new_model = copy.deepcopy(model) |
| | base_new_model = new_model.module if hasattr(new_model, 'module') else new_model |
| | |
| | |
| | if isinstance(layer_index, int): |
| | |
| | base_new_model.blocks = base_new_model.blocks[:layer_index+1] |
| | |
| | return new_model |
| | |
| | else: |
| | |
| | modules = list(model.named_children()) |
| | print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}") |
| | |
| | cutoff_idx = next((i for i, (name, _) in enumerate(modules) |
| | if name == str(layer_index)), None) |
| | |
| | if cutoff_idx is not None: |
| | |
| | new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1])) |
| | return new_model |
| | else: |
| | raise ValueError(f"Module {layer_index} not found in model") |
| |
|
| | |
| | def get_imagenet_labels(): |
| | url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" |
| | response = requests.get(url) |
| | if response.status_code == 200: |
| | return response.json() |
| | else: |
| | raise RuntimeError("Failed to fetch ImageNet labels") |
| |
|
| | |
| | def download_model(model_type): |
| | if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: |
| | return None |
| | |
| | |
| | if model_type == 'resnet50_robust_face': |
| | model_path = Path("models/resnet50_robust_face_100_checkpoint.pt") |
| | else: |
| | model_path = Path(f"models/{model_type}.pt") |
| | |
| | if not model_path.exists(): |
| | print(f"Downloading {model_type} model...") |
| | url = MODEL_URLS[model_type] |
| | response = requests.get(url, stream=True) |
| | if response.status_code == 200: |
| | with open(model_path, 'wb') as f: |
| | for chunk in response.iter_content(chunk_size=8192): |
| | f.write(chunk) |
| | print(f"Model downloaded and saved to {model_path}") |
| | else: |
| | raise RuntimeError(f"Failed to download model: {response.status_code}") |
| | return model_path |
| |
|
| | class NormalizeByChannelMeanStd(nn.Module): |
| | def __init__(self, mean, std): |
| | super(NormalizeByChannelMeanStd, self).__init__() |
| | if not isinstance(mean, torch.Tensor): |
| | mean = torch.tensor(mean) |
| | if not isinstance(std, torch.Tensor): |
| | std = torch.tensor(std) |
| | self.register_buffer("mean", mean) |
| | self.register_buffer("std", std) |
| | |
| | def forward(self, tensor): |
| | return self.normalize_fn(tensor, self.mean, self.std) |
| | |
| | def normalize_fn(self, tensor, mean, std): |
| | """Differentiable version of torchvision.functional.normalize""" |
| | |
| | mean = mean[None, :, None, None] |
| | std = std[None, :, None, None] |
| | return tensor.sub(mean).div(std) |
| |
|
| | class InferStep: |
| | def __init__(self, orig_image, eps, step_size, adaptive_eps_config=None, adaptive_step_config=None): |
| | self.orig_image = orig_image |
| | dev = orig_image.device |
| |
|
| | |
| | if adaptive_eps_config and adaptive_eps_config.get("enabled", False): |
| | self.use_adaptive_eps = True |
| | base_eps = float(adaptive_eps_config.get("base_epsilon", eps)) |
| | self.eps_map = create_flat_top_gaussian_mask( |
| | orig_image.shape, |
| | float(adaptive_eps_config.get("center_x", 0.0)), |
| | float(adaptive_eps_config.get("center_y", 0.0)), |
| | float(adaptive_eps_config.get("flat_radius", 0.3)), |
| | float(adaptive_eps_config.get("sigma", 0.2)), |
| | float(adaptive_eps_config.get("max_multiplier", 4.0)), |
| | float(adaptive_eps_config.get("min_multiplier", 1.0)), |
| | device=dev, |
| | ) * base_eps |
| | self.eps_map = self.eps_map.to(dev) |
| | else: |
| | self.use_adaptive_eps = False |
| | self.eps = eps |
| |
|
| | |
| | if adaptive_step_config and adaptive_step_config.get("enabled", False): |
| | self.use_adaptive_step_size = True |
| | base_step = float(adaptive_step_config.get("base_step_size", step_size)) |
| | self.step_size_map = create_flat_top_gaussian_mask( |
| | orig_image.shape, |
| | float(adaptive_step_config.get("center_x", 0.0)), |
| | float(adaptive_step_config.get("center_y", 0.0)), |
| | float(adaptive_step_config.get("flat_radius", 0.3)), |
| | float(adaptive_step_config.get("sigma", 0.2)), |
| | float(adaptive_step_config.get("max_multiplier", 4.0)), |
| | float(adaptive_step_config.get("min_multiplier", 1.0)), |
| | device=dev, |
| | ) * base_step |
| | self.step_size_map = self.step_size_map.to(dev) |
| | else: |
| | self.use_adaptive_step_size = False |
| | self.step_size = step_size |
| |
|
| | def project(self, x): |
| | |
| | |
| | diff = x - self.orig_image |
| | color_norm = torch.norm(diff, dim=1, keepdim=True) |
| | if self.use_adaptive_eps: |
| | |
| | scale = torch.clamp(self.eps_map / (color_norm + 1e-10), max=1.0) |
| | else: |
| | scale = torch.clamp(self.eps / (color_norm + 1e-10), max=1.0) |
| | diff = diff * scale |
| | return torch.clamp(self.orig_image + diff, 0, 1) |
| |
|
| | def step(self, x, grad): |
| | |
| | l = len(x.shape) - 1 |
| | grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1] * l)) |
| | scaled_grad = grad / (grad_norm + 1e-10) |
| | if self.use_adaptive_step_size: |
| | return scaled_grad * self.step_size_map |
| | return scaled_grad * self.step_size |
| |
|
| | def get_iterations_to_show(n_itr): |
| | """Generate a dynamic list of iterations to show based on total iterations.""" |
| | if n_itr <= 50: |
| | return [1, 5, 10, 20, 30, 40, 50, n_itr] |
| | elif n_itr <= 100: |
| | return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr] |
| | elif n_itr <= 200: |
| | return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr] |
| | elif n_itr <= 500: |
| | return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr] |
| | else: |
| | |
| | return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, |
| | int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr] |
| |
|
| | def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0): |
| | """Generate inference configuration with customizable parameters. |
| | |
| | Args: |
| | inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion') |
| | eps (float): Maximum perturbation size |
| | n_itr (int): Number of iterations |
| | step_size (float): Step size for each iteration |
| | """ |
| | |
| | |
| | config = { |
| | 'loss_infer': inference_type, |
| | 'n_itr': n_itr, |
| | 'eps': eps, |
| | 'step_size': step_size, |
| | 'diffusion_noise_ratio': 0.0, |
| | 'initial_inference_noise_ratio': 0.0, |
| | 'top_layer': 'all', |
| | 'inference_normalization': 'off', |
| | 'recognition_normalization': 'off', |
| | 'iterations_to_show': get_iterations_to_show(n_itr), |
| | 'misc_info': {'keep_grads': False}, |
| | 'biased_inference': {'enable': False, 'class': None} |
| | } |
| | |
| | |
| | if inference_type == 'IncreaseConfidence': |
| | config['loss_function'] = 'CE' |
| | |
| | elif inference_type == 'Prior-Guided Drift Diffusion': |
| | config['loss_function'] = 'MSE' |
| | config['initial_inference_noise_ratio'] = 0.05 |
| | config['diffusion_noise_ratio'] = 0.01 |
| | |
| | elif inference_type == 'GradModulation': |
| | config['loss_function'] = 'CE' |
| | config['misc_info']['grad_modulation'] = 0.5 |
| | |
| | elif inference_type == 'CompositionalFusion': |
| | config['loss_function'] = 'CE' |
| | config['misc_info']['positive_classes'] = [] |
| | config['misc_info']['negative_classes'] = [] |
| | |
| | return config |
| |
|
| | class GenerativeInferenceModel: |
| | def __init__(self): |
| | self.models = {} |
| | self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device) |
| | self.labels = get_imagenet_labels() |
| | |
| | def verify_model_integrity(self, model, model_type): |
| | """ |
| | Verify model integrity by running a test input through it. |
| | Returns whether the model passes basic integrity check. |
| | """ |
| | try: |
| | print(f"\n=== Running model integrity check for {model_type} ===") |
| | |
| | test_input = torch.zeros(1, 3, 224, 224, device=device) |
| | test_input[0, 0, 100:124, 100:124] = 0.5 |
| | |
| | |
| | with torch.no_grad(): |
| | output = model(test_input) |
| | |
| | |
| | if output.shape != (1, 1000): |
| | print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)") |
| | return False |
| | |
| | |
| | probs = torch.nn.functional.softmax(output, dim=1) |
| | confidence, prediction = torch.max(probs, 1) |
| | |
| | |
| | mean = output.mean().item() |
| | std = output.std().item() |
| | min_val = output.min().item() |
| | max_val = output.max().item() |
| | |
| | print(f"Model integrity check results:") |
| | print(f"- Output shape: {output.shape}") |
| | print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence") |
| | print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}") |
| | |
| | |
| | if torch.isnan(output).any(): |
| | print("❌ Model produced NaN outputs") |
| | return False |
| | |
| | if output.std().item() < 0.1: |
| | print("⚠️ Low output variance, model may not be discriminative") |
| | |
| | print("✅ Model passes basic integrity check") |
| | return True |
| | |
| | except Exception as e: |
| | print(f"❌ Model integrity check failed with error: {e}") |
| | |
| | return True |
| | |
| | def load_model(self, model_type): |
| | """Load model from checkpoint or use pretrained model.""" |
| | if model_type in self.models: |
| | print(f"Using cached {model_type} model") |
| | return self.models[model_type] |
| | |
| | |
| | start_time = time.time() |
| | model_path = download_model(model_type) |
| | |
| | |
| | resnet = models.resnet50() |
| | model = nn.Sequential( |
| | self.normalizer, |
| | resnet |
| | ) |
| | |
| | |
| | if model_path: |
| | print(f"Loading {model_type} model from {model_path}...") |
| | try: |
| | checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| | |
| | |
| | print("\n=== Analyzing checkpoint structure ===") |
| | if isinstance(checkpoint, dict): |
| | print(f"Checkpoint contains keys: {list(checkpoint.keys())}") |
| | |
| | |
| | if 'model' in checkpoint and isinstance(checkpoint['model'], dict): |
| | model_dict = checkpoint['model'] |
| | |
| | first_keys = list(model_dict.keys())[:5] |
| | print(f"'model' contains keys like: {first_keys}") |
| | |
| | |
| | prefixes = set() |
| | for key in list(model_dict.keys())[:100]: |
| | parts = key.split('.') |
| | if len(parts) > 1: |
| | prefixes.add(parts[0]) |
| | if prefixes: |
| | print(f"Common prefixes in model dict: {prefixes}") |
| | else: |
| | print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}") |
| | |
| | |
| | if 'model' in checkpoint: |
| | |
| | state_dict = checkpoint['model'] |
| | print("Using 'model' key from checkpoint") |
| | elif 'state_dict' in checkpoint: |
| | state_dict = checkpoint['state_dict'] |
| | print("Using 'state_dict' key from checkpoint") |
| | else: |
| | |
| | state_dict = checkpoint |
| | print("Using checkpoint directly as state_dict") |
| | |
| | |
| | resnet_state_dict = {} |
| | prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.'] |
| | resnet_keys = set(resnet.state_dict().keys()) |
| | |
| | |
| | print("\n=== Phase 1: Checking for specific model structures ===") |
| | |
| | |
| | module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')] |
| | if module_model_keys: |
| | print(f"Found 'module.model' structure with {len(module_model_keys)} parameters") |
| | |
| | for source_key, value in state_dict.items(): |
| | if source_key.startswith('module.model.'): |
| | target_key = source_key[len('module.model.'):] |
| | resnet_state_dict[target_key] = value |
| | |
| | print(f"Extracted {len(resnet_state_dict)} parameters from module.model") |
| | |
| | |
| | attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')] |
| | if attacker_model_keys: |
| | print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters") |
| | |
| | for source_key, value in state_dict.items(): |
| | if source_key.startswith('attacker.model.'): |
| | target_key = source_key[len('attacker.model.'):] |
| | resnet_state_dict[target_key] = value |
| | |
| | print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model") |
| | |
| | |
| | model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')] |
| | if model_keys and len(resnet_state_dict) < len(resnet_keys): |
| | print(f"Found additional 'model.' structure with {len(model_keys)} parameters") |
| | |
| | for source_key, value in state_dict.items(): |
| | if source_key.startswith('model.'): |
| | target_key = source_key[len('model.'):] |
| | if target_key in resnet_keys and target_key not in resnet_state_dict: |
| | resnet_state_dict[target_key] = value |
| | |
| | else: |
| | |
| | structure_found = False |
| | |
| | |
| | model_keys = [key for key in state_dict.keys() if key.startswith('model.')] |
| | if model_keys: |
| | print(f"Found 'model.' structure with {len(model_keys)} parameters") |
| | for source_key, value in state_dict.items(): |
| | if source_key.startswith('model.'): |
| | target_key = source_key[len('model.'):] |
| | resnet_state_dict[target_key] = value |
| | structure_found = True |
| | |
| | |
| | top_level_resnet_keys = 0 |
| | for key in resnet_keys: |
| | if key in state_dict: |
| | top_level_resnet_keys += 1 |
| | |
| | if top_level_resnet_keys > 0: |
| | print(f"Found {top_level_resnet_keys} ResNet parameters at top level") |
| | for target_key in resnet_keys: |
| | if target_key in state_dict: |
| | resnet_state_dict[target_key] = state_dict[target_key] |
| | structure_found = True |
| | |
| | |
| | if not structure_found: |
| | print("No standard model structure found, trying prefix mappings...") |
| | for target_key in resnet_keys: |
| | for prefix in prefixes_to_try: |
| | source_key = prefix + target_key |
| | if source_key in state_dict: |
| | resnet_state_dict[target_key] = state_dict[source_key] |
| | break |
| | |
| | |
| | if len(resnet_state_dict) < len(resnet_keys): |
| | print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...") |
| | |
| | |
| | prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']} |
| | layer_matches = {} |
| | |
| | |
| | for key in resnet_keys: |
| | layer_name = key.split('.')[0] if '.' in key else key |
| | if layer_name not in layer_matches: |
| | layer_matches[layer_name] = {'total': 0, 'matched': 0} |
| | layer_matches[layer_name]['total'] += 1 |
| | |
| | |
| | for source_key, value in state_dict.items(): |
| | |
| | target_key = source_key |
| | matched_prefix = None |
| | |
| | |
| | for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']: |
| | if source_key.startswith(prefix): |
| | target_key = source_key[len(prefix):] |
| | matched_prefix = prefix |
| | break |
| | |
| | |
| | if target_key in resnet_keys and target_key not in resnet_state_dict: |
| | resnet_state_dict[target_key] = value |
| | |
| | |
| | if matched_prefix: |
| | prefix_matches[matched_prefix] += 1 |
| | |
| | |
| | layer_name = target_key.split('.')[0] if '.' in target_key else target_key |
| | if layer_name in layer_matches: |
| | layer_matches[layer_name]['matched'] += 1 |
| | |
| | |
| | print("\n=== Prefix Removal Statistics ===") |
| | total_matches = sum(prefix_matches.values()) |
| | print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)") |
| | |
| | |
| | print("\nMatches by prefix:") |
| | for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True): |
| | if count > 0: |
| | print(f" {prefix}: {count} parameters") |
| | |
| | |
| | print("\nMatches by layer type:") |
| | for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True): |
| | match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0 |
| | print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)") |
| | |
| | |
| | critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc'] |
| | print("\nStatus of critical layers:") |
| | for layer in critical_layers: |
| | if layer in layer_matches: |
| | match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100 |
| | status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE" |
| | print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}") |
| | else: |
| | print(f" {layer}: Not found in model") |
| | |
| | |
| | if resnet_state_dict: |
| | try: |
| | |
| | result = resnet.load_state_dict(resnet_state_dict, strict=False) |
| | missing_keys, unexpected_keys = result |
| | |
| | |
| | loading_report = [] |
| | loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====") |
| | loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}") |
| | loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}") |
| | loading_report.append(f"Missing keys: {len(missing_keys):,} parameters") |
| | loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters") |
| |
|
| | |
| | loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys) |
| | loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100 |
| | |
| | |
| | if loaded_percent >= 99.5: |
| | status = "✅ COMPLETE - All important parameters loaded" |
| | elif loaded_percent >= 90: |
| | status = "🟡 PARTIAL - Most parameters loaded, should still function" |
| | elif loaded_percent >= 50: |
| | status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly" |
| | else: |
| | status = "❌ FAILED - Critical parameters missing, will not function properly" |
| | |
| | loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)") |
| | loading_report.append(f"Loading status: {status}") |
| | |
| | |
| | if loaded_percent < 50: |
| | loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.") |
| | loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.") |
| | |
| | |
| | resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
| | model = nn.Sequential(self.normalizer, resnet) |
| | loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model") |
| | |
| | |
| | if missing_keys: |
| | loading_report.append("\nMissing keys by layer type:") |
| | layer_types = {} |
| | for key in missing_keys: |
| | |
| | parts = key.split('.') |
| | if len(parts) > 0: |
| | layer_type = parts[0] |
| | if layer_type not in layer_types: |
| | layer_types[layer_type] = 0 |
| | layer_types[layer_type] += 1 |
| | |
| | |
| | for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True): |
| | loading_report.append(f" {layer_type}: {count:,} parameters") |
| | |
| | loading_report.append("\nFirst 10 missing keys:") |
| | for i, key in enumerate(sorted(missing_keys)[:10]): |
| | loading_report.append(f" {i+1}. {key}") |
| | |
| | |
| | if unexpected_keys: |
| | loading_report.append("\nFirst 10 unexpected keys:") |
| | for i, key in enumerate(sorted(unexpected_keys)[:10]): |
| | loading_report.append(f" {i+1}. {key}") |
| | |
| | loading_report.append("========================================") |
| | |
| | |
| | report_text = "\n".join(loading_report) |
| | print(report_text) |
| | |
| | |
| | os.makedirs("logs", exist_ok=True) |
| | with open(f"logs/model_loading_{model_type}.log", "w") as f: |
| | f.write(report_text) |
| | |
| | |
| | if any(key.startswith('attacker.normalize.') for key in state_dict.keys()): |
| | norm_state_dict = {} |
| | for key, value in state_dict.items(): |
| | if key.startswith('attacker.normalize.'): |
| | norm_key = key[len('attacker.normalize.'):] |
| | norm_state_dict[norm_key] = value |
| | |
| | if norm_state_dict: |
| | try: |
| | self.normalizer.load_state_dict(norm_state_dict, strict=False) |
| | print("Successfully loaded normalizer parameters") |
| | except Exception as e: |
| | print(f"Warning: Could not load normalizer parameters: {e}") |
| | except Exception as e: |
| | print(f"Warning: Error loading ResNet parameters: {e}") |
| | |
| | model = resnet |
| | except Exception as e: |
| | print(f"Error loading model checkpoint: {e}") |
| | |
| | print("Falling back to PyTorch's pretrained model") |
| | resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
| | model = nn.Sequential(self.normalizer, resnet) |
| | else: |
| | |
| | print("No checkpoint available, using PyTorch's pretrained model") |
| | resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
| | model = nn.Sequential(self.normalizer, resnet) |
| | |
| | model = model.to(device) |
| | model.eval() |
| | |
| | |
| | self.verify_model_integrity(model, model_type) |
| | |
| | |
| | self.models[model_type] = model |
| | end_time = time.time() |
| | load_time = end_time - start_time |
| | print(f"Model {model_type} loaded in {load_time:.2f} seconds") |
| | return model |
| | |
| | def inference(self, image, model_type, config): |
| | """Run generative inference on the image.""" |
| | |
| | inference_start = time.time() |
| | |
| | |
| | model = self.load_model(model_type) |
| | |
| | |
| | if isinstance(image, str): |
| | if os.path.exists(image): |
| | image = Image.open(image).convert('RGB') |
| | else: |
| | raise ValueError(f"Image path does not exist: {image}") |
| | elif isinstance(image, torch.Tensor): |
| | raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor") |
| | |
| | |
| | load_start = time.time() |
| | use_norm = config['inference_normalization'] == 'on' |
| | custom_transform = get_transform( |
| | input_size=224, |
| | normalize=use_norm, |
| | norm_mean=IMAGENET_MEAN, |
| | norm_std=IMAGENET_STD |
| | ) |
| | |
| | |
| | if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']: |
| | grad_modulation = config['misc_info']['grad_modulation'] |
| | image_tensor = custom_transform(image).unsqueeze(0).to(device) |
| | image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device) |
| | else: |
| | image_tensor = custom_transform(image).unsqueeze(0).to(device) |
| | |
| | image_tensor.requires_grad = True |
| | print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds") |
| | |
| | |
| | is_sequential = isinstance(model, nn.Sequential) |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): |
| | print("Model is sequential with normalization") |
| | |
| | core_model = model[1] |
| | if config['inference_normalization'] == 'on': |
| | output_original = model(image_tensor) |
| | else: |
| | output_original = core_model(image_tensor) |
| | |
| | else: |
| | print("Model is not sequential with normalization") |
| | |
| | if config['inference_normalization'] == 'on': |
| | normalized_tensor = normalize_transform(image_tensor) |
| | output_original = model(normalized_tensor) |
| | else: |
| | output_original = model(image_tensor) |
| | core_model = model |
| | |
| | probs_orig = F.softmax(output_original, dim=1) |
| | conf_orig, classes_orig = torch.max(probs_orig, 1) |
| | |
| | |
| | _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False) |
| | |
| | |
| | adaptive_eps = config.get("adaptive_epsilon") |
| | adaptive_step = config.get("adaptive_step_size") |
| | if adaptive_eps and adaptive_eps.get("enabled"): |
| | print(f"Applying adaptive epsilon mask: center=({adaptive_eps.get('center_x'):.2f}, {adaptive_eps.get('center_y'):.2f}), radius={adaptive_eps.get('flat_radius'):.2f}") |
| | if adaptive_step and adaptive_step.get("enabled"): |
| | print(f"Applying adaptive step-size mask: center=({adaptive_step.get('center_x'):.2f}, {adaptive_step.get('center_y'):.2f}), radius={adaptive_step.get('flat_radius'):.2f}") |
| | infer_step = InferStep( |
| | image_tensor, |
| | config["eps"], |
| | config["step_size"], |
| | adaptive_eps_config=adaptive_eps, |
| | adaptive_step_config=adaptive_step, |
| | ) |
| | |
| | |
| | biased_inference_config = config.get('biased_inference', {'enable': False, 'class': None}) |
| | biased_class_index = None |
| | biased_class_tensor = None |
| | if biased_inference_config.get('enable', False): |
| | class_name = biased_inference_config.get('class') or biased_inference_config.get('class_name') |
| | if class_name: |
| | try: |
| | biased_class_index = next( |
| | i for i, label in enumerate(self.labels) |
| | if label.lower() == class_name.lower() |
| | ) |
| | biased_class_tensor = torch.tensor( |
| | [biased_class_index], device=device, dtype=torch.long |
| | ) |
| | print(f"Biased inference: biasing toward class '{self.labels[biased_class_index]}' (index {biased_class_index})") |
| | except StopIteration: |
| | raise ValueError( |
| | f"biased_inference class '{class_name}' not found in ImageNet simple labels. " |
| | "Use a label from imagenet-simple-labels (e.g. 'goldfish', 'tabby cat')." |
| | ) |
| | else: |
| | print("Biased inference enabled but no class specified; ignoring.") |
| | |
| | |
| | |
| | x = image_tensor.clone().detach().requires_grad_(True) |
| | all_steps = [image_tensor[0].detach().cpu()] |
| | |
| | |
| | noisy_features = None |
| | layer_model = None |
| | if config['loss_infer'] == 'Prior-Guided Drift Diffusion': |
| | print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...") |
| | |
| | |
| | try: |
| | |
| | base_model = model |
| | |
| | |
| | if hasattr(base_model, 'module'): |
| | base_model = base_model.module |
| | |
| | |
| | print(f"DEBUG - Initial model structure: {type(base_model)}") |
| | |
| | |
| | if isinstance(base_model, nn.Sequential): |
| | print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children") |
| | |
| | |
| | if len(list(base_model.children())) >= 2: |
| | |
| | actual_model = list(base_model.children())[1] |
| | print(f"DEBUG - Using ResNet component: {type(actual_model)}") |
| | print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}") |
| | |
| | |
| | layer_model = extract_middle_layers(actual_model, config['top_layer']) |
| | else: |
| | |
| | layer_model = extract_middle_layers(base_model, config['top_layer']) |
| | else: |
| | |
| | print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}") |
| | layer_model = extract_middle_layers(base_model, config['top_layer']) |
| | |
| | print(f"Successfully extracted model up to layer: {config['top_layer']}") |
| | except ValueError as e: |
| | print(f"Layer extraction failed: {e}. Using full model.") |
| | layer_model = model |
| | |
| | |
| | added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device) |
| | noisy_image_tensor = image_tensor + added_noise |
| | |
| | |
| | noisy_features = layer_model(noisy_image_tensor) |
| | |
| | print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}") |
| | |
| | |
| | print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...") |
| | loop_start = time.time() |
| | for i in range(config['n_itr']): |
| | |
| | x.grad = None |
| | |
| | |
| | if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None: |
| | |
| | |
| | output = layer_model(x) |
| | else: |
| | |
| | |
| | output = model(x) |
| | |
| | |
| | try: |
| | if config['loss_infer'] == 'Prior-Guided Drift Diffusion': |
| | |
| | assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE" |
| | if noisy_features is not None: |
| | loss = F.mse_loss(output, noisy_features) |
| | grad = torch.autograd.grad(loss, x)[0] |
| | else: |
| | raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion") |
| | |
| | else: |
| | |
| | num_classes = min(10, least_confident_classes.size(1)) |
| | target_classes = least_confident_classes[0, :num_classes] |
| | |
| | |
| | targets = torch.tensor([idx.item() for idx in target_classes], device=device) |
| | |
| | |
| | loss = 0 |
| | for target in targets: |
| | |
| | one_hot = torch.zeros_like(output) |
| | one_hot[0, target] = 1 |
| | |
| | loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot) |
| | |
| | grad = torch.autograd.grad(loss, x, retain_graph=True)[0] |
| | |
| | |
| | if biased_inference_config.get('enable', False) and biased_class_tensor is not None: |
| | output_full = model(x) |
| | loss_biased = F.cross_entropy(output_full, biased_class_tensor) |
| | grad_biased = torch.autograd.grad(loss_biased, x)[0] |
| | grad = grad - grad_biased |
| | |
| | if grad is None: |
| | print("Warning: Direct gradient calculation failed") |
| | |
| | random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] |
| | x = infer_step.project(x + random_noise) |
| | else: |
| | |
| | adjusted_grad = infer_step.step(x, grad) |
| | |
| | |
| | diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device) |
| | |
| | |
| | x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise) |
| | |
| | except Exception as e: |
| | print(f"Error in gradient calculation: {e}") |
| | |
| | random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] |
| | x = infer_step.project(x.clone() + random_noise) |
| | |
| | |
| | if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']: |
| | all_steps.append(x[0].detach().cpu()) |
| | |
| | |
| | with torch.no_grad(): |
| | if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): |
| | if config['inference_normalization'] == 'on': |
| | final_output = model(x) |
| | else: |
| | final_output = core_model(x) |
| | else: |
| | if config['inference_normalization'] == 'on': |
| | normalized_x = normalize_transform(x) |
| | final_output = model(normalized_x) |
| | else: |
| | final_output = model(x) |
| | |
| | final_probs = F.softmax(final_output, dim=1) |
| | final_conf, final_classes = torch.max(final_probs, 1) |
| | |
| | |
| | loop_time = time.time() - loop_start |
| | total_time = time.time() - inference_start |
| | avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0 |
| | |
| | print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})") |
| | print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})") |
| | print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)") |
| | print(f"Total inference time: {total_time:.2f} seconds") |
| | |
| | |
| | return { |
| | 'final_image': x[0].detach().cpu(), |
| | 'steps': all_steps, |
| | 'original_class': classes_orig.item(), |
| | 'original_confidence': conf_orig.item(), |
| | 'final_class': final_classes.item(), |
| | 'final_confidence': final_conf.item() |
| | } |
| |
|
| | |
| | def show_inference_steps(steps, figsize=(15, 10)): |
| | import matplotlib.pyplot as plt |
| | |
| | n_steps = len(steps) |
| | fig, axes = plt.subplots(1, n_steps, figsize=figsize) |
| | |
| | for i, step_img in enumerate(steps): |
| | img = step_img.permute(1, 2, 0).numpy() |
| | axes[i].imshow(img) |
| | axes[i].set_title(f"Step {i}") |
| | axes[i].axis('off') |
| | |
| | plt.tight_layout() |
| | return fig |