File size: 6,441 Bytes
c679d56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Memory-Augmented Autoencoder (MemAE) for video anomaly detection.
Encoder/decoder backbone = M1 AutoEncoder, unchanged. A memory module is
inserted between them: the decoder can only reconstruct from stored normal
prototypes, so anomalies (absent from memory) reconstruct poorly.
Ref: Gong et al. 2019, "Memorizing Normality to Detect Anomaly" (ICCV).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MemoryModule(nn.Module):
"""
Memory bank with sparse attention-based addressing.
Forward: bottleneck feature -> queries -> address memory -> reconstructed
feature (+ attention weights for entropy loss / sparsity inspection).
Args:
n_slots: N, number of memory items (paper value)
feat_dim: C, dimension of each memory item = bottleneck channel dim (16)
shrink_thres: lambda, sparse addressing threshold (paper value)
"""
def __init__(self, n_slots: int, feat_dim: int = 16, shrink_thres: float = None):
super().__init__()
self.n_slots = n_slots
self.feat_dim = feat_dim
# Edge case check
if shrink_thres is None:
shrink_thres = 1.0 / n_slots # lambda, dependant on N
self.shrink_thres = shrink_thres
# Memory bank: learnable (N, C) matrix, trained end-to-end via backprop
self.memory = nn.Parameter(torch.randn(size=(self.n_slots, self.feat_dim)))
def forward(self, z: torch.Tensor):
"""
Args:
z: bottleneck feature, shape (B, C, T, H, W) = (B, 16, 4, 16, 16)
Returns:
z_hat: reconstructed feature, same shape as z
attn: attention weights, shape (B, n_queries, N) -- for loss/viz
"""
B, C, T, H, W = z.shape
n_queries = T * H * W # 4*16*16 = 1024
# (B,C,T,H,W) -> queries (B, n_queries, C)
z = z.permute(dims=(0, 2, 3, 4, 1))
query = z.reshape(shape=(B, n_queries, C)) # shape (B, n_queries, C)
# Cosine similarity: query vs her memory slot.
query_n = F.normalize(query, dim=-1) # (B, n_queries, C)
memory_n = F.normalize(self.memory, dim=-1) # (N, C)
sim = query_n @ memory_n.t() # (B, n_queries, N)
# Softmax over N
attn = F.softmax(sim, dim=-1) # -1 dim to autocalculate dimensions, shape (B, n_queries, N)
# Sparse addressing: hard shrinkage + renormalize
eps = 1e-12
# hard shrinkage
attn = F.relu(attn - self.shrink_thres) * attn / (torch.abs(attn - self.shrink_thres) + eps)
# renormalize
attn = attn / (attn.sum(dim=-1, keepdim=True) + eps)
# Weighted sum: with plain (unnormalized) memory
z_hat_flat = attn @ self.memory # (B,n_queries,N) @ (N,C) = (B,n_queries,C)
# queries -> (B,C,T,H,W) backwards
# Backwards of the first step: reshape -> (B,T,H,W,C), after permute -> (B,C,T,H,W)
z_hat = z_hat_flat.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3) # (B,C,T,H,W)
return z_hat, attn
class MemoryAE(nn.Module):
"""
M1 encoder + MemoryModule + M1 decoder.
Encoder/decoder backbone unchanged from M1 (clean ablation).
"""
def __init__(self, n_slots: int, shrink_thres: float = None):
super().__init__()
# Edge case check
if shrink_thres is None:
shrink_thres = 1.0 / n_slots # lambda, dependant on N
# Encoder layers
self.encoder = nn.Sequential(
# enc1
nn.Conv3d(1, 16, (3,3,3), stride=(1,2,2), padding=1),
nn.GroupNorm(8, 16),
nn.LeakyReLU(),
# enc2
nn.Conv3d(16, 32, (3,3,3), stride=(2,2,2), padding=1),
nn.GroupNorm(8, 32),
nn.LeakyReLU(),
# enc3
nn.Conv3d(32, 64, (3,3,3), stride=(2,2,2), padding=1),
nn.GroupNorm(8, 64),
nn.LeakyReLU(),
# bottleneck — encoder's last piece
nn.Conv3d(64, 16, (3,3,3), stride=(1,1,1), padding=1),
nn.GroupNorm(8, 16),
nn.LeakyReLU(),
)
# Memory layer
self.memory = MemoryModule(n_slots=n_slots, feat_dim=16, shrink_thres=shrink_thres)
# Decoder layers
self.decoder = nn.Sequential(
# dec1
nn.ConvTranspose3d(16, 32, (3,3,3), stride=(2,2,2), padding=1, output_padding=(1,1,1)),
nn.GroupNorm(8, 32),
nn.LeakyReLU(),
# dec2
nn.ConvTranspose3d(32, 16, (3,3,3), stride=(2,2,2), padding=1, output_padding=(1,1,1)),
nn.GroupNorm(8, 16),
nn.LeakyReLU(),
# dec3
nn.ConvTranspose3d(16, 1, (3,3,3), stride=(1,2,2), padding=1, output_padding=(0,1,1)),
nn.Tanh(),
)
def forward(self, x: torch.Tensor):
"""
Args:
x: (B, T, C, H, W) -- loader format (same with M1)
Returns:
recon: (B, T, C, H, W)
attn: (B, n_queries, N)
"""
# M1 permute logic: loader (B,T,C,H,W) -> conv (B,C,T,H,W)
x = x.permute(0, 2, 1, 3, 4) # (B,C,T,H,W)
# encoder -> bottleneck
z = self.encoder(x) # (B, 16, 4, 16, 16)
# memory addressing
z_hat, attn = self.memory(z) # (B, 16, 4, 16, 16), (B, 1024, N)
# decoder
recon = self.decoder(z_hat) # (B, C, T, H, W)
# permute backwards to loader format
recon = recon.permute(0, 2, 1, 3, 4) # (B,T,C,H,W)
return recon, attn
if __name__ == "__main__":
# Smoke test
model = MemoryAE(n_slots=2000) # paper N
x = torch.randn(2, 16, 1, 128, 128)
# Control piece by piece
xp = x.permute(0,2,1,3,4) # (2,16,1,128,128) -> (2,1,16,128,128)
z = model.encoder(xp)
print("bottleneck:", z.shape) # should be (2, 16, 4, 16, 16)
z_hat, attn = model.memory(z)
print("z_hat:", z_hat.shape, "attn:", attn.shape) # (2,16,4,16,16), (2,1024,2000)
recon, attn = model(x)
active_frac = (attn > 0).float().mean()
active_per_query = (attn > 0).float().sum(dim=-1).mean()
print(f"shrink_thres (lambda): {model.memory.shrink_thres}")
print(f"active slot fraction: {active_frac:.4f}")
print(f"avg active slots/query: {active_per_query:.1f} / {model.memory.n_slots}") |