File size: 6,832 Bytes
b20701a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | '''
Custom dataset class definition for Vietnamese sign language data
'''
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import json
from src.data_utils import read_video, nfc_normalize
def collate_fn(batch):
frames = torch.stack([item["frames"] for item in batch])
output = {"frames": frames}
if "label" in batch[0] and batch[0]["label"] is not None:
output["labels"] = torch.tensor([item["label"] for item in batch])
if "path" in batch[0]:
output["paths"] = [item["path"] for item in batch]
return output
class VSLDataset(Dataset):
def __init__(
self, paths, label_mapping_path,
mode="train", transform=None,
norm_stats={
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225]
}, target_frames=32
):
assert mode in ["train", "validation", "test"], "Invalid value for dataset mode"
super().__init__()
self.paths = paths
self.mode = mode
self.transform = transform
self.norm_stats = norm_stats
self.target_frames = target_frames
with open(label_mapping_path, "r", encoding="utf-8") as f:
self.label2id = json.load(f)
self.labels = [
self.label2id[nfc_normalize(video_path.parent.name)]
for video_path in paths
] if mode != "test" else [None] * len(paths)
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
video_path = self.paths[idx]
label = self.labels[idx]
frames = read_video(video_path)
if self.transform is not None:
frames = self.transform(frames)
frames = self._resample_frames(frames)
frames = self._normalize(frames)
output = {"frames": frames, "label": label} if self.mode != "test" \
else {"frames": frames, "path": video_path}
return output
def _resample_frames(self, frames):
total = frames.shape[0]
if total >= self.target_frames:
indices = torch.linspace(0, total - 1, self.target_frames).long()
else:
indices = torch.arange(total)
pad = self.target_frames - total
indices = torch.cat([indices, indices[-1].repeat(pad)])
frames = frames[indices]
return frames
def _normalize(self, frames):
frames = frames.permute(0, 3, 1, 2).float() / 255.0
mean = torch.tensor(self.norm_stats["mean"]).view(1, 3, 1, 1)
std = torch.tensor(self.norm_stats["std"]).view(1, 3, 1, 1)
return (frames - mean) / std
class VideoAugmentation:
'''
Custom class for video data augmentation. These transformations are
consistent across all frames for one video
'''
def __init__(
self, mode,
output_size=(224, 224),
crop_scale=(0.85, 1.0),
brightness=0.2,
contrast=0.2,
saturation=0.2,
speed_range=(0.9, 1.1)
):
assert mode in ["train", "validation", "test"], "Invalid value for augmentation mode"
self.mode = mode
self.output_size = output_size
if self.mode == "train":
self.crop_scale = crop_scale
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.speed_range = speed_range
def __call__(self, frames):
if self.mode == "train":
# Speed Augmentation
frames = self._speed_augment(frames)
# Random Resized Crop
frames = self._random_resized_crop(frames)
# Color Jitter
frames = self._color_jitter(frames)
else:
# Only resize for validation and test data
frames = self._resize(frames)
return frames
def _speed_augment(self, frames):
'''Changing video speed by resampling frames'''
T = frames.shape[0]
speed = random.uniform(self.speed_range[0], self.speed_range[1])
new_T = int(T / speed)
if new_T < 4:
new_T = 4
if new_T == T:
return frames
# Resample frames
indices = torch.linspace(0, T - 1, new_T).long()
indices = torch.clamp(indices, 0, T - 1)
frames = frames[indices]
return frames
def _resize(self, frames):
H, W = frames.shape[1], frames.shape[2]
output_H, output_W = self.output_size
if H != output_H or W != output_W:
frames = frames.permute(0, 3, 1, 2).float()
frames = F.interpolate(frames, size=self.output_size, mode='bilinear', align_corners=False)
frames = frames.permute(0, 2, 3, 1).to(torch.uint8)
return frames
def _random_resized_crop(self, frames):
'''Random crop then resize to the desire output size'''
T, H, W, C = frames.shape
# Random scale and position
scale = random.uniform(self.crop_scale[0], self.crop_scale[1])
crop_h, crop_w = int(H * scale), int(W * scale)
top = random.randint(0, H - crop_h)
left = random.randint(0, W - crop_w)
# Crop all frames
frames = frames[:, top:top+crop_h, left:left+crop_w, :]
# Resize
# (T, H, W, C) -> (T, C, H, W) for interpolate
frames = frames.permute(0, 3, 1, 2).float()
frames = F.interpolate(frames, size=self.output_size, mode='bilinear', align_corners=False)
# (T, C, H, W) -> (T, H, W, C)
frames = frames.permute(0, 2, 3, 1)
return frames.to(torch.uint8)
def _color_jitter(self, frames):
'''Color jitter all frames'''
# Random parameters (same for all frames)
brightness_factor = 1.0 + random.uniform(-self.brightness, self.brightness)
contrast_factor = 1.0 + random.uniform(-self.contrast, self.contrast)
saturation_factor = 1.0 + random.uniform(-self.saturation, self.saturation)
frames = frames.float()
# Brightness
frames = frames * brightness_factor
# Contrast
mean = frames.mean(dim=(1, 2), keepdim=True)
frames = (frames - mean) * contrast_factor + mean
# Saturation
gray = frames.mean(dim=-1, keepdim=True)
frames = gray + (frames - gray) * saturation_factor
# Clamp to valid range
frames = torch.clamp(frames, 0, 255)
return frames.to(torch.uint8) |