MSherbinii commited on
Commit
463a80f
·
verified ·
1 Parent(s): 0f5deb2

Add HF-compatible dataset loader

Browse files
Files changed (1) hide show
  1. dataset.py +203 -0
dataset.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IPAD Dataset Loader for HuggingFace Infrastructure
3
+ Loads data from HF Hub and provides PyTorch DataLoader compatible interface
4
+ """
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import cv2
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import zipfile
11
+ from huggingface_hub import hf_hub_download
12
+ import os
13
+ from typing import List, Tuple, Optional
14
+ import random
15
+
16
+ class IPADVideoDataset(Dataset):
17
+ """
18
+ IPAD Video Anomaly Detection Dataset
19
+
20
+ Args:
21
+ root_dir: Path to extracted dataset
22
+ device_name: Device ID (e.g., "S01", "S02", ..., "S12")
23
+ split: "train" or "test"
24
+ clip_length: Number of frames per clip (default: 16)
25
+ frame_size: Tuple of (height, width) for resizing (default: (256, 256))
26
+ stride: Frame sampling stride (default: 1)
27
+ normalize: Whether to normalize frames to [-1, 1]
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ root_dir: str,
33
+ device_name: str = "S01",
34
+ split: str = "train",
35
+ clip_length: int = 16,
36
+ frame_size: Tuple[int, int] = (256, 256),
37
+ stride: int = 1,
38
+ normalize: bool = True
39
+ ):
40
+ self.root_dir = Path(root_dir)
41
+ self.device_name = device_name
42
+ self.split = split
43
+ self.clip_length = clip_length
44
+ self.frame_size = frame_size
45
+ self.stride = stride
46
+ self.normalize = normalize
47
+
48
+ # Construct path to device frames
49
+ self.device_path = self.root_dir / device_name / split / "frames"
50
+
51
+ if not self.device_path.exists():
52
+ raise ValueError(f"Dataset path not found: {self.device_path}")
53
+
54
+ # Get all video directories
55
+ self.video_dirs = sorted([d for d in self.device_path.iterdir() if d.is_dir()])
56
+
57
+ # Build index of all valid clips
58
+ self.clips = []
59
+ for video_dir in self.video_dirs:
60
+ frames = sorted(list(video_dir.glob("*.jpg")) + list(video_dir.glob("*.png")))
61
+ num_frames = len(frames)
62
+
63
+ # Create clips with stride
64
+ for start_idx in range(0, num_frames - clip_length + 1, stride):
65
+ self.clips.append({
66
+ 'video_dir': video_dir,
67
+ 'start_idx': start_idx,
68
+ 'frames': frames[start_idx:start_idx + clip_length]
69
+ })
70
+
71
+ print(f"Loaded {len(self.clips)} clips from {device_name}/{split}")
72
+
73
+ def __len__(self) -> int:
74
+ return len(self.clips)
75
+
76
+ def __getitem__(self, idx: int) -> torch.Tensor:
77
+ clip_info = self.clips[idx]
78
+ frames = []
79
+
80
+ # Load and process each frame
81
+ for frame_path in clip_info['frames']:
82
+ frame = cv2.imread(str(frame_path))
83
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
84
+ frame = cv2.resize(frame, self.frame_size)
85
+
86
+ # Normalize to [0, 1]
87
+ frame = frame.astype(np.float32) / 255.0
88
+
89
+ # Normalize to [-1, 1] if requested
90
+ if self.normalize:
91
+ frame = (frame - 0.5) / 0.5
92
+
93
+ frames.append(frame)
94
+
95
+ # Convert to tensor: [T, H, W, C] -> [C, T, H, W]
96
+ frames = np.stack(frames, axis=0) # [T, H, W, C]
97
+ frames = torch.from_numpy(frames).permute(3, 0, 1, 2) # [C, T, H, W]
98
+
99
+ return frames
100
+
101
+
102
+ def download_and_extract_dataset(cache_dir: str = "./cache") -> Path:
103
+ """
104
+ Download IPAD dataset from HF Hub and extract it
105
+
106
+ Returns:
107
+ Path to extracted dataset directory
108
+ """
109
+ cache_dir = Path(cache_dir)
110
+ cache_dir.mkdir(exist_ok=True, parents=True)
111
+
112
+ extracted_path = cache_dir / "ipad_dataset"
113
+
114
+ if extracted_path.exists():
115
+ print(f"✅ Dataset already extracted at {extracted_path}")
116
+ return extracted_path
117
+
118
+ print("📥 Downloading dataset from HF Hub...")
119
+ zip_path = hf_hub_download(
120
+ repo_id="MSherbinii/ipad-industrial-anomaly",
121
+ filename="ipad_dataset.zip",
122
+ repo_type="dataset",
123
+ cache_dir=str(cache_dir)
124
+ )
125
+
126
+ print("📦 Extracting dataset...")
127
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
128
+ zip_ref.extractall(cache_dir)
129
+
130
+ print(f"✅ Dataset extracted to {extracted_path}")
131
+ return extracted_path
132
+
133
+
134
+ def create_dataloaders(
135
+ dataset_path: str,
136
+ device_name: str = "S01",
137
+ batch_size: int = 4,
138
+ num_workers: int = 4,
139
+ clip_length: int = 16,
140
+ frame_size: Tuple[int, int] = (256, 256)
141
+ ) -> Tuple[DataLoader, DataLoader]:
142
+ """
143
+ Create train and test DataLoaders for a specific device
144
+
145
+ Args:
146
+ dataset_path: Path to extracted IPAD dataset
147
+ device_name: Device ID (e.g., "S01")
148
+ batch_size: Batch size for DataLoader
149
+ num_workers: Number of worker processes
150
+ clip_length: Frames per clip
151
+ frame_size: Frame dimensions
152
+
153
+ Returns:
154
+ Tuple of (train_loader, test_loader)
155
+ """
156
+ train_dataset = IPADVideoDataset(
157
+ root_dir=dataset_path,
158
+ device_name=device_name,
159
+ split="train",
160
+ clip_length=clip_length,
161
+ frame_size=frame_size,
162
+ stride=clip_length // 2 # 50% overlap for training
163
+ )
164
+
165
+ test_dataset = IPADVideoDataset(
166
+ root_dir=dataset_path,
167
+ device_name=device_name,
168
+ split="test",
169
+ clip_length=clip_length,
170
+ frame_size=frame_size,
171
+ stride=clip_length # No overlap for testing
172
+ )
173
+
174
+ train_loader = DataLoader(
175
+ train_dataset,
176
+ batch_size=batch_size,
177
+ shuffle=True,
178
+ num_workers=num_workers,
179
+ pin_memory=True,
180
+ drop_last=True
181
+ )
182
+
183
+ test_loader = DataLoader(
184
+ test_dataset,
185
+ batch_size=batch_size,
186
+ shuffle=False,
187
+ num_workers=num_workers,
188
+ pin_memory=True,
189
+ drop_last=False
190
+ )
191
+
192
+ return train_loader, test_loader
193
+
194
+
195
+ # Device name mappings
196
+ DEVICE_NAMES = [
197
+ "S01", "S02", "S03", "S04", "S05", "S06",
198
+ "S07", "S08", "S09", "S10", "S11", "S12",
199
+ "R01", "R02", "R03", "R04"
200
+ ]
201
+
202
+ SYNTHETIC_DEVICES = [f"S{i:02d}" for i in range(1, 13)]
203
+ REAL_DEVICES = [f"R{i:02d}" for i in range(1, 5)]