File size: 6,153 Bytes
a2b8a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""

Anomaly Detector — Cosine-similarity-based anomaly scoring.



Compares temporal pattern encodings against a learned "normal"

baseline to flag abnormal heat-distribution sequences.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional


class AnomalyDetector(nn.Module):
    """

    Anomaly detector using cosine similarity against a

    reference baseline embedding.



    During training, the baseline is updated as a running mean of

    embeddings from *normal* sequences.  At inference, an anomaly

    score is produced:  1 − cosine_similarity.



    Attributes:

        threshold: similarity below this → abnormal.

        baseline:  registered buffer; running-average normal embedding.

    """

    def __init__(

        self,

        embedding_dim: int = 256,

        threshold: float = 0.7,

        momentum: float = 0.99,

    ):
        super().__init__()
        self.threshold = threshold
        self.momentum = momentum

        # The normal-pattern baseline (non-trainable, persisted with model)
        self.register_buffer(
            "baseline", torch.zeros(embedding_dim)
        )
        self.register_buffer(
            "baseline_initialised", torch.tensor(False)
        )

    @classmethod
    def from_config(cls, config) -> "AnomalyDetector":
        """Construct from a Config object."""
        ad = config.model.anomaly_detector
        fe = config.model.feature_extractor
        return cls(
            embedding_dim=fe.embedding_dim,
            threshold=ad.threshold,
        )

    # ------------------------------------------------------------------
    # Baseline management
    # ------------------------------------------------------------------

    @torch.no_grad()
    def update_baseline(self, normal_embeddings: torch.Tensor):
        """

        Update the running-average baseline with new normal embeddings.



        Args:

            normal_embeddings: (N, D) embeddings from normal sequences.

        """
        batch_mean = normal_embeddings.mean(dim=0)
        if not self.baseline_initialised:
            self.baseline.copy_(batch_mean)
            self.baseline_initialised.fill_(True)
        else:
            self.baseline.mul_(self.momentum).add_(
                batch_mean, alpha=1.0 - self.momentum
            )

    def set_baseline(self, baseline: torch.Tensor):
        """Directly set the baseline embedding."""
        self.baseline.copy_(baseline)
        self.baseline_initialised.fill_(True)

    # ------------------------------------------------------------------
    # Scoring
    # ------------------------------------------------------------------

    def compute_similarity(self, embeddings: torch.Tensor) -> torch.Tensor:
        """

        Cosine similarity between each embedding and the baseline.



        Args:

            embeddings: (B, D)



        Returns:

            similarities: (B,) in range [-1, 1].

        """
        baseline = self.baseline.unsqueeze(0)  # (1, D)
        return F.cosine_similarity(embeddings, baseline, dim=1)

    def compute_anomaly_score(self, embeddings: torch.Tensor) -> torch.Tensor:
        """

        Anomaly score = 1 − similarity.

        Higher score → more abnormal.

        """
        return 1.0 - self.compute_similarity(embeddings)

    def forward(self, embeddings: torch.Tensor) -> dict:
        """

        Full anomaly detection inference.



        Args:

            embeddings: (B, D) temporal pattern encodings.



        Returns:

            dict with keys:

                similarity_score: (B,)

                anomaly_score:    (B,)

                is_normal:        (B,) boolean

                confidence:       (B,) distance from threshold

        """
        similarity = self.compute_similarity(embeddings)
        anomaly_score = 1.0 - similarity
        is_normal = similarity >= self.threshold
        confidence = torch.abs(similarity - self.threshold)

        return {
            "similarity_score": similarity,
            "anomaly_score": anomaly_score,
            "is_normal": is_normal,
            "confidence": confidence,
        }


class ThermalPatternPipeline(nn.Module):
    """

    End-to-end pipeline combining all three stages:

        1. ThermalFeatureExtractor  (CNN)

        2. SequenceAnalyzer         (LSTM + Attention)

        3. AnomalyDetector          (Cosine similarity)



    Accepts raw image sequences and returns anomaly predictions.

    """

    def __init__(self, feature_extractor, sequence_analyzer, anomaly_detector):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.sequence_analyzer = sequence_analyzer
        self.anomaly_detector = anomaly_detector

    @classmethod
    def from_config(cls, config) -> "ThermalPatternPipeline":
        """Build the entire pipeline from a Config object."""
        from src.models.feature_extractor import ThermalFeatureExtractor
        from src.models.sequence_analyzer import SequenceAnalyzer

        fe = ThermalFeatureExtractor.from_config(config)
        sa = SequenceAnalyzer.from_config(config)
        ad = AnomalyDetector.from_config(config)
        return cls(fe, sa, ad)

    def forward(self, sequences: torch.Tensor) -> dict:
        """

        End-to-end forward pass.



        Args:

            sequences: (B, T, 1, H, W)



        Returns:

            dict with anomaly_detector outputs + attention_weights.

        """
        # 1. Extract per-frame features  →  (B, T, D)
        features = self.feature_extractor.extract_features_from_sequence(
            sequences
        )

        # 2. Temporal analysis  →  (B, D), (B, T) attention
        encoding, attn_weights = self.sequence_analyzer(features)

        # 3. Anomaly detection
        results = self.anomaly_detector(encoding)
        results["attention_weights"] = attn_weights
        results["encoding"] = encoding

        return results