File size: 2,677 Bytes
f9a156f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image

class IAMDataset(Dataset):
    def __init__(self, data_dir, csv_file, transform=None):
        """
        Args:
            data_dir (str): Path to directory containing IAM word images.
            csv_file (str): Path to CSV file containing 'filename' and 'text'.
            transform (callable, optional): Optional transform to be applied.
        """
        self.data_dir = data_dir
        # Assuming CSV has columns: 'filename' and 'text'
        self.annotations = pd.read_csv(csv_file)
        self.transform = transform
        
        # Build vocabulary
        self.vocab = self._build_vocab()
        self.char_to_idx = {char: idx + 1 for idx, char in enumerate(self.vocab)} # 0 is reserved for CTC blank
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.num_classes = len(self.vocab) + 1 # +1 for CTC blank
        
    def _build_vocab(self):
        chars = set()
        for text in self.annotations['text']:
            if pd.notna(text):
                chars.update(list(str(text)))
        return sorted(list(chars))

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.data_dir, str(self.annotations.iloc[idx]['filename']))
        
        try:
            image = Image.open(img_name).convert('L') # Convert to grayscale
        except FileNotFoundError:
            # Handle missing files gracefully in a real scenario
            image = Image.new('L', (1024, 32), color=255)
            
        text = str(self.annotations.iloc[idx]['text'])
        
        if pd.isna(text):
            text = ""

        if self.transform:
            image = self.transform(image)

        # Convert text to tensor of indices
        encoded_text = [self.char_to_idx[char] for char in text if char in self.char_to_idx]
        text_tensor = torch.tensor(encoded_text, dtype=torch.long)

        return image, text_tensor, len(encoded_text)

# Collate function for DataLoader to handle variable length sequences
def collate_fn(batch):
    images, texts, text_lengths = zip(*batch)
    
    # Stack images
    images = torch.stack(images)
    
    # Pad texts to max length in batch
    texts_padded = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)
    
    text_lengths = torch.tensor(text_lengths, dtype=torch.long)
    
    return images, texts_padded, text_lengths

if __name__ == "__main__":
    print("Dataset module ready.")