File size: 6,153 Bytes
a2b8a0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """
Anomaly Detector — Cosine-similarity-based anomaly scoring.
Compares temporal pattern encodings against a learned "normal"
baseline to flag abnormal heat-distribution sequences.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional
class AnomalyDetector(nn.Module):
"""
Anomaly detector using cosine similarity against a
reference baseline embedding.
During training, the baseline is updated as a running mean of
embeddings from *normal* sequences. At inference, an anomaly
score is produced: 1 − cosine_similarity.
Attributes:
threshold: similarity below this → abnormal.
baseline: registered buffer; running-average normal embedding.
"""
def __init__(
self,
embedding_dim: int = 256,
threshold: float = 0.7,
momentum: float = 0.99,
):
super().__init__()
self.threshold = threshold
self.momentum = momentum
# The normal-pattern baseline (non-trainable, persisted with model)
self.register_buffer(
"baseline", torch.zeros(embedding_dim)
)
self.register_buffer(
"baseline_initialised", torch.tensor(False)
)
@classmethod
def from_config(cls, config) -> "AnomalyDetector":
"""Construct from a Config object."""
ad = config.model.anomaly_detector
fe = config.model.feature_extractor
return cls(
embedding_dim=fe.embedding_dim,
threshold=ad.threshold,
)
# ------------------------------------------------------------------
# Baseline management
# ------------------------------------------------------------------
@torch.no_grad()
def update_baseline(self, normal_embeddings: torch.Tensor):
"""
Update the running-average baseline with new normal embeddings.
Args:
normal_embeddings: (N, D) embeddings from normal sequences.
"""
batch_mean = normal_embeddings.mean(dim=0)
if not self.baseline_initialised:
self.baseline.copy_(batch_mean)
self.baseline_initialised.fill_(True)
else:
self.baseline.mul_(self.momentum).add_(
batch_mean, alpha=1.0 - self.momentum
)
def set_baseline(self, baseline: torch.Tensor):
"""Directly set the baseline embedding."""
self.baseline.copy_(baseline)
self.baseline_initialised.fill_(True)
# ------------------------------------------------------------------
# Scoring
# ------------------------------------------------------------------
def compute_similarity(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Cosine similarity between each embedding and the baseline.
Args:
embeddings: (B, D)
Returns:
similarities: (B,) in range [-1, 1].
"""
baseline = self.baseline.unsqueeze(0) # (1, D)
return F.cosine_similarity(embeddings, baseline, dim=1)
def compute_anomaly_score(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Anomaly score = 1 − similarity.
Higher score → more abnormal.
"""
return 1.0 - self.compute_similarity(embeddings)
def forward(self, embeddings: torch.Tensor) -> dict:
"""
Full anomaly detection inference.
Args:
embeddings: (B, D) temporal pattern encodings.
Returns:
dict with keys:
similarity_score: (B,)
anomaly_score: (B,)
is_normal: (B,) boolean
confidence: (B,) distance from threshold
"""
similarity = self.compute_similarity(embeddings)
anomaly_score = 1.0 - similarity
is_normal = similarity >= self.threshold
confidence = torch.abs(similarity - self.threshold)
return {
"similarity_score": similarity,
"anomaly_score": anomaly_score,
"is_normal": is_normal,
"confidence": confidence,
}
class ThermalPatternPipeline(nn.Module):
"""
End-to-end pipeline combining all three stages:
1. ThermalFeatureExtractor (CNN)
2. SequenceAnalyzer (LSTM + Attention)
3. AnomalyDetector (Cosine similarity)
Accepts raw image sequences and returns anomaly predictions.
"""
def __init__(self, feature_extractor, sequence_analyzer, anomaly_detector):
super().__init__()
self.feature_extractor = feature_extractor
self.sequence_analyzer = sequence_analyzer
self.anomaly_detector = anomaly_detector
@classmethod
def from_config(cls, config) -> "ThermalPatternPipeline":
"""Build the entire pipeline from a Config object."""
from src.models.feature_extractor import ThermalFeatureExtractor
from src.models.sequence_analyzer import SequenceAnalyzer
fe = ThermalFeatureExtractor.from_config(config)
sa = SequenceAnalyzer.from_config(config)
ad = AnomalyDetector.from_config(config)
return cls(fe, sa, ad)
def forward(self, sequences: torch.Tensor) -> dict:
"""
End-to-end forward pass.
Args:
sequences: (B, T, 1, H, W)
Returns:
dict with anomaly_detector outputs + attention_weights.
"""
# 1. Extract per-frame features → (B, T, D)
features = self.feature_extractor.extract_features_from_sequence(
sequences
)
# 2. Temporal analysis → (B, D), (B, T) attention
encoding, attn_weights = self.sequence_analyzer(features)
# 3. Anomaly detection
results = self.anomaly_detector(encoding)
results["attention_weights"] = attn_weights
results["encoding"] = encoding
return results
|