import torch import torch.nn as nn from tqdm import tqdm from openood.postprocessors.base_postprocessor import BasePostprocessor class MahalanobisPlusPlusPostprocessor(BasePostprocessor): def __init__(self, config): super().__init__(config) self.class_means = None self.precision = None self.APS_mode = False self.hyperparam_search_done = True # Set the device dynamically based on availability self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @torch.no_grad() def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): print(f"Computing Mahalanobis statistics on {self.device}...") features_list = [] labels_list = [] net.eval() # Collect features for batch in tqdm(id_loader_dict['train']): # Move data and labels to the detected device data = batch['data'].to(self.device) label = batch['label'].to(self.device) # Get features _, features = net(data, return_feature=True) # Added normalization features = features / (features.norm(dim=1, keepdim=True) + 1e-10) features_list.append(features) labels_list.append(label) features = torch.cat(features_list, dim=0) labels = torch.cat(labels_list, dim=0) n_classes = labels.max().item() + 1 feat_dim = features.size(1) # Compute class means directly on the device class_means = torch.zeros(n_classes, feat_dim, device=self.device) centered = torch.zeros_like(features) for c in range(n_classes): mask = labels == c class_feats = features[mask] class_means[c] = class_feats.mean(dim=0) centered[mask] = class_feats - class_means[c] # Compute covariance cov = centered.t().mm(centered) / features.size(0) # Regularization (initialized directly on the device) cov += 1e-4 * torch.eye(feat_dim, device=self.device) precision = torch.linalg.inv(cov) self.class_means = class_means self.precision = precision print("Mahalanobis setup complete.") @torch.no_grad() def postprocess(self, net: nn.Module, data): _, features = net(data, return_feature=True) # Added normalization features = features / (features.norm(dim=1, keepdim=True) + 1e-10) # Compute distances diff = features.unsqueeze(1) - self.class_means.unsqueeze(0) left = torch.matmul(diff, self.precision) dist = (left * diff).sum(dim=2) # Mahalanobis score = negative distance score = -dist.min(dim=1)[0] pred = torch.zeros_like(score).long() # not used return pred, score