import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import os from PIL import Image class IAMDataset(Dataset): def __init__(self, data_dir, csv_file, transform=None): """ Args: data_dir (str): Path to directory containing IAM word images. csv_file (str): Path to CSV file containing 'filename' and 'text'. transform (callable, optional): Optional transform to be applied. """ self.data_dir = data_dir # Assuming CSV has columns: 'filename' and 'text' self.annotations = pd.read_csv(csv_file) self.transform = transform # Build vocabulary self.vocab = self._build_vocab() self.char_to_idx = {char: idx + 1 for idx, char in enumerate(self.vocab)} # 0 is reserved for CTC blank self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()} self.num_classes = len(self.vocab) + 1 # +1 for CTC blank def _build_vocab(self): chars = set() for text in self.annotations['text']: if pd.notna(text): chars.update(list(str(text))) return sorted(list(chars)) def __len__(self): return len(self.annotations) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() img_name = os.path.join(self.data_dir, str(self.annotations.iloc[idx]['filename'])) try: image = Image.open(img_name).convert('L') # Convert to grayscale except FileNotFoundError: # Handle missing files gracefully in a real scenario image = Image.new('L', (1024, 32), color=255) text = str(self.annotations.iloc[idx]['text']) if pd.isna(text): text = "" if self.transform: image = self.transform(image) # Convert text to tensor of indices encoded_text = [self.char_to_idx[char] for char in text if char in self.char_to_idx] text_tensor = torch.tensor(encoded_text, dtype=torch.long) return image, text_tensor, len(encoded_text) # Collate function for DataLoader to handle variable length sequences def collate_fn(batch): images, texts, text_lengths = zip(*batch) # Stack images images = torch.stack(images) # Pad texts to max length in batch texts_padded = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0) text_lengths = torch.tensor(text_lengths, dtype=torch.long) return images, texts_padded, text_lengths if __name__ == "__main__": print("Dataset module ready.")