triumphh77's picture
Upload 13 files
f9a156f verified
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image
class IAMDataset(Dataset):
def __init__(self, data_dir, csv_file, transform=None):
"""
Args:
data_dir (str): Path to directory containing IAM word images.
csv_file (str): Path to CSV file containing 'filename' and 'text'.
transform (callable, optional): Optional transform to be applied.
"""
self.data_dir = data_dir
# Assuming CSV has columns: 'filename' and 'text'
self.annotations = pd.read_csv(csv_file)
self.transform = transform
# Build vocabulary
self.vocab = self._build_vocab()
self.char_to_idx = {char: idx + 1 for idx, char in enumerate(self.vocab)} # 0 is reserved for CTC blank
self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
self.num_classes = len(self.vocab) + 1 # +1 for CTC blank
def _build_vocab(self):
chars = set()
for text in self.annotations['text']:
if pd.notna(text):
chars.update(list(str(text)))
return sorted(list(chars))
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.data_dir, str(self.annotations.iloc[idx]['filename']))
try:
image = Image.open(img_name).convert('L') # Convert to grayscale
except FileNotFoundError:
# Handle missing files gracefully in a real scenario
image = Image.new('L', (1024, 32), color=255)
text = str(self.annotations.iloc[idx]['text'])
if pd.isna(text):
text = ""
if self.transform:
image = self.transform(image)
# Convert text to tensor of indices
encoded_text = [self.char_to_idx[char] for char in text if char in self.char_to_idx]
text_tensor = torch.tensor(encoded_text, dtype=torch.long)
return image, text_tensor, len(encoded_text)
# Collate function for DataLoader to handle variable length sequences
def collate_fn(batch):
images, texts, text_lengths = zip(*batch)
# Stack images
images = torch.stack(images)
# Pad texts to max length in batch
texts_padded = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)
text_lengths = torch.tensor(text_lengths, dtype=torch.long)
return images, texts_padded, text_lengths
if __name__ == "__main__":
print("Dataset module ready.")