Spaces:
Sleeping
Sleeping
File size: 2,468 Bytes
eb1fd70 7cf7abf eb1fd70 7cf7abf eb1fd70 7cf7abf eb1fd70 7cf7abf eb1fd70 7cf7abf eb1fd70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import torch
import torch.nn as nn
from tqdm import tqdm
from openood.postprocessors.base_postprocessor import BasePostprocessor
class MahalanobisPlusPlusPostprocessor(BasePostprocessor):
def __init__(self, config):
super().__init__(config)
self.class_means = None
self.precision = None
self.APS_mode = False
self.hyperparam_search_done = True
# Set the device dynamically based on availability
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
print(f"Computing Mahalanobis statistics on {self.device}...")
features_list = []
labels_list = []
net.eval()
# Collect features
for batch in tqdm(id_loader_dict['train']):
# Move data and labels to the detected device
data = batch['data'].to(self.device)
label = batch['label'].to(self.device)
# Get features
_, features = net(data, return_feature=True)
# Added normalization
features = features / (features.norm(dim=1, keepdim=True) + 1e-10)
features_list.append(features)
labels_list.append(label)
features = torch.cat(features_list, dim=0)
labels = torch.cat(labels_list, dim=0)
n_classes = labels.max().item() + 1
feat_dim = features.size(1)
# Compute class means directly on the device
class_means = torch.zeros(n_classes, feat_dim, device=self.device)
centered = torch.zeros_like(features)
for c in range(n_classes):
mask = labels == c
class_feats = features[mask]
class_means[c] = class_feats.mean(dim=0)
centered[mask] = class_feats - class_means[c]
# Compute covariance
cov = centered.t().mm(centered) / features.size(0)
# Regularization (initialized directly on the device)
cov += 1e-4 * torch.eye(feat_dim, device=self.device)
precision = torch.linalg.inv(cov)
self.class_means = class_means
self.precision = precision
print("Mahalanobis setup complete.")
@torch.no_grad()
def postprocess(self, net: nn.Module, data):
_, features = net(data, return_feature=True)
# Added normalization
features = features / (features.norm(dim=1, keepdim=True) + 1e-10)
# Compute distances
diff = features.unsqueeze(1) - self.class_means.unsqueeze(0)
left = torch.matmul(diff, self.precision)
dist = (left * diff).sum(dim=2)
# Mahalanobis score = negative distance
score = -dist.min(dim=1)[0]
pred = torch.zeros_like(score).long() # not used
return pred, score |