| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import pandas as pd |
| import os |
| from PIL import Image |
|
|
| class IAMDataset(Dataset): |
| def __init__(self, data_dir, csv_file, transform=None): |
| """ |
| Args: |
| data_dir (str): Path to directory containing IAM word images. |
| csv_file (str): Path to CSV file containing 'filename' and 'text'. |
| transform (callable, optional): Optional transform to be applied. |
| """ |
| self.data_dir = data_dir |
| |
| self.annotations = pd.read_csv(csv_file) |
| self.transform = transform |
| |
| |
| self.vocab = self._build_vocab() |
| self.char_to_idx = {char: idx + 1 for idx, char in enumerate(self.vocab)} |
| self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()} |
| self.num_classes = len(self.vocab) + 1 |
| |
| def _build_vocab(self): |
| chars = set() |
| for text in self.annotations['text']: |
| if pd.notna(text): |
| chars.update(list(str(text))) |
| return sorted(list(chars)) |
|
|
| def __len__(self): |
| return len(self.annotations) |
|
|
| def __getitem__(self, idx): |
| if torch.is_tensor(idx): |
| idx = idx.tolist() |
|
|
| img_name = os.path.join(self.data_dir, str(self.annotations.iloc[idx]['filename'])) |
| |
| try: |
| image = Image.open(img_name).convert('L') |
| except FileNotFoundError: |
| |
| image = Image.new('L', (1024, 32), color=255) |
| |
| text = str(self.annotations.iloc[idx]['text']) |
| |
| if pd.isna(text): |
| text = "" |
|
|
| if self.transform: |
| image = self.transform(image) |
|
|
| |
| encoded_text = [self.char_to_idx[char] for char in text if char in self.char_to_idx] |
| text_tensor = torch.tensor(encoded_text, dtype=torch.long) |
|
|
| return image, text_tensor, len(encoded_text) |
|
|
| |
| def collate_fn(batch): |
| images, texts, text_lengths = zip(*batch) |
| |
| |
| images = torch.stack(images) |
| |
| |
| texts_padded = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0) |
| |
| text_lengths = torch.tensor(text_lengths, dtype=torch.long) |
| |
| return images, texts_padded, text_lengths |
|
|
| if __name__ == "__main__": |
| print("Dataset module ready.") |
|
|