| """ |
| Memory-Augmented Autoencoder (MemAE) for video anomaly detection. |
| |
| Encoder/decoder backbone = M1 AutoEncoder, unchanged. A memory module is |
| inserted between them: the decoder can only reconstruct from stored normal |
| prototypes, so anomalies (absent from memory) reconstruct poorly. |
| |
| Ref: Gong et al. 2019, "Memorizing Normality to Detect Anomaly" (ICCV). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class MemoryModule(nn.Module): |
| """ |
| Memory bank with sparse attention-based addressing. |
| |
| Forward: bottleneck feature -> queries -> address memory -> reconstructed |
| feature (+ attention weights for entropy loss / sparsity inspection). |
| |
| Args: |
| n_slots: N, number of memory items (paper value) |
| feat_dim: C, dimension of each memory item = bottleneck channel dim (16) |
| shrink_thres: lambda, sparse addressing threshold (paper value) |
| """ |
|
|
| def __init__(self, n_slots: int, feat_dim: int = 16, shrink_thres: float = None): |
| super().__init__() |
| self.n_slots = n_slots |
| self.feat_dim = feat_dim |
|
|
| |
| if shrink_thres is None: |
| shrink_thres = 1.0 / n_slots |
| self.shrink_thres = shrink_thres |
|
|
| |
| self.memory = nn.Parameter(torch.randn(size=(self.n_slots, self.feat_dim))) |
|
|
| def forward(self, z: torch.Tensor): |
| """ |
| Args: |
| z: bottleneck feature, shape (B, C, T, H, W) = (B, 16, 4, 16, 16) |
| Returns: |
| z_hat: reconstructed feature, same shape as z |
| attn: attention weights, shape (B, n_queries, N) -- for loss/viz |
| """ |
| B, C, T, H, W = z.shape |
| n_queries = T * H * W |
|
|
| |
| z = z.permute(dims=(0, 2, 3, 4, 1)) |
| query = z.reshape(shape=(B, n_queries, C)) |
|
|
| |
| query_n = F.normalize(query, dim=-1) |
| memory_n = F.normalize(self.memory, dim=-1) |
| sim = query_n @ memory_n.t() |
|
|
| |
| attn = F.softmax(sim, dim=-1) |
|
|
| |
| eps = 1e-12 |
| |
| attn = F.relu(attn - self.shrink_thres) * attn / (torch.abs(attn - self.shrink_thres) + eps) |
| |
| attn = attn / (attn.sum(dim=-1, keepdim=True) + eps) |
|
|
| |
| z_hat_flat = attn @ self.memory |
|
|
| |
| |
| z_hat = z_hat_flat.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3) |
|
|
| return z_hat, attn |
|
|
|
|
| class MemoryAE(nn.Module): |
| """ |
| M1 encoder + MemoryModule + M1 decoder. |
| Encoder/decoder backbone unchanged from M1 (clean ablation). |
| """ |
|
|
| def __init__(self, n_slots: int, shrink_thres: float = None): |
| super().__init__() |
|
|
| |
| if shrink_thres is None: |
| shrink_thres = 1.0 / n_slots |
|
|
| |
| self.encoder = nn.Sequential( |
| |
| nn.Conv3d(1, 16, (3,3,3), stride=(1,2,2), padding=1), |
| nn.GroupNorm(8, 16), |
| nn.LeakyReLU(), |
| |
| nn.Conv3d(16, 32, (3,3,3), stride=(2,2,2), padding=1), |
| nn.GroupNorm(8, 32), |
| nn.LeakyReLU(), |
| |
| nn.Conv3d(32, 64, (3,3,3), stride=(2,2,2), padding=1), |
| nn.GroupNorm(8, 64), |
| nn.LeakyReLU(), |
| |
| nn.Conv3d(64, 16, (3,3,3), stride=(1,1,1), padding=1), |
| nn.GroupNorm(8, 16), |
| nn.LeakyReLU(), |
| ) |
| |
| self.memory = MemoryModule(n_slots=n_slots, feat_dim=16, shrink_thres=shrink_thres) |
| |
| self.decoder = nn.Sequential( |
| |
| nn.ConvTranspose3d(16, 32, (3,3,3), stride=(2,2,2), padding=1, output_padding=(1,1,1)), |
| nn.GroupNorm(8, 32), |
| nn.LeakyReLU(), |
| |
| nn.ConvTranspose3d(32, 16, (3,3,3), stride=(2,2,2), padding=1, output_padding=(1,1,1)), |
| nn.GroupNorm(8, 16), |
| nn.LeakyReLU(), |
| |
| nn.ConvTranspose3d(16, 1, (3,3,3), stride=(1,2,2), padding=1, output_padding=(0,1,1)), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| """ |
| Args: |
| x: (B, T, C, H, W) -- loader format (same with M1) |
| Returns: |
| recon: (B, T, C, H, W) |
| attn: (B, n_queries, N) |
| """ |
| |
| x = x.permute(0, 2, 1, 3, 4) |
|
|
| |
| z = self.encoder(x) |
|
|
| |
| z_hat, attn = self.memory(z) |
|
|
| |
| recon = self.decoder(z_hat) |
|
|
| |
| recon = recon.permute(0, 2, 1, 3, 4) |
|
|
| return recon, attn |
| |
|
|
| if __name__ == "__main__": |
| |
| model = MemoryAE(n_slots=2000) |
| x = torch.randn(2, 16, 1, 128, 128) |
|
|
| |
| xp = x.permute(0,2,1,3,4) |
| z = model.encoder(xp) |
| print("bottleneck:", z.shape) |
|
|
| z_hat, attn = model.memory(z) |
| print("z_hat:", z_hat.shape, "attn:", attn.shape) |
|
|
| recon, attn = model(x) |
| active_frac = (attn > 0).float().mean() |
| active_per_query = (attn > 0).float().sum(dim=-1).mean() |
| print(f"shrink_thres (lambda): {model.memory.shrink_thres}") |
| print(f"active slot fraction: {active_frac:.4f}") |
| print(f"avg active slots/query: {active_per_query:.1f} / {model.memory.n_slots}") |