File size: 7,211 Bytes
234a70c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
"""
Dataset for loading text-GIF pairs for sign language generation
"""
import os
import glob
import random
from typing import Dict, List, Optional, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from torchvision import transforms
class SignLanguageDataset(Dataset):
"""Dataset for text-to-sign language video generation"""
def __init__(
self,
data_dir: str,
image_size: int = 64,
num_frames: int = 16,
train: bool = True,
train_ratio: float = 0.9,
):
"""
Args:
data_dir: Directory containing .gif and .txt files
image_size: Size to resize frames to
num_frames: Number of frames to sample from each GIF
train: Whether this is training set
train_ratio: Ratio of data to use for training
"""
self.data_dir = data_dir
self.image_size = image_size
self.num_frames = num_frames
self.train = train
# Find all pairs
self.pairs = self._find_pairs()
# Split into train/val
random.seed(42)
indices = list(range(len(self.pairs)))
random.shuffle(indices)
split_idx = int(len(indices) * train_ratio)
if train:
self.indices = indices[:split_idx]
else:
self.indices = indices[split_idx:]
# Image transforms
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # [-1, 1]
])
print(f"Loaded {len(self.indices)} {'training' if train else 'validation'} samples")
def _find_pairs(self) -> List[Tuple[str, str]]:
"""Find all GIF-text pairs in the data directory"""
pairs = []
# Find all GIF files
gif_files = glob.glob(os.path.join(self.data_dir, "*.gif"))
for gif_path in gif_files:
# Find corresponding text file
txt_path = gif_path.replace(".gif", ".txt")
if os.path.exists(txt_path):
pairs.append((gif_path, txt_path))
return pairs
def _load_gif(self, gif_path: str) -> torch.Tensor:
"""Load GIF and sample frames"""
try:
gif = Image.open(gif_path)
# Get all frames
frames = []
try:
while True:
# Convert to RGB
frame = gif.convert("RGB")
frame = self.transform(frame)
frames.append(frame)
gif.seek(gif.tell() + 1)
except EOFError:
pass
if len(frames) == 0:
raise ValueError(f"No frames found in {gif_path}")
# Sample or pad frames
if len(frames) >= self.num_frames:
# Uniform sampling
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
frames = [frames[i] for i in indices]
else:
# Pad by repeating last frame
while len(frames) < self.num_frames:
frames.append(frames[-1])
# Stack frames: (num_frames, C, H, W)
video = torch.stack(frames)
return video
except Exception as e:
print(f"Error loading {gif_path}: {e}")
# Return random noise as fallback
return torch.randn(self.num_frames, 3, self.image_size, self.image_size)
def _load_text(self, txt_path: str) -> str:
"""Load text from file"""
try:
with open(txt_path, "r", encoding="utf-8") as f:
text = f.read().strip()
return text
except Exception as e:
print(f"Error loading {txt_path}: {e}")
return ""
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
real_idx = self.indices[idx]
gif_path, txt_path = self.pairs[real_idx]
video = self._load_gif(gif_path) # (T, C, H, W)
text = self._load_text(txt_path)
return {
"video": video,
"text": text,
}
class SimpleTokenizer:
"""Simple tokenizer for text encoding"""
def __init__(self, vocab_size: int = 49408, max_length: int = 77):
self.vocab_size = vocab_size
self.max_length = max_length
# Simple character-level tokenization with hash
self.bos_token_id = 0
self.eos_token_id = 1
self.pad_token_id = 2
def encode(self, text: str) -> torch.Tensor:
"""Encode text to token IDs"""
# Simple hash-based encoding
tokens = [self.bos_token_id]
for char in text.lower():
# Hash character to token ID
token_id = (ord(char) % (self.vocab_size - 3)) + 3
tokens.append(token_id)
if len(tokens) >= self.max_length - 1:
break
tokens.append(self.eos_token_id)
# Pad to max_length
while len(tokens) < self.max_length:
tokens.append(self.pad_token_id)
return torch.tensor(tokens[:self.max_length], dtype=torch.long)
def __call__(self, texts: List[str]) -> torch.Tensor:
"""Batch encode texts"""
return torch.stack([self.encode(text) for text in texts])
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Custom collate function for batching"""
tokenizer = SimpleTokenizer()
videos = torch.stack([item["video"] for item in batch])
texts = [item["text"] for item in batch]
tokens = tokenizer(texts)
return {
"video": videos, # (B, T, C, H, W)
"tokens": tokens, # (B, max_length)
"text": texts, # List of strings
}
def get_dataloader(
data_dir: str,
batch_size: int = 4,
image_size: int = 64,
num_frames: int = 16,
num_workers: int = 4,
train: bool = True,
) -> DataLoader:
"""Create dataloader for training or validation"""
dataset = SignLanguageDataset(
data_dir=data_dir,
image_size=image_size,
num_frames=num_frames,
train=train,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=train,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
drop_last=train,
)
return dataloader
if __name__ == "__main__":
# Test dataset
dataset = SignLanguageDataset(
data_dir="text2sign/training_data",
image_size=64,
num_frames=16,
train=True,
)
print(f"Dataset size: {len(dataset)}")
sample = dataset[0]
print(f"Video shape: {sample['video'].shape}")
print(f"Text: {sample['text']}")
|