"""Complete generative inference module with model loading and inference capabilities.""" import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision.models as models from torchvision.models.resnet import ResNet50_Weights from PIL import Image import numpy as np import os import requests import time import copy from collections import OrderedDict from pathlib import Path from typing import Dict, List, Optional, Tuple, Union # Check for available hardware acceleration if torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs else: device = torch.device("cpu") print(f"Using device: {device}") # Constants for model URLs MODEL_URLS = { 'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt', 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt', 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt' } # Model-specific preprocessing configurations MODEL_CONFIGS = { 'resnet50_robust_face': { 'input_size': 112, 'norm_mean': [0.5, 0.5, 0.5], 'norm_std': [0.5, 0.5, 0.5], 'n_classes': 500, 'dataset': 'VGGFace2' }, 'resnet50_standard': { 'input_size': 224, 'norm_mean': [0.485, 0.456, 0.406], 'norm_std': [0.229, 0.224, 0.225], 'n_classes': 1000, 'dataset': 'ImageNet' }, 'resnet50_robust': { 'input_size': 224, 'norm_mean': [0.485, 0.456, 0.406], 'norm_std': [0.229, 0.224, 0.225], 'n_classes': 1000, 'dataset': 'ImageNet' } } IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_iterations_to_show(n_itr): """Generate a dynamic list of iterations to show based on total iterations.""" if n_itr <= 50: return [1, 5, 10, 20, 30, 40, 50, n_itr] elif n_itr <= 100: return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr] elif n_itr <= 200: return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr] elif n_itr <= 500: return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr] else: return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr] def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0): """Generate inference configuration with customizable parameters.""" config = { 'loss_infer': inference_type, 'n_itr': n_itr, 'eps': eps, 'step_size': step_size, 'diffusion_noise_ratio': 0.0, 'initial_inference_noise_ratio': 0.0, 'top_layer': 'all', 'inference_normalization': False, 'recognition_normalization': False, 'iterations_to_show': get_iterations_to_show(n_itr), 'misc_info': {'keep_grads': False} } if inference_type == 'IncreaseConfidence': config['loss_function'] = 'CE' elif inference_type == 'Prior-Guided Drift Diffusion': config['loss_function'] = 'MSE' config['initial_inference_noise_ratio'] = 0.05 config['diffusion_noise_ratio'] = 0.01 config['top_layer'] = 'layer4' elif inference_type == 'GradModulation': config['loss_function'] = 'CE' config['misc_info']['grad_modulation'] = 0.5 elif inference_type == 'CompositionalFusion': config['loss_function'] = 'CE' config['misc_info']['positive_classes'] = [] config['misc_info']['negative_classes'] = [] return config def get_model_preprocessing(model_type: str) -> Dict: """Get preprocessing configuration for specific model type.""" if model_type not in MODEL_CONFIGS: print(f"Fall-back: Unknown model type {model_type}, using ImageNet defaults") return MODEL_CONFIGS['resnet50_standard'] return MODEL_CONFIGS[model_type] class NormalizeByChannelMeanStd(nn.Module): """Normalization layer for models.""" def __init__(self, mean, std): super(NormalizeByChannelMeanStd, self).__init__() if not isinstance(mean, torch.Tensor): mean = torch.tensor(mean) if not isinstance(std, torch.Tensor): std = torch.tensor(std) self.register_buffer("mean", mean) self.register_buffer("std", std) def forward(self, tensor): return self.normalize_fn(tensor, self.mean, self.std) def normalize_fn(self, tensor, mean, std): """Differentiable version of torchvision.functional.normalize""" mean = mean[None, :, None, None] std = std[None, :, None, None] return tensor.sub(mean).div(std) class InferStep: """Inference step class for gradient-based optimization.""" def __init__(self, orig_image: torch.Tensor, eps: float, step_size: float): self.orig_image = orig_image self.eps = eps self.step_size = step_size def project(self, x: torch.Tensor) -> torch.Tensor: """Project x onto epsilon-ball around original image.""" diff = x - self.orig_image diff = torch.clamp(diff, -self.eps, self.eps) return torch.clamp(self.orig_image + diff, 0, 1) def step(self, x: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: """Take a normalized gradient step.""" dim = len(x.shape) - 1 grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=1).reshape(-1, *([1] * dim)) scaled_grad = grad / (grad_norm + 1e-10) return scaled_grad * self.step_size def extract_middle_layers(model: nn.Module, layer_index: Union[str, int]) -> nn.Module: """Extract middle layers from a model up to a specified layer index.""" if isinstance(layer_index, str) and layer_index == 'all': return model # Handle ResNet layer extraction modules = list(model.named_children()) cutoff_idx = next( (i for i, (name, _) in enumerate(modules) if name == str(layer_index)), None ) if cutoff_idx is not None: new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx + 1])) return new_model else: print(f"Fall-back: Module {layer_index} not found, using full model") return model def calculate_loss(output_model: torch.Tensor, class_indices: List[int], loss_inference: str) -> torch.Tensor: """Calculate loss for specified class indices.""" losses = [] for idx in class_indices: target = torch.full((1,), idx, dtype=torch.long, device=output_model.device) if loss_inference == 'CE': loss = nn.CrossEntropyLoss()(output_model, target) elif loss_inference == 'MSE': one_hot_target = torch.zeros_like(output_model) one_hot_target[0, target] = 1 loss = nn.MSELoss()(output_model, one_hot_target) else: raise ValueError(f"Unsupported loss_inference: {loss_inference}") losses.append(loss) return torch.stack(losses).mean() def download_model(model_type): """Download model if needed.""" if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: return None os.makedirs("models", exist_ok=True) if model_type == 'resnet50_robust_face': model_path = Path("models/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt") else: model_path = Path(f"models/{model_type}.pt") if not model_path.exists(): print(f"Downloading {model_type} model...") url = MODEL_URLS[model_type] response = requests.get(url, stream=True) if response.status_code == 200: with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Model downloaded and saved to {model_path}") else: raise RuntimeError(f"Failed to download model: {response.status_code}") return model_path class GenerativeInferenceModel: """Complete generative inference model with model loading and inference.""" def __init__(self): self.models = {} self.model_preproc = {} self.labels = self.get_imagenet_labels() def get_imagenet_labels(self): """Get ImageNet labels.""" url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" try: response = requests.get(url, timeout=10) # Add timeout to prevent hanging if response.status_code == 200: return response.json() else: print("Fall-back: Failed to fetch ImageNet labels, using placeholder") return [f"class_{i}" for i in range(1000)] except Exception as e: print(f"Fall-back: Error fetching labels: {e}") return [f"class_{i}" for i in range(1000)] def load_model(self, model_type): """Load and cache models for different model types.""" if model_type in self.models: print(f"Using cached {model_type} model") return self.models[model_type] start_time = time.time() # Get model-specific preprocessing config preproc_config = get_model_preprocessing(model_type) self.model_preproc[model_type] = preproc_config # Create normalizer normalizer = NormalizeByChannelMeanStd( preproc_config['norm_mean'], preproc_config['norm_std'] ).to(device) # Create base model architecture num_classes = preproc_config['n_classes'] resnet = models.resnet50(num_classes=num_classes) model = nn.Sequential(normalizer, resnet) # Download and load checkpoint model_path = download_model(model_type) if model_path: print(f"Loading {model_type} model from {model_path}...") try: checkpoint = torch.load(model_path, map_location=device) # Handle different checkpoint formats if 'model' in checkpoint: state_dict = checkpoint['model'] print("Using 'model' key from checkpoint") elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] print("Using 'state_dict' key from checkpoint") else: state_dict = checkpoint print("Using checkpoint directly as state_dict") # Extract ResNet state dict resnet_state_dict = {} resnet_keys = set(resnet.state_dict().keys()) # For face model, prioritize 'module.model.model.' structure (seen in actual checkpoint) if model_type == 'resnet50_robust_face': # Check for 'module.model.model.' structure first (face checkpoints use this) module_model_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.model.')] if module_model_model_keys: print(f"Found 'module.model.model.' structure with {len(module_model_model_keys)} parameters (face model)") for source_key, value in state_dict.items(): if source_key.startswith('module.model.model.'): target_key = source_key[len('module.model.model.'):] if target_key in resnet_keys: resnet_state_dict[target_key] = value print(f"Extracted {len(resnet_state_dict)} parameters from module.model.model.") # Also check for 'module.model.' structure as fallback if len(resnet_state_dict) < len(resnet_keys): module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.') and not key.startswith('module.model.model.')] if module_model_keys: print(f"Found additional 'module.model.' structure with {len(module_model_keys)} parameters") for source_key, value in state_dict.items(): if source_key.startswith('module.model.') and not source_key.startswith('module.model.model.'): target_key = source_key[len('module.model.'):] # Remove extra 'model.' if present if target_key.startswith('model.'): target_key = target_key[len('model.'):] if target_key in resnet_keys and target_key not in resnet_state_dict: resnet_state_dict[target_key] = value print(f"Now have {len(resnet_state_dict)} parameters after adding module.model. keys") # Handle different key prefixes in checkpoints (for other models) if len(resnet_state_dict) == 0: prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.', 'attacker.'] for source_key, value in state_dict.items(): target_key = source_key # Try removing various prefixes for prefix in prefixes_to_try: if source_key.startswith(prefix): target_key = source_key[len(prefix):] break # Handle nested model keys if target_key.startswith('model.'): target_key = target_key[len('model.'):] # If the target key is in ResNet keys, add it if target_key in resnet_keys: resnet_state_dict[target_key] = value # Load the state dict if resnet_state_dict: result = resnet.load_state_dict(resnet_state_dict, strict=False) missing_keys, unexpected_keys = result loaded_percent = (len(resnet_state_dict) / len(resnet_keys)) * 100 print(f"Model loading: {len(resnet_state_dict)}/{len(resnet_keys)} parameters ({loaded_percent:.1f}%)") if loaded_percent < 50: print(f"Fall-back: Loading too incomplete ({loaded_percent:.1f}%), using PyTorch pretrained") if model_type != 'resnet50_robust_face': resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(normalizer, resnet) else: print("Fall-back: No matching keys found in checkpoint, using PyTorch pretrained") if model_type != 'resnet50_robust_face': resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(normalizer, resnet) except Exception as e: print(f"Fall-back: Error loading checkpoint: {e}") if model_type != 'resnet50_robust_face': print("Fall-back: Using PyTorch pretrained model") resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(normalizer, resnet) else: print("Fall-back: Face model checkpoint failed, model may not work properly") else: # Use PyTorch's pretrained model for ImageNet models if model_type != 'resnet50_robust_face': print(f"No checkpoint for {model_type}, using PyTorch pretrained") resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(normalizer, resnet) else: print("Fall-back: Face model requires checkpoint, model may not work properly") model = model.to(device) model.eval() # Verify model self.verify_model_integrity(model, model_type) # Cache the model self.models[model_type] = model end_time = time.time() print(f"Model {model_type} loaded in {end_time - start_time:.2f} seconds") return model def verify_model_integrity(self, model, model_type): """Verify model integrity.""" try: print(f"Fall-back: Running model integrity check for {model_type}") config = get_model_preprocessing(model_type) H = W = config['input_size'] test_input = torch.zeros(1, 3, H, W, device=device) test_input[0, 0, H//4:3*H//4, W//4:3*W//4] = 0.5 with torch.no_grad(): output = model(test_input) expected_classes = config['n_classes'] if output.shape != (1, expected_classes): print(f"Fall-back: Unexpected output shape: {output.shape}, expected (1, {expected_classes})") return False probs = torch.nn.functional.softmax(output, dim=1) confidence, prediction = torch.max(probs, 1) print(f"Model integrity check passed:") print(f"- Output shape: {output.shape}") print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence") return True except Exception as e: print(f"Fall-back: Model integrity check failed with error: {e}") return False def inference(self, image, model_type, config): """Run generative inference.""" inference_start = time.time() # Load the model model = self.load_model(model_type) # Handle image input if isinstance(image, str): if os.path.exists(image): image = Image.open(image).convert('RGB') else: raise ValueError(f"Image path does not exist: {image}") elif isinstance(image, np.ndarray): if image.dtype != np.uint8: if image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) if len(image.shape) == 3: if image.shape[0] == 3 or image.shape[0] == 1: image = np.transpose(image, (1, 2, 0)) if image.shape[2] == 4: image = image[:, :, :3] elif image.shape[2] == 1: image = np.repeat(image, 3, axis=2) image = Image.fromarray(image) elif not isinstance(image, Image.Image): try: image = Image.fromarray(np.array(image)).convert('RGB') except Exception as e: raise ValueError(f"Cannot convert image type {type(image)} to PIL Image: {e}") if isinstance(image, Image.Image) and image.mode != 'RGB': image = image.convert('RGB') # Get preprocessing config preproc_config = get_model_preprocessing(model_type) input_size = preproc_config['input_size'] norm_mean = torch.tensor(preproc_config['norm_mean']) norm_std = torch.tensor(preproc_config['norm_std']) n_classes = preproc_config['n_classes'] # Create transform if config.get('inference_normalization', False): transform = transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize(norm_mean.tolist(), norm_std.tolist()), ]) print(f"Fall-back: Using normalization with mean={norm_mean.tolist()}, std={norm_std.tolist()}") else: transform = transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), ]) print(f"Normalization OFF - feeding raw [0,1] tensors to model (normalization applied in the model)") # Helper function to safely apply transform with fallback for numpy compatibility def safe_transform(img): try: return transform(img) except TypeError as e: if "expected np.ndarray" in str(e) or "got numpy.ndarray" in str(e): # Fallback: manually convert PIL to tensor print(f"[WARNING] Transform failed with numpy compatibility issue, using manual conversion") # Apply resize and center crop manually resize_transform = transforms.Resize(input_size) crop_transform = transforms.CenterCrop(input_size) img = crop_transform(resize_transform(img)) # Convert to numpy array and then to tensor using torch.tensor() to avoid numpy compatibility issues img_array = np.array(img, dtype=np.uint8) # Use torch.tensor() instead of torch.from_numpy() to avoid compatibility issues # Convert to float and normalize to [0, 1], then convert from HWC to CHW format img_tensor = torch.tensor(img_array, dtype=torch.float32).div(255.0).permute(2, 0, 1) # Apply normalization if needed if config.get('inference_normalization', False): img_tensor = transforms.Normalize(norm_mean.tolist(), norm_std.tolist())(img_tensor) return img_tensor else: raise # Prepare image tensor with safe transform image_tensor = safe_transform(image).unsqueeze(0).to(device) image_tensor.requires_grad = True # Get model components is_sequential = isinstance(model, nn.Sequential) if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): core_model = model[1] else: core_model = model # Prepare model for layer extraction if config.get('top_layer', 'all') != 'all': new_model = extract_middle_layers(core_model, config['top_layer']) else: new_model = model # Get original predictions with torch.no_grad(): if config.get('inference_normalization', False): output_original = model(image_tensor) else: output_original = core_model(image_tensor) probs_orig = F.softmax(output_original, dim=1) conf_orig, classes_orig = torch.max(probs_orig, 1) # Get least confident classes for IncreaseConfidence if config['loss_infer'] == 'IncreaseConfidence': _, least_confident_classes = torch.topk(probs_orig, k=int(n_classes / 10), largest=False) # Setup for Prior-Guided Drift Diffusion noisy_features = None if config['loss_infer'] == 'Prior-Guided Drift Diffusion': print(f"Setting up Prior-Guided Drift Diffusion...") added_noise = config.get('initial_inference_noise_ratio', 0.05) * torch.randn_like(image_tensor).to(device) noisy_image_tensor = image_tensor + added_noise noisy_features = new_model(noisy_image_tensor) # Initialize inference step infer_step = InferStep(image_tensor, config['eps'], config['step_size']) # Storage for inference steps x = image_tensor.clone().detach().requires_grad_(True) all_steps = [image_tensor[0].detach().cpu()] selected_inferred_patterns = [] perceived_categories = [] confidence_list = [] # Main inference loop print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...") for i in range(config['n_itr']): # Reset gradients x.grad = None if i == 0: # Get predictions for first iteration if config.get('inference_normalization', False): output = model(x) else: output = core_model(x) if isinstance(output, torch.Tensor) and output.size(-1) == n_classes: probs = F.softmax(output, dim=1) conf, classes = torch.max(probs, 1) else: probs = 0 conf = 0 classes = 'N/A' else: # Calculate loss and gradients try: # Forward pass through new_model for feature extraction features = new_model(x) if config['loss_infer'] == 'Prior-Guided Drift Diffusion': assert config.get('loss_function', 'MSE') == 'MSE', "Prior-Guided Drift Diffusion requires MSE loss" if noisy_features is not None: loss = F.mse_loss(features, noisy_features) grad = torch.autograd.grad(loss, x)[0] adjusted_grad = infer_step.step(x, grad) else: raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion") elif config['loss_infer'] == 'IncreaseConfidence': # Calculate loss using least confident classes num_target_classes = min(int(n_classes / 10), least_confident_classes.size(1)) target_classes = least_confident_classes[0, :num_target_classes] loss = calculate_loss(features, target_classes.tolist(), config.get('loss_function', 'CE')) grad = torch.autograd.grad(loss, x, retain_graph=True)[0] adjusted_grad = infer_step.step(x, grad) else: raise ValueError(f"Loss inference method {config['loss_infer']} not supported") if grad is None: print("Fall-back: Direct gradient calculation failed") random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] x = infer_step.project(x.clone() + random_noise) else: # Add diffusion noise if specified diffusion_noise = config.get('diffusion_noise_ratio', 0.0) * torch.randn_like(x).to(device) x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise) except Exception as e: print(f"Fall-back: Error in gradient calculation: {e}") random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] x = infer_step.project(x.clone() + random_noise) # Store step if in iterations_to_show if i+1 in config.get('iterations_to_show', []) or i+1 == config['n_itr']: all_steps.append(x[0].detach().cpu()) selected_inferred_patterns.append(x[0].detach().cpu()) # Get current predictions with torch.no_grad(): if config.get('inference_normalization', False): current_output = model(x) else: current_output = core_model(x) if isinstance(current_output, torch.Tensor) and current_output.size(-1) == n_classes: current_probs = F.softmax(current_output, dim=1) current_conf, current_classes = torch.max(current_probs, 1) perceived_categories.append(current_classes.item()) confidence_list.append(current_conf.item()) else: perceived_categories.append('N/A') confidence_list.append(0.0) # Final predictions with torch.no_grad(): if config.get('inference_normalization', False): final_output = model(x) else: final_output = core_model(x) final_probs = F.softmax(final_output, dim=1) final_conf, final_classes = torch.max(final_probs, 1) total_time = time.time() - inference_start print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})") print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})") print(f"Total inference time: {total_time:.2f} seconds") # Return results in Code 1 format return { 'final_image': x[0].detach().cpu(), 'steps': all_steps, 'original_class': classes_orig.item(), 'original_confidence': conf_orig.item(), 'final_class': final_classes.item(), 'final_confidence': final_conf.item(), 'all_categories': perceived_categories, 'all_confidences': confidence_list, } def show_inference_steps(steps, figsize=(15, 10)): """Show inference steps using matplotlib.""" try: import matplotlib.pyplot as plt n_steps = len(steps) fig, axes = plt.subplots(1, n_steps, figsize=figsize) if n_steps == 1: axes = [axes] for i, step_img in enumerate(steps): if isinstance(step_img, torch.Tensor): img = step_img.permute(1, 2, 0).numpy() img = np.clip(img, 0, 1) else: img = step_img axes[i].imshow(img) axes[i].set_title(f"Step {i+1}") axes[i].axis('off') plt.tight_layout() return fig except ImportError: print("Fall-back: matplotlib not available for visualization") return None except Exception as e: print(f"Fall-back: Visualization failed: {e}") return None # Export the main classes and functions __all__ = ['GenerativeInferenceModel', 'get_inference_configs', 'show_inference_steps']