Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from openood.postprocessors.base_postprocessor import BasePostprocessor | |
| class MahalanobisPlusPlusPostprocessor(BasePostprocessor): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.class_means = None | |
| self.precision = None | |
| self.APS_mode = False | |
| self.hyperparam_search_done = True | |
| # Set the device dynamically based on availability | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): | |
| print(f"Computing Mahalanobis statistics on {self.device}...") | |
| features_list = [] | |
| labels_list = [] | |
| net.eval() | |
| # Collect features | |
| for batch in tqdm(id_loader_dict['train']): | |
| # Move data and labels to the detected device | |
| data = batch['data'].to(self.device) | |
| label = batch['label'].to(self.device) | |
| # Get features | |
| _, features = net(data, return_feature=True) | |
| # Added normalization | |
| features = features / (features.norm(dim=1, keepdim=True) + 1e-10) | |
| features_list.append(features) | |
| labels_list.append(label) | |
| features = torch.cat(features_list, dim=0) | |
| labels = torch.cat(labels_list, dim=0) | |
| n_classes = labels.max().item() + 1 | |
| feat_dim = features.size(1) | |
| # Compute class means directly on the device | |
| class_means = torch.zeros(n_classes, feat_dim, device=self.device) | |
| centered = torch.zeros_like(features) | |
| for c in range(n_classes): | |
| mask = labels == c | |
| class_feats = features[mask] | |
| class_means[c] = class_feats.mean(dim=0) | |
| centered[mask] = class_feats - class_means[c] | |
| # Compute covariance | |
| cov = centered.t().mm(centered) / features.size(0) | |
| # Regularization (initialized directly on the device) | |
| cov += 1e-4 * torch.eye(feat_dim, device=self.device) | |
| precision = torch.linalg.inv(cov) | |
| self.class_means = class_means | |
| self.precision = precision | |
| print("Mahalanobis setup complete.") | |
| def postprocess(self, net: nn.Module, data): | |
| _, features = net(data, return_feature=True) | |
| # Added normalization | |
| features = features / (features.norm(dim=1, keepdim=True) + 1e-10) | |
| # Compute distances | |
| diff = features.unsqueeze(1) - self.class_means.unsqueeze(0) | |
| left = torch.matmul(diff, self.precision) | |
| dist = (left * diff).sum(dim=2) | |
| # Mahalanobis score = negative distance | |
| score = -dist.min(dim=1)[0] | |
| pred = torch.zeros_like(score).long() # not used | |
| return pred, score |