File size: 2,468 Bytes
eb1fd70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cf7abf
 
 
eb1fd70
 
7cf7abf
eb1fd70
 
 
 
 
 
 
 
7cf7abf
 
 
eb1fd70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cf7abf
 
eb1fd70
 
 
 
 
 
 
 
 
 
 
7cf7abf
 
eb1fd70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
from tqdm import tqdm
from openood.postprocessors.base_postprocessor import BasePostprocessor


class MahalanobisPlusPlusPostprocessor(BasePostprocessor):
	def __init__(self, config):
		super().__init__(config)
		self.class_means = None
		self.precision = None

		self.APS_mode = False
		self.hyperparam_search_done = True
		
		# Set the device dynamically based on availability
		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		
	@torch.no_grad()
	def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
		print(f"Computing Mahalanobis statistics on {self.device}...")

		features_list = []
		labels_list = []

		net.eval()

		# Collect features
		for batch in tqdm(id_loader_dict['train']):
			# Move data and labels to the detected device
			data = batch['data'].to(self.device)
			label = batch['label'].to(self.device)

			# Get features
			_, features = net(data, return_feature=True)

			# Added normalization
			features = features / (features.norm(dim=1, keepdim=True) + 1e-10)
   
			features_list.append(features)
			labels_list.append(label)

		features = torch.cat(features_list, dim=0)
		labels = torch.cat(labels_list, dim=0)

		n_classes = labels.max().item() + 1
		feat_dim = features.size(1)

		# Compute class means directly on the device
		class_means = torch.zeros(n_classes, feat_dim, device=self.device)
		centered = torch.zeros_like(features)

		for c in range(n_classes):
			mask = labels == c
			class_feats = features[mask]
			class_means[c] = class_feats.mean(dim=0)
			centered[mask] = class_feats - class_means[c]

		# Compute covariance
		cov = centered.t().mm(centered) / features.size(0)

		# Regularization (initialized directly on the device)
		cov += 1e-4 * torch.eye(feat_dim, device=self.device)

		precision = torch.linalg.inv(cov)

		self.class_means = class_means
		self.precision = precision

		print("Mahalanobis setup complete.")

	@torch.no_grad()
	def postprocess(self, net: nn.Module, data):
		_, features = net(data, return_feature=True)
  		# Added normalization
		features = features / (features.norm(dim=1, keepdim=True) + 1e-10)
  
		# Compute distances
		diff = features.unsqueeze(1) - self.class_means.unsqueeze(0)
		left = torch.matmul(diff, self.precision)
		dist = (left * diff).sum(dim=2)

		# Mahalanobis score = negative distance
		score = -dist.min(dim=1)[0]

		pred = torch.zeros_like(score).long()  # not used

		return pred, score