import os import pickle from typing import Dict, Optional, Tuple,Any import numpy as np import torch import torch.nn as nn from tqdm import tqdm from sklearn.covariance import EmpiricalCovariance from openood.postprocessors.base_postprocessor import BasePostprocessor class AdaptiveNormGate(nn.Module): """ Scalar norm-only gate: g(x) = sigmoid(a * (log ||f|| - b)) adaptive feature: f_adapt = (1 - g) * f + g * (f / ||f||) """ def __init__(self, init_a: float = 1.0, init_b: float = 0.0, eps: float = 1e-10): super().__init__() self.a = nn.Parameter(torch.tensor(float(init_a), dtype=torch.float32)) self.b = nn.Parameter(torch.tensor(float(init_b), dtype=torch.float32)) self.eps = eps def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: norms = torch.norm(features, p=2, dim=1, keepdim=True) log_norms = torch.log(norms + self.eps) g = torch.sigmoid(self.a * (log_norms - self.b)) features_norm = features / (norms + self.eps) features_adapt = (1.0 - g) * features + g * features_norm return features_adapt, g class AdaptiveNormMahalanobisPostprocessor(BasePostprocessor): def __init__(self, config): super().__init__(config) args = getattr(config.postprocessor, 'postprocessor_args', config.postprocessor) self.gate_init_a = getattr(args, 'gate_init_a', 1.0) self.gate_init_b = getattr(args, 'gate_init_b', 0.0) self.gate_lr = getattr(args, 'gate_lr', 1e-2) self.gate_weight_decay = getattr(args, 'gate_weight_decay', 0.0) self.gate_epochs = getattr(args, 'gate_epochs', 20) self.gate_batch_size = getattr(args, 'gate_batch_size', 1024) self.gate_fit_ratio = getattr(args, 'gate_fit_ratio', 0.9) self.covariance_reg = getattr(args, 'covariance_reg', 1e-6) self.eps = getattr(args, 'eps', 1e-10) self.cache_dir = getattr(args, 'cache_dir', './cache') self.save_cache = getattr(args, 'save_cache', False) self.use_cache = getattr(args, 'use_cache', False) self.print_progress = getattr(args, 'print_progress', True) self.reg_lambda = getattr(args, 'reg_lambda', 1e-4) self.reg_type = getattr(args, 'reg_type', 'l2') self.setup_flag = False self.hyperparam_search_done = True self.APS_mode = False self.class_mean: Optional[torch.Tensor] = None self.precision: Optional[torch.Tensor] = None self.num_classes: Optional[int] = None self.feature_dim: Optional[int] = None self.gate = AdaptiveNormGate(self.gate_init_a, self.gate_init_b, self.eps) # Set the device dynamically based on availability self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _get_cache_path(self, net: nn.Module) -> str: net_name = net.__class__.__name__ filename = f'adaptive_norm_mahalanobis_{net_name}.pkl' return os.path.join(self.cache_dir, filename) @torch.no_grad() def _extract_id_features(self, net: nn.Module, id_loader_dict: Dict[str, torch.utils.data.DataLoader] ) -> Tuple[torch.Tensor, torch.Tensor]: if self.use_cache: cache_path = self._get_cache_path(net) if os.path.exists(cache_path): with open(cache_path, 'rb') as f: cache = pickle.load(f) # Moved cached features and labels to dynamic device features = torch.from_numpy(cache['features']).float().to(self.device) labels = torch.from_numpy(cache['labels']).long().to(self.device) return features, labels net.eval() feature_list = [] label_list = [] loader = id_loader_dict['train'] iterator = tqdm(loader, desc='Extracting ID features', disable=not self.print_progress) for batch in iterator: # Moved batch data to dynamic device data = batch['data'].to(self.device) label = batch['label'].to(self.device) _, feature = net(data, return_feature=True) feature_list.append(feature.detach()) label_list.append(label.detach()) features = torch.cat(feature_list, dim=0) labels = torch.cat(label_list, dim=0) if self.save_cache: os.makedirs(self.cache_dir, exist_ok=True) cache_path = self._get_cache_path(net) with open(cache_path, 'wb') as f: pickle.dump({ 'features': features.detach().cpu().numpy(), 'labels': labels.detach().cpu().numpy() }, f) return features, labels def _adaptive_transform(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return self.gate(features) def _fit_gaussian_stats(self, features: torch.Tensor, labels: torch.Tensor, num_classes: int ) -> Tuple[torch.Tensor, torch.Tensor]: device = features.device feat_dim = features.shape[1] class_mean = torch.zeros(num_classes, feat_dim, device=device) centered_chunks = [] for c in range(num_classes): class_mask = (labels == c) class_features = features[class_mask] if class_features.shape[0] == 0: raise ValueError(f'No samples found for class {c} while fitting ' 'AdaptiveNormMahalanobisPostprocessor.') class_mean[c] = class_features.mean(dim=0) centered_chunks.append(class_features - class_mean[c]) centered = torch.cat(centered_chunks, dim=0) centered_np = centered.detach().cpu().numpy() cov = EmpiricalCovariance(assume_centered=True) cov.fit(centered_np) precision = torch.from_numpy(cov.precision_).float().to(device) if self.covariance_reg > 0: identity = torch.eye(precision.shape[0], device=device) cov_reg = torch.from_numpy(cov.covariance_).float().to(device) cov_reg = cov_reg + self.covariance_reg * identity precision = torch.linalg.inv(cov_reg) return class_mean, precision def _true_class_mahalanobis(self, features: torch.Tensor, labels: torch.Tensor, class_mean: torch.Tensor, precision: torch.Tensor ) -> torch.Tensor: mu = class_mean[labels] diff = features - mu left = torch.matmul(diff, precision) dist = torch.sum(left * diff, dim=1) return dist def _gate_regularization(self) -> torch.Tensor: reg = torch.tensor(0.0, device=self.gate.a.device) if self.reg_type == 'l2': reg = self.gate.a.pow(2) + self.gate.b.pow(2) return self.reg_lambda * reg def _train_gate(self, features: torch.Tensor, labels: torch.Tensor, num_classes: int) -> None: device = features.device n = features.shape[0] perm = torch.randperm(n, device=device) split_idx = int(self.gate_fit_ratio * n) split_idx = max(1, min(split_idx, n - 1)) fit_idx = perm[:split_idx] gate_idx = perm[split_idx:] fit_features = features[fit_idx] fit_labels = labels[fit_idx] gate_features = features[gate_idx] gate_labels = labels[gate_idx] optimizer = torch.optim.Adam( self.gate.parameters(), lr=self.gate_lr, weight_decay=self.gate_weight_decay, ) best_state = None best_loss = float('inf') for epoch in range(self.gate_epochs): self.gate.train() with torch.no_grad(): fit_features_adapt, _ = self._adaptive_transform(fit_features) class_mean, precision = self._fit_gaussian_stats( fit_features_adapt, fit_labels, num_classes) epoch_loss = 0.0 num_seen = 0 batch_perm = torch.randperm(gate_features.shape[0], device=device) iterator = range(0, gate_features.shape[0], self.gate_batch_size) if self.print_progress: iterator = tqdm(iterator, desc=f'Training gate epoch {epoch + 1}/{self.gate_epochs}', leave=False) for start in iterator: end = min(start + self.gate_batch_size, gate_features.shape[0]) idx = batch_perm[start:end] batch_features = gate_features[idx] batch_labels = gate_labels[idx] batch_features_adapt, _ = self._adaptive_transform(batch_features) d_true = self._true_class_mahalanobis(batch_features_adapt, batch_labels, class_mean, precision) loss = d_true.mean() + self._gate_regularization() optimizer.zero_grad() loss.backward() optimizer.step() batch_size = batch_features.shape[0] epoch_loss += loss.detach().item() * batch_size num_seen += batch_size epoch_loss /= max(num_seen, 1) if epoch_loss < best_loss: best_loss = epoch_loss best_state = { 'a': self.gate.a.detach().clone(), 'b': self.gate.b.detach().clone(), } if best_state is not None: with torch.no_grad(): self.gate.a.copy_(best_state['a']) self.gate.b.copy_(best_state['b']) self.gate.eval() def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): # Skip expensive initialization if statistics were already prepared. if self.setup_flag: return # Freeze backbone behavior and prepare the gate module for training. net.eval() # Ensure the gate module is on the correct device self.gate.to(self.device) self.gate.train() # Collect all ID features/labels once; these drive gate fitting and Gaussian stats. with torch.no_grad(): features, labels = self._extract_id_features(net, id_loader_dict) # Infer dataset/classification geometry from extracted features. self.num_classes = int(labels.max().item()) + 1 self.feature_dim = features.shape[1] # Optimize gate parameters to reduce true-class Mahalanobis distance. self._train_gate(features, labels, self.num_classes) # Recompute class means and shared precision using gate-adapted features. with torch.no_grad(): features_adapt, _ = self._adaptive_transform(features) self.class_mean, self.precision = self._fit_gaussian_stats( features_adapt, labels, self.num_classes) # Mark setup complete so inference can call postprocess safely. self.setup_flag = True @torch.no_grad() def postprocess(self, net: nn.Module, data: Any): # Guard against using postprocess before class statistics are available. if not self.setup_flag: raise RuntimeError('AdaptiveNormMahalanobisPostprocessor must be ' 'setup before calling postprocess().') # Run inference with fixed model/gate parameters. net.eval() self.gate.eval() # Extract logits/features, then apply the learned adaptive normalization. output, feature = net(data, return_feature=True) feature_adapt, _ = self._adaptive_transform(feature) # Compute Mahalanobis distance from each sample to every class centroid. diff = feature_adapt.unsqueeze(1) - self.class_mean.unsqueeze(0) left = torch.matmul(diff, self.precision) mahalanobis_distance = torch.sum(left * diff, dim=2) # OOD score: negative minimum distance (higher is more ID-like). score = -torch.min(mahalanobis_distance, dim=1)[0] # Predicted class from model logits. pred = torch.argmax(output, dim=1) return pred, score