Zorrojurro commited on
Commit
a2b8a0b
·
verified ·
1 Parent(s): 6b1b45c

Upload src/models/anomaly_detector.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models/anomaly_detector.py +184 -0
src/models/anomaly_detector.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anomaly Detector — Cosine-similarity-based anomaly scoring.
3
+
4
+ Compares temporal pattern encodings against a learned "normal"
5
+ baseline to flag abnormal heat-distribution sequences.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ from typing import Optional
13
+
14
+
15
+ class AnomalyDetector(nn.Module):
16
+ """
17
+ Anomaly detector using cosine similarity against a
18
+ reference baseline embedding.
19
+
20
+ During training, the baseline is updated as a running mean of
21
+ embeddings from *normal* sequences. At inference, an anomaly
22
+ score is produced: 1 − cosine_similarity.
23
+
24
+ Attributes:
25
+ threshold: similarity below this → abnormal.
26
+ baseline: registered buffer; running-average normal embedding.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ embedding_dim: int = 256,
32
+ threshold: float = 0.7,
33
+ momentum: float = 0.99,
34
+ ):
35
+ super().__init__()
36
+ self.threshold = threshold
37
+ self.momentum = momentum
38
+
39
+ # The normal-pattern baseline (non-trainable, persisted with model)
40
+ self.register_buffer(
41
+ "baseline", torch.zeros(embedding_dim)
42
+ )
43
+ self.register_buffer(
44
+ "baseline_initialised", torch.tensor(False)
45
+ )
46
+
47
+ @classmethod
48
+ def from_config(cls, config) -> "AnomalyDetector":
49
+ """Construct from a Config object."""
50
+ ad = config.model.anomaly_detector
51
+ fe = config.model.feature_extractor
52
+ return cls(
53
+ embedding_dim=fe.embedding_dim,
54
+ threshold=ad.threshold,
55
+ )
56
+
57
+ # ------------------------------------------------------------------
58
+ # Baseline management
59
+ # ------------------------------------------------------------------
60
+
61
+ @torch.no_grad()
62
+ def update_baseline(self, normal_embeddings: torch.Tensor):
63
+ """
64
+ Update the running-average baseline with new normal embeddings.
65
+
66
+ Args:
67
+ normal_embeddings: (N, D) embeddings from normal sequences.
68
+ """
69
+ batch_mean = normal_embeddings.mean(dim=0)
70
+ if not self.baseline_initialised:
71
+ self.baseline.copy_(batch_mean)
72
+ self.baseline_initialised.fill_(True)
73
+ else:
74
+ self.baseline.mul_(self.momentum).add_(
75
+ batch_mean, alpha=1.0 - self.momentum
76
+ )
77
+
78
+ def set_baseline(self, baseline: torch.Tensor):
79
+ """Directly set the baseline embedding."""
80
+ self.baseline.copy_(baseline)
81
+ self.baseline_initialised.fill_(True)
82
+
83
+ # ------------------------------------------------------------------
84
+ # Scoring
85
+ # ------------------------------------------------------------------
86
+
87
+ def compute_similarity(self, embeddings: torch.Tensor) -> torch.Tensor:
88
+ """
89
+ Cosine similarity between each embedding and the baseline.
90
+
91
+ Args:
92
+ embeddings: (B, D)
93
+
94
+ Returns:
95
+ similarities: (B,) in range [-1, 1].
96
+ """
97
+ baseline = self.baseline.unsqueeze(0) # (1, D)
98
+ return F.cosine_similarity(embeddings, baseline, dim=1)
99
+
100
+ def compute_anomaly_score(self, embeddings: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Anomaly score = 1 − similarity.
103
+ Higher score → more abnormal.
104
+ """
105
+ return 1.0 - self.compute_similarity(embeddings)
106
+
107
+ def forward(self, embeddings: torch.Tensor) -> dict:
108
+ """
109
+ Full anomaly detection inference.
110
+
111
+ Args:
112
+ embeddings: (B, D) temporal pattern encodings.
113
+
114
+ Returns:
115
+ dict with keys:
116
+ similarity_score: (B,)
117
+ anomaly_score: (B,)
118
+ is_normal: (B,) boolean
119
+ confidence: (B,) distance from threshold
120
+ """
121
+ similarity = self.compute_similarity(embeddings)
122
+ anomaly_score = 1.0 - similarity
123
+ is_normal = similarity >= self.threshold
124
+ confidence = torch.abs(similarity - self.threshold)
125
+
126
+ return {
127
+ "similarity_score": similarity,
128
+ "anomaly_score": anomaly_score,
129
+ "is_normal": is_normal,
130
+ "confidence": confidence,
131
+ }
132
+
133
+
134
+ class ThermalPatternPipeline(nn.Module):
135
+ """
136
+ End-to-end pipeline combining all three stages:
137
+ 1. ThermalFeatureExtractor (CNN)
138
+ 2. SequenceAnalyzer (LSTM + Attention)
139
+ 3. AnomalyDetector (Cosine similarity)
140
+
141
+ Accepts raw image sequences and returns anomaly predictions.
142
+ """
143
+
144
+ def __init__(self, feature_extractor, sequence_analyzer, anomaly_detector):
145
+ super().__init__()
146
+ self.feature_extractor = feature_extractor
147
+ self.sequence_analyzer = sequence_analyzer
148
+ self.anomaly_detector = anomaly_detector
149
+
150
+ @classmethod
151
+ def from_config(cls, config) -> "ThermalPatternPipeline":
152
+ """Build the entire pipeline from a Config object."""
153
+ from src.models.feature_extractor import ThermalFeatureExtractor
154
+ from src.models.sequence_analyzer import SequenceAnalyzer
155
+
156
+ fe = ThermalFeatureExtractor.from_config(config)
157
+ sa = SequenceAnalyzer.from_config(config)
158
+ ad = AnomalyDetector.from_config(config)
159
+ return cls(fe, sa, ad)
160
+
161
+ def forward(self, sequences: torch.Tensor) -> dict:
162
+ """
163
+ End-to-end forward pass.
164
+
165
+ Args:
166
+ sequences: (B, T, 1, H, W)
167
+
168
+ Returns:
169
+ dict with anomaly_detector outputs + attention_weights.
170
+ """
171
+ # 1. Extract per-frame features → (B, T, D)
172
+ features = self.feature_extractor.extract_features_from_sequence(
173
+ sequences
174
+ )
175
+
176
+ # 2. Temporal analysis → (B, D), (B, T) attention
177
+ encoding, attn_weights = self.sequence_analyzer(features)
178
+
179
+ # 3. Anomaly detection
180
+ results = self.anomaly_detector(encoding)
181
+ results["attention_weights"] = attn_weights
182
+ results["encoding"] = encoding
183
+
184
+ return results