Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| from typing import Dict, Optional, Tuple,Any | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from sklearn.covariance import EmpiricalCovariance | |
| from openood.postprocessors.base_postprocessor import BasePostprocessor | |
| class AdaptiveNormGate(nn.Module): | |
| """ | |
| Scalar norm-only gate: g(x) = sigmoid(a * (log ||f|| - b)) | |
| adaptive feature: f_adapt = (1 - g) * f + g * (f / ||f||) | |
| """ | |
| def __init__(self, | |
| init_a: float = 1.0, | |
| init_b: float = 0.0, | |
| eps: float = 1e-10): | |
| super().__init__() | |
| self.a = nn.Parameter(torch.tensor(float(init_a), | |
| dtype=torch.float32)) | |
| self.b = nn.Parameter(torch.tensor(float(init_b), | |
| dtype=torch.float32)) | |
| self.eps = eps | |
| def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| norms = torch.norm(features, p=2, dim=1, keepdim=True) | |
| log_norms = torch.log(norms + self.eps) | |
| g = torch.sigmoid(self.a * (log_norms - self.b)) | |
| features_norm = features / (norms + self.eps) | |
| features_adapt = (1.0 - g) * features + g * features_norm | |
| return features_adapt, g | |
| class AdaptiveNormMahalanobisPostprocessor(BasePostprocessor): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| args = getattr(config.postprocessor, 'postprocessor_args', config.postprocessor) | |
| self.gate_init_a = getattr(args, 'gate_init_a', 1.0) | |
| self.gate_init_b = getattr(args, 'gate_init_b', 0.0) | |
| self.gate_lr = getattr(args, 'gate_lr', 1e-2) | |
| self.gate_weight_decay = getattr(args, 'gate_weight_decay', 0.0) | |
| self.gate_epochs = getattr(args, 'gate_epochs', 20) | |
| self.gate_batch_size = getattr(args, 'gate_batch_size', 1024) | |
| self.gate_fit_ratio = getattr(args, 'gate_fit_ratio', 0.9) | |
| self.covariance_reg = getattr(args, 'covariance_reg', 1e-6) | |
| self.eps = getattr(args, 'eps', 1e-10) | |
| self.cache_dir = getattr(args, 'cache_dir', './cache') | |
| self.save_cache = getattr(args, 'save_cache', False) | |
| self.use_cache = getattr(args, 'use_cache', False) | |
| self.print_progress = getattr(args, 'print_progress', True) | |
| self.reg_lambda = getattr(args, 'reg_lambda', 1e-4) | |
| self.reg_type = getattr(args, 'reg_type', 'l2') | |
| self.setup_flag = False | |
| self.hyperparam_search_done = True | |
| self.APS_mode = False | |
| self.class_mean: Optional[torch.Tensor] = None | |
| self.precision: Optional[torch.Tensor] = None | |
| self.num_classes: Optional[int] = None | |
| self.feature_dim: Optional[int] = None | |
| self.gate = AdaptiveNormGate(self.gate_init_a, | |
| self.gate_init_b, | |
| self.eps) | |
| # Set the device dynamically based on availability | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _get_cache_path(self, net: nn.Module) -> str: | |
| net_name = net.__class__.__name__ | |
| filename = f'adaptive_norm_mahalanobis_{net_name}.pkl' | |
| return os.path.join(self.cache_dir, filename) | |
| def _extract_id_features(self, | |
| net: nn.Module, | |
| id_loader_dict: Dict[str, torch.utils.data.DataLoader] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.use_cache: | |
| cache_path = self._get_cache_path(net) | |
| if os.path.exists(cache_path): | |
| with open(cache_path, 'rb') as f: | |
| cache = pickle.load(f) | |
| # Moved cached features and labels to dynamic device | |
| features = torch.from_numpy(cache['features']).float().to(self.device) | |
| labels = torch.from_numpy(cache['labels']).long().to(self.device) | |
| return features, labels | |
| net.eval() | |
| feature_list = [] | |
| label_list = [] | |
| loader = id_loader_dict['train'] | |
| iterator = tqdm(loader, | |
| desc='Extracting ID features', | |
| disable=not self.print_progress) | |
| for batch in iterator: | |
| # Moved batch data to dynamic device | |
| data = batch['data'].to(self.device) | |
| label = batch['label'].to(self.device) | |
| _, feature = net(data, return_feature=True) | |
| feature_list.append(feature.detach()) | |
| label_list.append(label.detach()) | |
| features = torch.cat(feature_list, dim=0) | |
| labels = torch.cat(label_list, dim=0) | |
| if self.save_cache: | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| cache_path = self._get_cache_path(net) | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump({ | |
| 'features': features.detach().cpu().numpy(), | |
| 'labels': labels.detach().cpu().numpy() | |
| }, f) | |
| return features, labels | |
| def _adaptive_transform(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return self.gate(features) | |
| def _fit_gaussian_stats(self, | |
| features: torch.Tensor, | |
| labels: torch.Tensor, | |
| num_classes: int | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| device = features.device | |
| feat_dim = features.shape[1] | |
| class_mean = torch.zeros(num_classes, feat_dim, device=device) | |
| centered_chunks = [] | |
| for c in range(num_classes): | |
| class_mask = (labels == c) | |
| class_features = features[class_mask] | |
| if class_features.shape[0] == 0: | |
| raise ValueError(f'No samples found for class {c} while fitting ' | |
| 'AdaptiveNormMahalanobisPostprocessor.') | |
| class_mean[c] = class_features.mean(dim=0) | |
| centered_chunks.append(class_features - class_mean[c]) | |
| centered = torch.cat(centered_chunks, dim=0) | |
| centered_np = centered.detach().cpu().numpy() | |
| cov = EmpiricalCovariance(assume_centered=True) | |
| cov.fit(centered_np) | |
| precision = torch.from_numpy(cov.precision_).float().to(device) | |
| if self.covariance_reg > 0: | |
| identity = torch.eye(precision.shape[0], device=device) | |
| cov_reg = torch.from_numpy(cov.covariance_).float().to(device) | |
| cov_reg = cov_reg + self.covariance_reg * identity | |
| precision = torch.linalg.inv(cov_reg) | |
| return class_mean, precision | |
| def _true_class_mahalanobis(self, | |
| features: torch.Tensor, | |
| labels: torch.Tensor, | |
| class_mean: torch.Tensor, | |
| precision: torch.Tensor | |
| ) -> torch.Tensor: | |
| mu = class_mean[labels] | |
| diff = features - mu | |
| left = torch.matmul(diff, precision) | |
| dist = torch.sum(left * diff, dim=1) | |
| return dist | |
| def _gate_regularization(self) -> torch.Tensor: | |
| reg = torch.tensor(0.0, device=self.gate.a.device) | |
| if self.reg_type == 'l2': | |
| reg = self.gate.a.pow(2) + self.gate.b.pow(2) | |
| return self.reg_lambda * reg | |
| def _train_gate(self, | |
| features: torch.Tensor, | |
| labels: torch.Tensor, | |
| num_classes: int) -> None: | |
| device = features.device | |
| n = features.shape[0] | |
| perm = torch.randperm(n, device=device) | |
| split_idx = int(self.gate_fit_ratio * n) | |
| split_idx = max(1, min(split_idx, n - 1)) | |
| fit_idx = perm[:split_idx] | |
| gate_idx = perm[split_idx:] | |
| fit_features = features[fit_idx] | |
| fit_labels = labels[fit_idx] | |
| gate_features = features[gate_idx] | |
| gate_labels = labels[gate_idx] | |
| optimizer = torch.optim.Adam( | |
| self.gate.parameters(), | |
| lr=self.gate_lr, | |
| weight_decay=self.gate_weight_decay, | |
| ) | |
| best_state = None | |
| best_loss = float('inf') | |
| for epoch in range(self.gate_epochs): | |
| self.gate.train() | |
| with torch.no_grad(): | |
| fit_features_adapt, _ = self._adaptive_transform(fit_features) | |
| class_mean, precision = self._fit_gaussian_stats( | |
| fit_features_adapt, fit_labels, num_classes) | |
| epoch_loss = 0.0 | |
| num_seen = 0 | |
| batch_perm = torch.randperm(gate_features.shape[0], device=device) | |
| iterator = range(0, gate_features.shape[0], self.gate_batch_size) | |
| if self.print_progress: | |
| iterator = tqdm(iterator, | |
| desc=f'Training gate epoch {epoch + 1}/{self.gate_epochs}', | |
| leave=False) | |
| for start in iterator: | |
| end = min(start + self.gate_batch_size, gate_features.shape[0]) | |
| idx = batch_perm[start:end] | |
| batch_features = gate_features[idx] | |
| batch_labels = gate_labels[idx] | |
| batch_features_adapt, _ = self._adaptive_transform(batch_features) | |
| d_true = self._true_class_mahalanobis(batch_features_adapt, | |
| batch_labels, | |
| class_mean, | |
| precision) | |
| loss = d_true.mean() + self._gate_regularization() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| batch_size = batch_features.shape[0] | |
| epoch_loss += loss.detach().item() * batch_size | |
| num_seen += batch_size | |
| epoch_loss /= max(num_seen, 1) | |
| if epoch_loss < best_loss: | |
| best_loss = epoch_loss | |
| best_state = { | |
| 'a': self.gate.a.detach().clone(), | |
| 'b': self.gate.b.detach().clone(), | |
| } | |
| if best_state is not None: | |
| with torch.no_grad(): | |
| self.gate.a.copy_(best_state['a']) | |
| self.gate.b.copy_(best_state['b']) | |
| self.gate.eval() | |
| def setup(self, | |
| net: nn.Module, | |
| id_loader_dict, | |
| ood_loader_dict): | |
| # Skip expensive initialization if statistics were already prepared. | |
| if self.setup_flag: | |
| return | |
| # Freeze backbone behavior and prepare the gate module for training. | |
| net.eval() | |
| # Ensure the gate module is on the correct device | |
| self.gate.to(self.device) | |
| self.gate.train() | |
| # Collect all ID features/labels once; these drive gate fitting and Gaussian stats. | |
| with torch.no_grad(): | |
| features, labels = self._extract_id_features(net, id_loader_dict) | |
| # Infer dataset/classification geometry from extracted features. | |
| self.num_classes = int(labels.max().item()) + 1 | |
| self.feature_dim = features.shape[1] | |
| # Optimize gate parameters to reduce true-class Mahalanobis distance. | |
| self._train_gate(features, labels, self.num_classes) | |
| # Recompute class means and shared precision using gate-adapted features. | |
| with torch.no_grad(): | |
| features_adapt, _ = self._adaptive_transform(features) | |
| self.class_mean, self.precision = self._fit_gaussian_stats( | |
| features_adapt, labels, self.num_classes) | |
| # Mark setup complete so inference can call postprocess safely. | |
| self.setup_flag = True | |
| def postprocess(self, net: nn.Module, data: Any): | |
| # Guard against using postprocess before class statistics are available. | |
| if not self.setup_flag: | |
| raise RuntimeError('AdaptiveNormMahalanobisPostprocessor must be ' | |
| 'setup before calling postprocess().') | |
| # Run inference with fixed model/gate parameters. | |
| net.eval() | |
| self.gate.eval() | |
| # Extract logits/features, then apply the learned adaptive normalization. | |
| output, feature = net(data, return_feature=True) | |
| feature_adapt, _ = self._adaptive_transform(feature) | |
| # Compute Mahalanobis distance from each sample to every class centroid. | |
| diff = feature_adapt.unsqueeze(1) - self.class_mean.unsqueeze(0) | |
| left = torch.matmul(diff, self.precision) | |
| mahalanobis_distance = torch.sum(left * diff, dim=2) | |
| # OOD score: negative minimum distance (higher is more ID-like). | |
| score = -torch.min(mahalanobis_distance, dim=1)[0] | |
| # Predicted class from model logits. | |
| pred = torch.argmax(output, dim=1) | |
| return pred, score |