File size: 7,125 Bytes
c679d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Module for dataset and dataloaders of UCSD dataset.
"""

import re
from pathlib import Path
from typing import List, Tuple, Optional
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset

from src.data.video_transforms import transform


class UCSDDataset(Dataset):
    """
    UCSD Anomaly Detection Dataset.
    
    Train: only normal clips.
    Test: clips with frame-level ground truth annotations.
    
    Args:
        root: Dataset root path (containing UCSDped1/, UCSDped2/)
        subset: 'Ped1' or 'Ped2'
        split: 'train' or 'test'
        window_size: Number of frames per sample (sliding window)
        stride: Stride between windows
        transform: Optional transform applied to each frame
    """
    
    def __init__(
        self,
        root: str,
        subset: str = "Ped2",
        split: str = "train",
        window_size: int = 16,
        stride: int = 8,
        mode: str = "reconstruction",
        transform: Optional[callable] = None,
        clip_indices: Optional[List[int]] = None
    ):
        super().__init__()
        self.root = Path(root)
        self.subset = subset.lower()
        self.split = split
        self.window_size = window_size
        self.stride = stride
        self.mode = mode
        self.transform = transform
        self.clip_indices = clip_indices

        # Subset check
        assert subset in ("ped1", "ped2"), f"subset must be ped1 or ped2, got {subset}"

        # Read the subset and store the clip directories
        self.subset_split = self.root / f"UCSD{self.subset}" / f"{split.title()}"

        # Sanity check to ensure the files and clip directories exist
        if not self.subset_split.exists():
            raise FileNotFoundError(f"Dataset path not found: {self.subset_split}")
        
        self.clip_dirs = sorted([
            d for d in self.subset_split.iterdir() 
            if d.is_dir() and not d.name.endswith("_gt")
        ])

        # Filter out the clips
        if self.clip_indices is not None:
            self.clip_dirs = [self.clip_dirs[i] for i in self.clip_indices]
        
        if len(self.clip_dirs) == 0:
            raise RuntimeError(f"No clip directories found in {self.subset_split}")

        # Collect the clip paths
        self.clips = []
        for clip_dir in self.clip_dirs:  # clip_dir = Path("Train001")
            frame_paths = sorted(clip_dir.glob("*.tif"))  # liste of frame paths
            frames = np.stack([np.array(Image.open(p)) for p in frame_paths])
            self.clips.append(frames)

        # Create labels based on split
        if self.split == "test":
            m_file = self.subset_split / f"UCSD{subset}.m"  # path case dikkat
            content = m_file.read_text()
            matches = re.findall(r"\[(\d+):(\d+)\]", content)
            
            self.labels = []
            for clip_idx, (start_str, end_str) in enumerate(matches):
                start, end = int(start_str), int(end_str)
                n_frames = len(self.clips[clip_idx])  # clip's frame length
                label = np.zeros(n_frames, dtype=np.int64)
                label[start-1:end] = 1  # 1-indexed -> 0-indexed slice
                self.labels.append(label)
        else:
            self.labels = None  # train, no label

        # Collect the window indexes
        self.windows = []  # list of (clip_idx, start_frame)
        for clip_idx, frames in enumerate(self.clips):
            n_frames = len(frames)
            for start in range(0, n_frames - window_size + 1, stride):
                self.windows.append((clip_idx, start))

    def __len__(self) -> int:
        return len(self.windows)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
           frames: (T, C, H, W) tensor
           label: (T,) tensor of 0/1 (train: all zeros, test: from gt)
        """
        # Read frames and label
        clip_idx, start_frame = self.windows[idx]
        
        # Take the frames within frame range
        window_frames = self.clips[clip_idx][start_frame : start_frame + self.window_size] # shape: (T, H, W) uint8

        # Check labels based on split
        if self.split == "test":
            labels_np = self.labels[clip_idx][start_frame : start_frame + self.window_size]
            labels = torch.from_numpy(labels_np)  # int64 tensor
        else:
            labels = torch.zeros(self.window_size, dtype=torch.long)

        # Convert window array to tensor and reshape it
        window_tensor = torch.from_numpy(window_frames).float() / 255.0
        window_tensor = window_tensor.unsqueeze(1)  # (T, H, W) -> (T, 1, H, W)
        
        # Check for transforms
        if self.transform is not None:
            window_tensor = self.transform(window_tensor)

        if self.mode == "prediction":
            input_frames = window_tensor[:-1]   # (15, 1, H, W) — first 15 window
            target_frame = window_tensor[-1]    # (1, H, W) — last frame, target
            return input_frames, target_frame
        else:
            return window_tensor, labels

if __name__ == "__main__":
    # Run sanity check
    train_clips = [0,1,2,3,4,5,6,7,8,9,10,11,12]   # 13 clip
    val_clips   = [13,14,15]                       # 3 clip

    # Train
    ds_train = UCSDDataset(root="data/ucsd/raw", subset="ped2", clip_indices=train_clips, transform=transform, split="train")
    print(f"Train: {len(ds_train.clips)} clips, {len(ds_train)} windows")
    print(f"First clip shape: {ds_train.clips[0].shape}")

    # Validation
    ds_val = UCSDDataset(root="data/ucsd/raw", subset="ped2", clip_indices=val_clips, transform=transform, split="train")
    print(f"Val: {len(ds_val.clips)} clips, {len(ds_val)} windows")
    print(f"Val labels: {ds_val.labels}") # Should be None

    # Test
    ds_test = UCSDDataset(root="data/ucsd/raw", subset="ped2", split="test", transform=transform)
    print(f"Test: {len(ds_test.clips)} clips, {len(ds_test)} windows")
    print(f"First label sum: {ds_test.labels[0].sum()}/{len(ds_test.labels[0])}")

    # Test getitem
    sample, label = ds_train[0]
    print(f"\nSample 0 (train):")
    print(f"  Sample shape: {sample.shape}, dtype: {sample.dtype}")
    print(f"  Sample range: [{sample.min():.3f}, {sample.max():.3f}]")
    print(f"  Label shape: {label.shape}, sum: {label.sum()}")

    sample, label = ds_val[0]
    print(f"\nSample 0 (test):")
    print(f"  Sample shape: {sample.shape}")
    print(f"  Label shape: {label.shape}, sum: {label.sum()}")

    # Random middle sample
    sample, label = ds_train[len(ds_train) // 2]
    print(f"\nMiddle train sample shape: {sample.shape}")

    # Transform check
    print(sample.shape)  # torch.Size([16, 1, 128, 128])

    # Prediction
    ds = UCSDDataset(root="data/ucsd/raw", subset="ped2", split="train",
                    clip_indices=list(range(13)), transform=transform, mode="prediction")
    inp, tgt = ds[0]
    print(f"input: {inp.shape}")    # expected (15, 1, 128, 128)
    print(f"target: {tgt.shape}")   # expected (1, 128, 128)