|
|
""" |
|
|
Dataset for loading text-GIF pairs for sign language generation |
|
|
""" |
|
|
|
|
|
import os |
|
|
import glob |
|
|
import random |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
class SignLanguageDataset(Dataset): |
|
|
"""Dataset for text-to-sign language video generation""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_dir: str, |
|
|
image_size: int = 64, |
|
|
num_frames: int = 16, |
|
|
train: bool = True, |
|
|
train_ratio: float = 0.9, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
data_dir: Directory containing .gif and .txt files |
|
|
image_size: Size to resize frames to |
|
|
num_frames: Number of frames to sample from each GIF |
|
|
train: Whether this is training set |
|
|
train_ratio: Ratio of data to use for training |
|
|
""" |
|
|
self.data_dir = data_dir |
|
|
self.image_size = image_size |
|
|
self.num_frames = num_frames |
|
|
self.train = train |
|
|
|
|
|
|
|
|
self.pairs = self._find_pairs() |
|
|
|
|
|
|
|
|
random.seed(42) |
|
|
indices = list(range(len(self.pairs))) |
|
|
random.shuffle(indices) |
|
|
split_idx = int(len(indices) * train_ratio) |
|
|
|
|
|
if train: |
|
|
self.indices = indices[:split_idx] |
|
|
else: |
|
|
self.indices = indices[split_idx:] |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
print(f"Loaded {len(self.indices)} {'training' if train else 'validation'} samples") |
|
|
|
|
|
def _find_pairs(self) -> List[Tuple[str, str]]: |
|
|
"""Find all GIF-text pairs in the data directory""" |
|
|
pairs = [] |
|
|
|
|
|
|
|
|
gif_files = glob.glob(os.path.join(self.data_dir, "*.gif")) |
|
|
|
|
|
for gif_path in gif_files: |
|
|
|
|
|
txt_path = gif_path.replace(".gif", ".txt") |
|
|
|
|
|
if os.path.exists(txt_path): |
|
|
pairs.append((gif_path, txt_path)) |
|
|
|
|
|
return pairs |
|
|
|
|
|
def _load_gif(self, gif_path: str) -> torch.Tensor: |
|
|
"""Load GIF and sample frames""" |
|
|
try: |
|
|
gif = Image.open(gif_path) |
|
|
|
|
|
|
|
|
frames = [] |
|
|
try: |
|
|
while True: |
|
|
|
|
|
frame = gif.convert("RGB") |
|
|
frame = self.transform(frame) |
|
|
frames.append(frame) |
|
|
gif.seek(gif.tell() + 1) |
|
|
except EOFError: |
|
|
pass |
|
|
|
|
|
if len(frames) == 0: |
|
|
raise ValueError(f"No frames found in {gif_path}") |
|
|
|
|
|
|
|
|
if len(frames) >= self.num_frames: |
|
|
|
|
|
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int) |
|
|
frames = [frames[i] for i in indices] |
|
|
else: |
|
|
|
|
|
while len(frames) < self.num_frames: |
|
|
frames.append(frames[-1]) |
|
|
|
|
|
|
|
|
video = torch.stack(frames) |
|
|
|
|
|
return video |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading {gif_path}: {e}") |
|
|
|
|
|
return torch.randn(self.num_frames, 3, self.image_size, self.image_size) |
|
|
|
|
|
def _load_text(self, txt_path: str) -> str: |
|
|
"""Load text from file""" |
|
|
try: |
|
|
with open(txt_path, "r", encoding="utf-8") as f: |
|
|
text = f.read().strip() |
|
|
return text |
|
|
except Exception as e: |
|
|
print(f"Error loading {txt_path}: {e}") |
|
|
return "" |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.indices) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
real_idx = self.indices[idx] |
|
|
gif_path, txt_path = self.pairs[real_idx] |
|
|
|
|
|
video = self._load_gif(gif_path) |
|
|
text = self._load_text(txt_path) |
|
|
|
|
|
return { |
|
|
"video": video, |
|
|
"text": text, |
|
|
} |
|
|
|
|
|
|
|
|
class SimpleTokenizer: |
|
|
"""Simple tokenizer for text encoding""" |
|
|
|
|
|
def __init__(self, vocab_size: int = 49408, max_length: int = 77): |
|
|
self.vocab_size = vocab_size |
|
|
self.max_length = max_length |
|
|
|
|
|
|
|
|
self.bos_token_id = 0 |
|
|
self.eos_token_id = 1 |
|
|
self.pad_token_id = 2 |
|
|
|
|
|
def encode(self, text: str) -> torch.Tensor: |
|
|
"""Encode text to token IDs""" |
|
|
|
|
|
tokens = [self.bos_token_id] |
|
|
|
|
|
for char in text.lower(): |
|
|
|
|
|
token_id = (ord(char) % (self.vocab_size - 3)) + 3 |
|
|
tokens.append(token_id) |
|
|
|
|
|
if len(tokens) >= self.max_length - 1: |
|
|
break |
|
|
|
|
|
tokens.append(self.eos_token_id) |
|
|
|
|
|
|
|
|
while len(tokens) < self.max_length: |
|
|
tokens.append(self.pad_token_id) |
|
|
|
|
|
return torch.tensor(tokens[:self.max_length], dtype=torch.long) |
|
|
|
|
|
def __call__(self, texts: List[str]) -> torch.Tensor: |
|
|
"""Batch encode texts""" |
|
|
return torch.stack([self.encode(text) for text in texts]) |
|
|
|
|
|
|
|
|
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: |
|
|
"""Custom collate function for batching""" |
|
|
tokenizer = SimpleTokenizer() |
|
|
|
|
|
videos = torch.stack([item["video"] for item in batch]) |
|
|
texts = [item["text"] for item in batch] |
|
|
tokens = tokenizer(texts) |
|
|
|
|
|
return { |
|
|
"video": videos, |
|
|
"tokens": tokens, |
|
|
"text": texts, |
|
|
} |
|
|
|
|
|
|
|
|
def get_dataloader( |
|
|
data_dir: str, |
|
|
batch_size: int = 4, |
|
|
image_size: int = 64, |
|
|
num_frames: int = 16, |
|
|
num_workers: int = 4, |
|
|
train: bool = True, |
|
|
) -> DataLoader: |
|
|
"""Create dataloader for training or validation""" |
|
|
|
|
|
dataset = SignLanguageDataset( |
|
|
data_dir=data_dir, |
|
|
image_size=image_size, |
|
|
num_frames=num_frames, |
|
|
train=train, |
|
|
) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=train, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn, |
|
|
pin_memory=True, |
|
|
drop_last=train, |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
dataset = SignLanguageDataset( |
|
|
data_dir="text2sign/training_data", |
|
|
image_size=64, |
|
|
num_frames=16, |
|
|
train=True, |
|
|
) |
|
|
|
|
|
print(f"Dataset size: {len(dataset)}") |
|
|
|
|
|
sample = dataset[0] |
|
|
print(f"Video shape: {sample['video'].shape}") |
|
|
print(f"Text: {sample['text']}") |
|
|
|