text2sign / dataset.py
xiaruize's picture
upd
234a70c
"""
Dataset for loading text-GIF pairs for sign language generation
"""
import os
import glob
import random
from typing import Dict, List, Optional, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from torchvision import transforms
class SignLanguageDataset(Dataset):
"""Dataset for text-to-sign language video generation"""
def __init__(
self,
data_dir: str,
image_size: int = 64,
num_frames: int = 16,
train: bool = True,
train_ratio: float = 0.9,
):
"""
Args:
data_dir: Directory containing .gif and .txt files
image_size: Size to resize frames to
num_frames: Number of frames to sample from each GIF
train: Whether this is training set
train_ratio: Ratio of data to use for training
"""
self.data_dir = data_dir
self.image_size = image_size
self.num_frames = num_frames
self.train = train
# Find all pairs
self.pairs = self._find_pairs()
# Split into train/val
random.seed(42)
indices = list(range(len(self.pairs)))
random.shuffle(indices)
split_idx = int(len(indices) * train_ratio)
if train:
self.indices = indices[:split_idx]
else:
self.indices = indices[split_idx:]
# Image transforms
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # [-1, 1]
])
print(f"Loaded {len(self.indices)} {'training' if train else 'validation'} samples")
def _find_pairs(self) -> List[Tuple[str, str]]:
"""Find all GIF-text pairs in the data directory"""
pairs = []
# Find all GIF files
gif_files = glob.glob(os.path.join(self.data_dir, "*.gif"))
for gif_path in gif_files:
# Find corresponding text file
txt_path = gif_path.replace(".gif", ".txt")
if os.path.exists(txt_path):
pairs.append((gif_path, txt_path))
return pairs
def _load_gif(self, gif_path: str) -> torch.Tensor:
"""Load GIF and sample frames"""
try:
gif = Image.open(gif_path)
# Get all frames
frames = []
try:
while True:
# Convert to RGB
frame = gif.convert("RGB")
frame = self.transform(frame)
frames.append(frame)
gif.seek(gif.tell() + 1)
except EOFError:
pass
if len(frames) == 0:
raise ValueError(f"No frames found in {gif_path}")
# Sample or pad frames
if len(frames) >= self.num_frames:
# Uniform sampling
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
frames = [frames[i] for i in indices]
else:
# Pad by repeating last frame
while len(frames) < self.num_frames:
frames.append(frames[-1])
# Stack frames: (num_frames, C, H, W)
video = torch.stack(frames)
return video
except Exception as e:
print(f"Error loading {gif_path}: {e}")
# Return random noise as fallback
return torch.randn(self.num_frames, 3, self.image_size, self.image_size)
def _load_text(self, txt_path: str) -> str:
"""Load text from file"""
try:
with open(txt_path, "r", encoding="utf-8") as f:
text = f.read().strip()
return text
except Exception as e:
print(f"Error loading {txt_path}: {e}")
return ""
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
real_idx = self.indices[idx]
gif_path, txt_path = self.pairs[real_idx]
video = self._load_gif(gif_path) # (T, C, H, W)
text = self._load_text(txt_path)
return {
"video": video,
"text": text,
}
class SimpleTokenizer:
"""Simple tokenizer for text encoding"""
def __init__(self, vocab_size: int = 49408, max_length: int = 77):
self.vocab_size = vocab_size
self.max_length = max_length
# Simple character-level tokenization with hash
self.bos_token_id = 0
self.eos_token_id = 1
self.pad_token_id = 2
def encode(self, text: str) -> torch.Tensor:
"""Encode text to token IDs"""
# Simple hash-based encoding
tokens = [self.bos_token_id]
for char in text.lower():
# Hash character to token ID
token_id = (ord(char) % (self.vocab_size - 3)) + 3
tokens.append(token_id)
if len(tokens) >= self.max_length - 1:
break
tokens.append(self.eos_token_id)
# Pad to max_length
while len(tokens) < self.max_length:
tokens.append(self.pad_token_id)
return torch.tensor(tokens[:self.max_length], dtype=torch.long)
def __call__(self, texts: List[str]) -> torch.Tensor:
"""Batch encode texts"""
return torch.stack([self.encode(text) for text in texts])
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Custom collate function for batching"""
tokenizer = SimpleTokenizer()
videos = torch.stack([item["video"] for item in batch])
texts = [item["text"] for item in batch]
tokens = tokenizer(texts)
return {
"video": videos, # (B, T, C, H, W)
"tokens": tokens, # (B, max_length)
"text": texts, # List of strings
}
def get_dataloader(
data_dir: str,
batch_size: int = 4,
image_size: int = 64,
num_frames: int = 16,
num_workers: int = 4,
train: bool = True,
) -> DataLoader:
"""Create dataloader for training or validation"""
dataset = SignLanguageDataset(
data_dir=data_dir,
image_size=image_size,
num_frames=num_frames,
train=train,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=train,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
drop_last=train,
)
return dataloader
if __name__ == "__main__":
# Test dataset
dataset = SignLanguageDataset(
data_dir="text2sign/training_data",
image_size=64,
num_frames=16,
train=True,
)
print(f"Dataset size: {len(dataset)}")
sample = dataset[0]
print(f"Video shape: {sample['video'].shape}")
print(f"Text: {sample['text']}")