lucvantien1211's picture
Upload src folder, which contains modules and scripts
b20701a verified
'''
Custom dataset class definition for Vietnamese sign language data
'''
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import json
from src.data_utils import read_video, nfc_normalize
def collate_fn(batch):
frames = torch.stack([item["frames"] for item in batch])
output = {"frames": frames}
if "label" in batch[0] and batch[0]["label"] is not None:
output["labels"] = torch.tensor([item["label"] for item in batch])
if "path" in batch[0]:
output["paths"] = [item["path"] for item in batch]
return output
class VSLDataset(Dataset):
def __init__(
self, paths, label_mapping_path,
mode="train", transform=None,
norm_stats={
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225]
}, target_frames=32
):
assert mode in ["train", "validation", "test"], "Invalid value for dataset mode"
super().__init__()
self.paths = paths
self.mode = mode
self.transform = transform
self.norm_stats = norm_stats
self.target_frames = target_frames
with open(label_mapping_path, "r", encoding="utf-8") as f:
self.label2id = json.load(f)
self.labels = [
self.label2id[nfc_normalize(video_path.parent.name)]
for video_path in paths
] if mode != "test" else [None] * len(paths)
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
video_path = self.paths[idx]
label = self.labels[idx]
frames = read_video(video_path)
if self.transform is not None:
frames = self.transform(frames)
frames = self._resample_frames(frames)
frames = self._normalize(frames)
output = {"frames": frames, "label": label} if self.mode != "test" \
else {"frames": frames, "path": video_path}
return output
def _resample_frames(self, frames):
total = frames.shape[0]
if total >= self.target_frames:
indices = torch.linspace(0, total - 1, self.target_frames).long()
else:
indices = torch.arange(total)
pad = self.target_frames - total
indices = torch.cat([indices, indices[-1].repeat(pad)])
frames = frames[indices]
return frames
def _normalize(self, frames):
frames = frames.permute(0, 3, 1, 2).float() / 255.0
mean = torch.tensor(self.norm_stats["mean"]).view(1, 3, 1, 1)
std = torch.tensor(self.norm_stats["std"]).view(1, 3, 1, 1)
return (frames - mean) / std
class VideoAugmentation:
'''
Custom class for video data augmentation. These transformations are
consistent across all frames for one video
'''
def __init__(
self, mode,
output_size=(224, 224),
crop_scale=(0.85, 1.0),
brightness=0.2,
contrast=0.2,
saturation=0.2,
speed_range=(0.9, 1.1)
):
assert mode in ["train", "validation", "test"], "Invalid value for augmentation mode"
self.mode = mode
self.output_size = output_size
if self.mode == "train":
self.crop_scale = crop_scale
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.speed_range = speed_range
def __call__(self, frames):
if self.mode == "train":
# Speed Augmentation
frames = self._speed_augment(frames)
# Random Resized Crop
frames = self._random_resized_crop(frames)
# Color Jitter
frames = self._color_jitter(frames)
else:
# Only resize for validation and test data
frames = self._resize(frames)
return frames
def _speed_augment(self, frames):
'''Changing video speed by resampling frames'''
T = frames.shape[0]
speed = random.uniform(self.speed_range[0], self.speed_range[1])
new_T = int(T / speed)
if new_T < 4:
new_T = 4
if new_T == T:
return frames
# Resample frames
indices = torch.linspace(0, T - 1, new_T).long()
indices = torch.clamp(indices, 0, T - 1)
frames = frames[indices]
return frames
def _resize(self, frames):
H, W = frames.shape[1], frames.shape[2]
output_H, output_W = self.output_size
if H != output_H or W != output_W:
frames = frames.permute(0, 3, 1, 2).float()
frames = F.interpolate(frames, size=self.output_size, mode='bilinear', align_corners=False)
frames = frames.permute(0, 2, 3, 1).to(torch.uint8)
return frames
def _random_resized_crop(self, frames):
'''Random crop then resize to the desire output size'''
T, H, W, C = frames.shape
# Random scale and position
scale = random.uniform(self.crop_scale[0], self.crop_scale[1])
crop_h, crop_w = int(H * scale), int(W * scale)
top = random.randint(0, H - crop_h)
left = random.randint(0, W - crop_w)
# Crop all frames
frames = frames[:, top:top+crop_h, left:left+crop_w, :]
# Resize
# (T, H, W, C) -> (T, C, H, W) for interpolate
frames = frames.permute(0, 3, 1, 2).float()
frames = F.interpolate(frames, size=self.output_size, mode='bilinear', align_corners=False)
# (T, C, H, W) -> (T, H, W, C)
frames = frames.permute(0, 2, 3, 1)
return frames.to(torch.uint8)
def _color_jitter(self, frames):
'''Color jitter all frames'''
# Random parameters (same for all frames)
brightness_factor = 1.0 + random.uniform(-self.brightness, self.brightness)
contrast_factor = 1.0 + random.uniform(-self.contrast, self.contrast)
saturation_factor = 1.0 + random.uniform(-self.saturation, self.saturation)
frames = frames.float()
# Brightness
frames = frames * brightness_factor
# Contrast
mean = frames.mean(dim=(1, 2), keepdim=True)
frames = (frames - mean) * contrast_factor + mean
# Saturation
gray = frames.mean(dim=-1, keepdim=True)
frames = gray + (frames - gray) * saturation_factor
# Clamp to valid range
frames = torch.clamp(frames, 0, 255)
return frames.to(torch.uint8)