File size: 7,031 Bytes
c679d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Module for per-frame anomaly scoring on UCSD test split.

Pipeline: model reconstruction -> per-frame error -> overlapping-window
averaging -> per-clip frame-aligned anomaly scores.
"""

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from scipy.ndimage import gaussian_filter1d
from src.data.ucsd_loader import UCSDDataset
from src.models.autoencoder import AutoEncoder
from src.data.video_transforms import transform


def smooth_scores(scores: np.ndarray, sigma: float = 2.0) -> np.ndarray:
    """Temporal Gaussian smoothing on a single clip's per-frame scores."""
    return gaussian_filter1d(scores, sigma=sigma)


def compute_frame_errors(model: nn.Module, dataset: UCSDDataset, device: str) -> dict:
    """
    Compute per-frame reconstruction error for every clip in the test set,
    averaging across overlapping windows.

    Returns:
        dict mapping clip_idx -> (scores, labels)
        - scores: np.ndarray shape (n_frames,), avg reconstruction error per frame
        - labels: np.ndarray shape (n_frames,), 0/1 ground truth per frame
    """
    model.eval()

    # Prepare an accumulator for every clip
    error_sum = {}
    count = {}
    for clip_idx in range(len(dataset.clips)):
        n_frames = len(dataset.clips[clip_idx])
        error_sum[clip_idx] = np.zeros(n_frames, dtype=np.float64)
        count[clip_idx] = np.zeros(n_frames, dtype=np.float64)

    # Give out the windows towards the model
    loader = DataLoader(dataset, batch_size=1, shuffle=False)

    with torch.no_grad():
        for idx, (window, _labels) in enumerate(loader):
            # Window shape: (1, T, C, H, W)  -- batch size 1
            window = window.to(device)

            # Reconstruction
            out = model(window)
            recon = out[0] if isinstance(out, tuple) else out

            # Calculate per-frame error with taking the mean based on (C,H,W) channels
            per_frame_err = torch.mean(((window - recon)**2), dim=(0, 2, 3, 4)).cpu().numpy()   # shape: (T,)

            # Take a particular window from a particular clip
            clip_idx, start_frame = dataset.windows[idx]

            # For every t, global frame = start_frame + t
            error_sum[clip_idx][start_frame : start_frame + dataset.window_size] += per_frame_err
            count[clip_idx][start_frame : start_frame + dataset.window_size] += 1

    # Ortalama al + ground truth'u hizala
    results = {}
    for clip_idx in error_sum:
        # Counts and errors
        counts = count[clip_idx]
        errs = error_sum[clip_idx]

        # Log the number of frames that aren't valid
        print(f"clip {clip_idx}: {(counts==0).sum()} frames with no window coverage")

        # Valid frame filter
        valid = counts > 0
        
        # Take out the average which gives the result of average error
        scores = errs[valid] / counts[valid]          # Only valid frames
        scores = smooth_scores(scores, sigma=1.0)     # Clip based smoothing
        labels = dataset.labels[clip_idx][valid]      # Apply same mask

        results[clip_idx] = (scores, labels)

    return results


def aggregate_all(results: dict) -> tuple:
    """
    Flatten per-clip results into two 1D arrays for global AUC.

    Returns:
        all_scores: np.ndarray (total_frames,)
        all_labels: np.ndarray (total_frames,)
    """
    scores_list = []
    labels_list = []

    # Append corresponding clip's (scores, labels) by order
    for clip_idx in results:
        scores, labels = results[clip_idx]
        scores_list.append(scores)
        labels_list.append(labels)

    # Concatenate the results on 1D numpy array
    all_scores = np.concatenate(scores_list)
    all_labels = np.concatenate(labels_list)

    return all_scores, all_labels


def compute_prediction_errors(model: nn.Module, dataset: UCSDDataset, device: str) -> dict:
    """
    Per-frame prediction error for M3.
    Each window (15 input -> 1 target) scores ONE frame: the target frame
    at index (start_frame + 15) in its clip.
    """
    model.eval()

    # Per-clip accumulator. Many frames are never a target:
    # the first 15 frames of each clip are always inputs, never predicted.
    error_sum = {}
    count = {}
    for clip_idx in range(len(dataset.clips)):
        n_frames = len(dataset.clips[clip_idx])
        error_sum[clip_idx] = np.zeros(n_frames, dtype=np.float64)
        count[clip_idx] = np.zeros(n_frames, dtype=np.float64)

    loader = DataLoader(dataset, batch_size=1, shuffle=False)

    with torch.no_grad():
        for idx, (inputs, target) in enumerate(loader):
            # inputs: (1,15,1,H,W), target: (1,1,H,W)
            inputs, target = inputs.to(device), target.to(device)
            pred = model(inputs)                      # (1,1,H,W)

            # Single target frame -> one scalar error (mean over C,H,W)
            err = ((pred - target) ** 2).mean().item()

            # Which clip / which target frame does this window predict?
            clip_idx, start_frame = dataset.windows[idx]
            target_idx = start_frame + 15   # first 15 are inputs, 16th is target
            error_sum[clip_idx][target_idx] += err
            count[clip_idx][target_idx] += 1

    # Average + align ground truth (count>0 mask, like M1/M2).
    # NOTE: first 15 frames + uncovered frames have count==0, masked out.
    results = {}
    for clip_idx in error_sum:
        counts = count[clip_idx]
        errs = error_sum[clip_idx]

        # Log frames with no coverage (expected: at least the first 15)
        print(f"clip {clip_idx}: {(counts==0).sum()} frames with no prediction coverage")

        # Keep only frames that were predicted at least once
        valid = counts > 0

        scores = errs[valid] / counts[valid]          # average error per frame
        scores = smooth_scores(scores, sigma=1.0)     # clip-level temporal smoothing
        labels = dataset.labels[clip_idx][valid]      # same mask -> alignment

        results[clip_idx] = (scores, labels)

    return results


if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Modeli yükle (eğittiğin best checkpoint)
    model = AutoEncoder().to(device)
    ckpt = torch.load("checkpoints/ae_best.pt", map_location=device)
    model.load_state_dict(ckpt["model_state"])

    # Test dataset — clip_indices YOK (tüm 12 clip), split="test"
    test_ds = UCSDDataset(root="data/ucsd/raw", subset="ped2", split="test", transform=transform)

    results = compute_frame_errors(model, test_ds, device)
    all_scores, all_labels = aggregate_all(results)

    # Sanity check
    print(f"shape: {all_scores.shape}, {all_labels.shape}")          # same, 1D
    print(f"anomaly frames: {all_labels.sum()}/{len(all_labels)}")

    normal_mean  = all_scores[all_labels == 0].mean()
    anomaly_mean = all_scores[all_labels == 1].mean()
    print(f"normal mean error:  {normal_mean:.6f}")
    print(f"anomaly mean error: {anomaly_mean:.6f}")