triumphh77 commited on
Commit
f9a156f
·
verified ·
1 Parent(s): 0be6807

Upload 13 files

Browse files
data/labels.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ torchaudio
4
+ opencv-python-headless
5
+ numpy
6
+ pandas
7
+ matplotlib
8
+ tqdm
9
+ scikit-learn
10
+ gradio
11
+ streamlit
12
+ pillow
13
+ albumentations
src/data/dataset.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import pandas as pd
4
+ import os
5
+ from PIL import Image
6
+
7
+ class IAMDataset(Dataset):
8
+ def __init__(self, data_dir, csv_file, transform=None):
9
+ """
10
+ Args:
11
+ data_dir (str): Path to directory containing IAM word images.
12
+ csv_file (str): Path to CSV file containing 'filename' and 'text'.
13
+ transform (callable, optional): Optional transform to be applied.
14
+ """
15
+ self.data_dir = data_dir
16
+ # Assuming CSV has columns: 'filename' and 'text'
17
+ self.annotations = pd.read_csv(csv_file)
18
+ self.transform = transform
19
+
20
+ # Build vocabulary
21
+ self.vocab = self._build_vocab()
22
+ self.char_to_idx = {char: idx + 1 for idx, char in enumerate(self.vocab)} # 0 is reserved for CTC blank
23
+ self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
24
+ self.num_classes = len(self.vocab) + 1 # +1 for CTC blank
25
+
26
+ def _build_vocab(self):
27
+ chars = set()
28
+ for text in self.annotations['text']:
29
+ if pd.notna(text):
30
+ chars.update(list(str(text)))
31
+ return sorted(list(chars))
32
+
33
+ def __len__(self):
34
+ return len(self.annotations)
35
+
36
+ def __getitem__(self, idx):
37
+ if torch.is_tensor(idx):
38
+ idx = idx.tolist()
39
+
40
+ img_name = os.path.join(self.data_dir, str(self.annotations.iloc[idx]['filename']))
41
+
42
+ try:
43
+ image = Image.open(img_name).convert('L') # Convert to grayscale
44
+ except FileNotFoundError:
45
+ # Handle missing files gracefully in a real scenario
46
+ image = Image.new('L', (1024, 32), color=255)
47
+
48
+ text = str(self.annotations.iloc[idx]['text'])
49
+
50
+ if pd.isna(text):
51
+ text = ""
52
+
53
+ if self.transform:
54
+ image = self.transform(image)
55
+
56
+ # Convert text to tensor of indices
57
+ encoded_text = [self.char_to_idx[char] for char in text if char in self.char_to_idx]
58
+ text_tensor = torch.tensor(encoded_text, dtype=torch.long)
59
+
60
+ return image, text_tensor, len(encoded_text)
61
+
62
+ # Collate function for DataLoader to handle variable length sequences
63
+ def collate_fn(batch):
64
+ images, texts, text_lengths = zip(*batch)
65
+
66
+ # Stack images
67
+ images = torch.stack(images)
68
+
69
+ # Pad texts to max length in batch
70
+ texts_padded = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)
71
+
72
+ text_lengths = torch.tensor(text_lengths, dtype=torch.long)
73
+
74
+ return images, texts_padded, text_lengths
75
+
76
+ if __name__ == "__main__":
77
+ print("Dataset module ready.")
src/data/download_dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ from tqdm import tqdm
5
+
6
+ def download_and_prepare_iam():
7
+ print("Downloading IAM-line dataset from Hugging Face...")
8
+ # Loading the dataset from Hugging Face (approx 266 MB)
9
+ dataset = load_dataset("Sj122702/IAM-line")
10
+
11
+ data_dir = "data/iam_words"
12
+ os.makedirs(data_dir, exist_ok=True)
13
+
14
+ print(f"Saving images to {data_dir} and creating labels.csv...")
15
+
16
+ metadata = []
17
+
18
+ # We will process the 'train' split for demonstration
19
+ # You can expand this to validation and test splits as well
20
+ split = 'train'
21
+
22
+ for idx, item in enumerate(tqdm(dataset[split])):
23
+ # The dataset contains 'image' and 'text'
24
+ image = item['image']
25
+ text = item['text']
26
+
27
+ # Save image
28
+ filename = f"img_{split}_{idx}.png"
29
+ filepath = os.path.join(data_dir, filename)
30
+
31
+ # Some images might be in different modes, convert to grayscale
32
+ image = image.convert("L")
33
+ image.save(filepath)
34
+
35
+ # Add to metadata
36
+ metadata.append({
37
+ "filename": filename,
38
+ "text": text
39
+ })
40
+
41
+ # Save metadata to CSV
42
+ csv_path = "data/labels.csv"
43
+ df = pd.DataFrame(metadata)
44
+ df.to_csv(csv_path, index=False)
45
+
46
+ print(f"\nDataset prepared successfully!")
47
+ print(f"Total images saved: {len(metadata)}")
48
+ print(f"Images location: {data_dir}/")
49
+ print(f"Labels CSV location: {csv_path}")
50
+
51
+ if __name__ == "__main__":
52
+ download_and_prepare_iam()
src/models/crnn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CRNN(nn.Module):
5
+ def __init__(self, img_channel, img_height, img_width, num_class,
6
+ map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
7
+ super(CRNN, self).__init__()
8
+
9
+ self.cnn, (output_channel, output_height, output_width) = \
10
+ self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
11
+
12
+ self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
13
+
14
+ self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True, batch_first=True)
15
+ self.rnn2 = nn.LSTM(rnn_hidden * 2, rnn_hidden, bidirectional=True, batch_first=True)
16
+
17
+ self.dense = nn.Linear(rnn_hidden * 2, num_class)
18
+
19
+ def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
20
+ assert img_height % 16 == 0
21
+ assert img_width % 4 == 0
22
+
23
+ channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
24
+ kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
25
+ strides = [1, 1, 1, 1, 1, 1, 1]
26
+ paddings = [1, 1, 1, 1, 1, 1, 0]
27
+
28
+ cnn = nn.Sequential()
29
+
30
+ def conv_relu(i, batch_normalization=False):
31
+ n_in = channels[i]
32
+ n_out = channels[i+1]
33
+ cnn.add_module(f'conv{i}', nn.Conv2d(n_in, n_out, kernel_sizes[i], strides[i], paddings[i]))
34
+ if batch_normalization:
35
+ cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(n_out))
36
+ if leaky_relu:
37
+ cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True))
38
+ else:
39
+ cnn.add_module(f'relu{i}', nn.ReLU(inplace=True))
40
+
41
+ conv_relu(0)
42
+ cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2)) # 64x16x64
43
+ conv_relu(1)
44
+ cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2)) # 128x8x32
45
+ conv_relu(2, True)
46
+ conv_relu(3)
47
+ cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 256x4x33
48
+ conv_relu(4, True)
49
+ conv_relu(5)
50
+ cnn.add_module('pooling3', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 512x2x34
51
+ conv_relu(6, True) # 512x1x33
52
+
53
+ output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 + 1
54
+ return cnn, (output_channel, output_height, output_width)
55
+
56
+ def forward(self, images):
57
+ # shape of images: (batch, channel, height, width)
58
+ conv = self.cnn(images)
59
+ batch, channel, height, width = conv.size()
60
+
61
+ conv = conv.view(batch, channel * height, width)
62
+ conv = conv.permute(0, 2, 1) # (batch, width, channel*height)
63
+
64
+ seq = self.map_to_seq(conv)
65
+
66
+ recurrent, _ = self.rnn1(seq)
67
+ recurrent, _ = self.rnn2(recurrent)
68
+
69
+ output = self.dense(recurrent)
70
+
71
+ # Log softmax for CTC loss
72
+ # Note: PyTorch's CTCLoss expects inputs of shape (input_length, batch_size, num_classes)
73
+ # So we permute it if we are returning it for CTC loss calculation directly
74
+ return output.log_softmax(2)
75
+
76
+ if __name__ == '__main__':
77
+ # Test model
78
+ dummy_input = torch.randn(1, 1, 32, 1024)
79
+ model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=80)
80
+ output = model(dummy_input)
81
+ print(f"Output shape: {output.shape}") # Expected: (1, 33, 80)
src/models/gan.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # Simple DCGAN-style architecture for generating word images (1x32x1024)
5
+
6
+ class Generator(nn.Module):
7
+ def __init__(self, latent_dim=100, channels=1):
8
+ super(Generator, self).__init__()
9
+
10
+ # Input: latent_dim, mapping to 4x128 map initially
11
+ self.init_size_h = 4
12
+ self.init_size_w = 128
13
+ self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size_h * self.init_size_w))
14
+
15
+ self.conv_blocks = nn.Sequential(
16
+ nn.BatchNorm2d(128),
17
+ nn.Upsample(scale_factor=2), # 8x256
18
+ nn.Conv2d(128, 128, 3, stride=1, padding=1),
19
+ nn.BatchNorm2d(128, 0.8),
20
+ nn.LeakyReLU(0.2, inplace=True),
21
+
22
+ nn.Upsample(scale_factor=2), # 16x512
23
+ nn.Conv2d(128, 64, 3, stride=1, padding=1),
24
+ nn.BatchNorm2d(64, 0.8),
25
+ nn.LeakyReLU(0.2, inplace=True),
26
+
27
+ nn.Upsample(scale_factor=2), # 32x1024
28
+ nn.Conv2d(64, channels, 3, stride=1, padding=1),
29
+ nn.Tanh(), # Output [-1, 1]
30
+ )
31
+
32
+ def forward(self, z):
33
+ out = self.l1(z)
34
+ out = out.view(out.shape[0], 128, self.init_size_h, self.init_size_w)
35
+ img = self.conv_blocks(out)
36
+ return img
37
+
38
+
39
+ class Discriminator(nn.Module):
40
+ def __init__(self, channels=1):
41
+ super(Discriminator, self).__init__()
42
+
43
+ def discriminator_block(in_filters, out_filters, bn=True):
44
+ block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
45
+ if bn:
46
+ block.append(nn.BatchNorm2d(out_filters, 0.8))
47
+ return block
48
+
49
+ self.model = nn.Sequential(
50
+ *discriminator_block(channels, 16, bn=False), # 16x512
51
+ *discriminator_block(16, 32), # 8x256
52
+ *discriminator_block(32, 64), # 4x128
53
+ *discriminator_block(64, 128), # 2x64
54
+ )
55
+
56
+ # The height and width of downsampled image
57
+ ds_size_h = 32 // 2**4
58
+ ds_size_w = 1024 // 2**4
59
+ self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size_h * ds_size_w, 1), nn.Sigmoid())
60
+
61
+ def forward(self, img):
62
+ out = self.model(img)
63
+ out = out.view(out.shape[0], -1)
64
+ validity = self.adv_layer(out)
65
+ return validity
66
+
67
+ if __name__ == "__main__":
68
+ z = torch.randn(1, 100)
69
+ G = Generator()
70
+ D = Discriminator()
71
+ fake_img = G(z)
72
+ validity = D(fake_img)
73
+ print(f"Generator output shape: {fake_img.shape}")
74
+ print(f"Discriminator output shape: {validity.shape}")
src/training/train_crnn.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import os
8
+ import sys
9
+
10
+ # Add project root to path
11
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
12
+
13
+ from src.data.dataset import IAMDataset, collate_fn
14
+ from src.models.crnn import CRNN
15
+
16
+ # Define transforms
17
+ transform = transforms.Compose([
18
+ transforms.Resize((32, 1024)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5,), (0.5,))
21
+ ])
22
+
23
+ def train_baseline(model, dataloader, optimizer, criterion, device, epochs=10, start_epoch=0):
24
+ model.train()
25
+
26
+ for epoch in range(start_epoch, epochs):
27
+ total_loss = 0
28
+ for i, (images, texts, text_lengths) in enumerate(dataloader):
29
+ images = images.to(device)
30
+ texts = texts.to(device)
31
+
32
+ optimizer.zero_grad()
33
+
34
+ # Forward pass
35
+ preds = model(images)
36
+
37
+ # CTCLoss expects (sequence_length, batch_size, num_classes)
38
+ preds = preds.permute(1, 0, 2)
39
+
40
+ # Calculate lengths for CTC Loss
41
+ input_lengths = torch.full(size=(preds.size(1),), fill_value=preds.size(0), dtype=torch.long)
42
+
43
+ # CTCLoss expects concatenated targets, not padded 2D tensor
44
+ # Flatten all target sequences into 1D
45
+ targets_list = []
46
+ for i in range(texts.size(0)):
47
+ target_seq = texts[i][:text_lengths[i]]
48
+ targets_list.append(target_seq)
49
+ targets_concat = torch.cat(targets_list)
50
+
51
+ loss = criterion(preds, targets_concat, input_lengths, text_lengths)
52
+
53
+ # Backward pass
54
+ loss.backward()
55
+ optimizer.step()
56
+
57
+ total_loss += loss.item()
58
+
59
+ if i % 10 == 0:
60
+ print(f"Epoch [{epoch+1}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
61
+
62
+ print(f"Epoch {epoch+1} Average Loss: {total_loss/len(dataloader):.4f}")
63
+
64
+ # Save checkpoint
65
+ os.makedirs('weights', exist_ok=True)
66
+ torch.save(model.state_dict(), f'weights/crnn_baseline_epoch_{epoch+1}.pth')
67
+
68
+ return model
69
+
70
+ if __name__ == "__main__":
71
+ print("Starting CRNN Baseline Training...")
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ print(f"Using device: {device}")
74
+
75
+ # Setup Data
76
+ data_dir = 'data/iam_words'
77
+ csv_file = 'data/labels.csv'
78
+
79
+ dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform)
80
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
81
+
82
+ # Setup Model
83
+ num_classes = dataset.num_classes
84
+ model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
85
+
86
+ # Resume from checkpoint if exists
87
+ start_epoch = 0
88
+ # Find the latest checkpoint
89
+ import glob
90
+ checkpoints = glob.glob('weights/crnn_baseline_epoch_*.pth')
91
+ if checkpoints:
92
+ checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
93
+ latest_checkpoint = checkpoints[-1]
94
+ start_epoch = int(os.path.basename(latest_checkpoint).split('_')[-1].split('.')[0])
95
+ print(f"Resuming training from {latest_checkpoint} (epoch {start_epoch})")
96
+ model.load_state_dict(torch.load(latest_checkpoint, map_location=device))
97
+
98
+ # Setup Optimizer & Loss
99
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
100
+ criterion = nn.CTCLoss(blank=0, zero_infinity=True)
101
+
102
+ # Train
103
+ train_baseline(model, dataloader, optimizer, criterion, device, epochs=30, start_epoch=start_epoch)
104
+ print("Training complete!")
src/training/train_gan.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ import os
7
+ import sys
8
+
9
+ # Add project root to path
10
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
11
+
12
+ from src.data.dataset import IAMDataset, collate_fn
13
+ from src.models.gan import Generator, Discriminator
14
+
15
+ # Define transforms for GAN (needs to be slightly different, just standard normalization)
16
+ transform = transforms.Compose([
17
+ transforms.Resize((32, 1024)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.5,), (0.5,))
20
+ ])
21
+
22
+ def train_gan(generator, discriminator, dataloader, epochs, device, latent_dim=100, start_epoch=0):
23
+ criterion = nn.BCELoss()
24
+
25
+ optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
26
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
27
+
28
+ generator.train()
29
+ discriminator.train()
30
+
31
+ for epoch in range(start_epoch, epochs):
32
+ for i, (imgs, _, _) in enumerate(dataloader):
33
+
34
+ batch_size = imgs.size(0)
35
+ # Adversarial ground truths
36
+ valid = torch.ones(batch_size, 1, requires_grad=False).to(device)
37
+ fake = torch.zeros(batch_size, 1, requires_grad=False).to(device)
38
+
39
+ # Configure input
40
+ real_imgs = imgs.to(device)
41
+
42
+ # -----------------
43
+ # Train Generator
44
+ # -----------------
45
+ optimizer_G.zero_grad()
46
+
47
+ # Sample noise as generator input
48
+ z = torch.randn(batch_size, latent_dim).to(device)
49
+
50
+ # Generate a batch of images
51
+ gen_imgs = generator(z)
52
+
53
+ # Loss measures generator's ability to fool the discriminator
54
+ g_loss = criterion(discriminator(gen_imgs), valid)
55
+
56
+ g_loss.backward()
57
+ optimizer_G.step()
58
+
59
+ # ---------------------
60
+ # Train Discriminator
61
+ # ---------------------
62
+ optimizer_D.zero_grad()
63
+
64
+ # Measure discriminator's ability to classify real from generated samples
65
+ real_loss = criterion(discriminator(real_imgs), valid)
66
+ fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
67
+ d_loss = (real_loss + fake_loss) / 2
68
+
69
+ d_loss.backward()
70
+ optimizer_D.step()
71
+
72
+ if i % 50 == 0:
73
+ print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
74
+
75
+ # Save checkpoints
76
+ os.makedirs('weights', exist_ok=True)
77
+ torch.save(generator.state_dict(), f'weights/gan_generator_epoch_{epoch+1}.pth')
78
+ torch.save(discriminator.state_dict(), f'weights/gan_discriminator_epoch_{epoch+1}.pth')
79
+
80
+ if __name__ == "__main__":
81
+ print("Starting GAN Training...")
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ print(f"Using device: {device}")
84
+
85
+ # Setup Data
86
+ data_dir = 'data/iam_words'
87
+ csv_file = 'data/labels.csv'
88
+
89
+ dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform)
90
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
91
+
92
+ # Setup Models
93
+ generator = Generator().to(device)
94
+ discriminator = Discriminator().to(device)
95
+
96
+ # Resume from checkpoint if exists
97
+ start_epoch = 0
98
+ import glob
99
+ checkpoints = glob.glob('weights/gan_generator_epoch_*.pth')
100
+ if checkpoints:
101
+ checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
102
+ latest_gen_checkpoint = checkpoints[-1]
103
+ start_epoch = int(os.path.basename(latest_gen_checkpoint).split('_')[-1].split('.')[0])
104
+ latest_disc_checkpoint = f'weights/gan_discriminator_epoch_{start_epoch}.pth'
105
+
106
+ print(f"Resuming GAN training from epoch {start_epoch}")
107
+ generator.load_state_dict(torch.load(latest_gen_checkpoint, map_location=device))
108
+ discriminator.load_state_dict(torch.load(latest_disc_checkpoint, map_location=device))
109
+
110
+ # Train
111
+ train_gan(generator, discriminator, dataloader, epochs=50, device=device, start_epoch=start_epoch)
112
+ print("GAN Training complete!")
src/training/train_ssl.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ import os
7
+ import sys
8
+ import glob
9
+
10
+ # Add project root to path
11
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
12
+
13
+ from src.data.dataset import IAMDataset, collate_fn
14
+ from src.models.crnn import CRNN
15
+ from src.models.gan import Generator
16
+
17
+ # Define transforms matching training exactly
18
+ transform = transforms.Compose([
19
+ transforms.Resize((32, 1024)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize((0.5,), (0.5,))
22
+ ])
23
+
24
+ def decode_pseudo_labels(preds):
25
+ # preds: (seq_len, batch, classes)
26
+ _, max_preds = torch.max(preds, 2)
27
+ max_preds = max_preds.permute(1, 0) # (batch, seq_len)
28
+
29
+ targets_list = []
30
+ target_lengths = []
31
+
32
+ for batch_idx in range(max_preds.size(0)):
33
+ pred_seq = max_preds[batch_idx]
34
+ decoded_seq = []
35
+ for i in range(len(pred_seq)):
36
+ if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]):
37
+ decoded_seq.append(pred_seq[i].item())
38
+
39
+ target_tensor = torch.tensor(decoded_seq, dtype=torch.long)
40
+ targets_list.append(target_tensor)
41
+ target_lengths.append(len(decoded_seq))
42
+
43
+ return targets_list, target_lengths
44
+
45
+ def train_ssl(model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8, latent_dim=100):
46
+ """
47
+ Pseudo-labeling approach for Semi-Supervised Learning.
48
+ Combines real labeled data with synthetic unlabeled data generated dynamically by the GAN.
49
+ """
50
+ model.train()
51
+ generator.eval() # Generator is fixed during this phase
52
+
53
+ for epoch in range(epochs):
54
+ total_loss_real = 0
55
+ total_loss_fake = 0
56
+
57
+ for step, (labeled_imgs, labeled_texts, labeled_lengths) in enumerate(dataloader):
58
+ labeled_imgs = labeled_imgs.to(device)
59
+ labeled_texts = labeled_texts.to(device)
60
+ batch_size = labeled_imgs.size(0)
61
+
62
+ optimizer.zero_grad()
63
+
64
+ # ==============================================================
65
+ # 1. Train on Real Labeled Data
66
+ # ==============================================================
67
+ preds_l = model(labeled_imgs)
68
+ preds_l = preds_l.permute(1, 0, 2) # (seq_len, batch, classes)
69
+
70
+ input_lengths_l = torch.full(size=(preds_l.size(1),), fill_value=preds_l.size(0), dtype=torch.long)
71
+
72
+ targets_list_l = []
73
+ for i in range(labeled_texts.size(0)):
74
+ targets_list_l.append(labeled_texts[i][:labeled_lengths[i]])
75
+ targets_concat_l = torch.cat(targets_list_l)
76
+
77
+ loss_real = criterion(preds_l, targets_concat_l, input_lengths_l, labeled_lengths)
78
+
79
+ # ==============================================================
80
+ # 2. Train on Synthetic GAN Data (Pseudo-Labeling)
81
+ # ==============================================================
82
+ # Generate fake images
83
+ with torch.no_grad():
84
+ z = torch.randn(batch_size, latent_dim).to(device)
85
+ fake_imgs = generator(z) # Shape: (batch, 1, 32, 1024), range [-1, 1]
86
+
87
+ # Get pseudo-labels
88
+ model.eval()
89
+ preds_fake_eval = model(fake_imgs)
90
+ probs = torch.exp(preds_fake_eval) # Softmax probs
91
+ max_probs, _ = torch.max(probs, dim=2)
92
+ avg_confidence = max_probs.mean(dim=1)
93
+
94
+ # Mask confident predictions
95
+ mask = avg_confidence > threshold
96
+
97
+ model.train()
98
+ loss_fake = torch.tensor(0.0).to(device)
99
+
100
+ if mask.sum() > 0:
101
+ confident_imgs = fake_imgs[mask]
102
+
103
+ preds_fake = model(confident_imgs)
104
+ preds_fake_perm = preds_fake.permute(1, 0, 2)
105
+
106
+ # Decode the pseudo-labels into CTC targets
107
+ targets_list_u, target_lengths_u = decode_pseudo_labels(preds_fake_perm.detach())
108
+
109
+ # Filter out empty pseudo-labels
110
+ valid_idx = [i for i, length in enumerate(target_lengths_u) if length > 0]
111
+
112
+ if valid_idx:
113
+ valid_preds_fake_perm = preds_fake_perm[:, valid_idx, :]
114
+ valid_targets_list = [targets_list_u[i].to(device) for i in valid_idx]
115
+ valid_target_lengths = torch.tensor([target_lengths_u[i] for i in valid_idx], dtype=torch.long).to(device)
116
+
117
+ valid_targets_concat = torch.cat(valid_targets_list)
118
+ input_lengths_u = torch.full(size=(valid_preds_fake_perm.size(1),), fill_value=valid_preds_fake_perm.size(0), dtype=torch.long).to(device)
119
+
120
+ loss_fake = criterion(valid_preds_fake_perm, valid_targets_concat, input_lengths_u, valid_target_lengths)
121
+ # Scale down the fake loss slightly so it doesn't overwhelm real data
122
+ loss_fake = loss_fake * 0.5
123
+
124
+ # Total loss
125
+ total_loss = loss_real + loss_fake
126
+ total_loss.backward()
127
+ optimizer.step()
128
+
129
+ total_loss_real += loss_real.item()
130
+ total_loss_fake += loss_fake.item() if loss_fake > 0 else 0
131
+
132
+ if step % 20 == 0:
133
+ print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Real Loss: {loss_real.item():.4f}, Fake Loss: {loss_fake.item() if loss_fake > 0 else 0:.4f}, Confident Fakes: {mask.sum().item()}/{batch_size}")
134
+
135
+ print(f"Epoch {epoch+1} Average Real Loss: {total_loss_real/len(dataloader):.4f}, Average Fake Loss: {total_loss_fake/len(dataloader):.4f}")
136
+
137
+ # Save checkpoints
138
+ os.makedirs('weights', exist_ok=True)
139
+ torch.save(model.state_dict(), f'weights/crnn_ssl_epoch_{epoch+1}.pth')
140
+
141
+ if __name__ == "__main__":
142
+ print("Starting Semi-Supervised Learning (SSL) Training Phase...")
143
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
+ print(f"Using device: {device}")
145
+
146
+ # 1. Load Dataset
147
+ data_dir = 'data/iam_words'
148
+ csv_file = 'data/labels.csv'
149
+ dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform)
150
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
151
+
152
+ # 2. Load the Baseline CRNN Model
153
+ num_classes = dataset.num_classes
154
+ crnn_model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
155
+
156
+ checkpoints_crnn = glob.glob('weights/crnn_baseline_epoch_*.pth')
157
+ if not checkpoints_crnn:
158
+ print("Error: Could not find baseline CRNN weights.")
159
+ sys.exit(1)
160
+ checkpoints_crnn.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
161
+ latest_crnn = checkpoints_crnn[-1]
162
+ print(f"Loading Baseline CRNN from {latest_crnn}")
163
+ crnn_model.load_state_dict(torch.load(latest_crnn, map_location=device))
164
+
165
+ # 3. Load the Trained GAN Generator
166
+ generator = Generator(latent_dim=100).to(device)
167
+ checkpoints_gan = glob.glob('weights/gan_generator_epoch_*.pth')
168
+ if not checkpoints_gan:
169
+ print("Error: Could not find GAN Generator weights.")
170
+ sys.exit(1)
171
+ checkpoints_gan.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
172
+ latest_gan = checkpoints_gan[-1]
173
+ print(f"Loading GAN Generator from {latest_gan}")
174
+ generator.load_state_dict(torch.load(latest_gan, map_location=device))
175
+
176
+ # 4. Setup Optimizer & Loss
177
+ # Use a smaller learning rate for fine-tuning
178
+ optimizer = optim.Adam(crnn_model.parameters(), lr=0.0001)
179
+ criterion = nn.CTCLoss(blank=0, zero_infinity=True)
180
+
181
+ # 5. Start SSL Training Loop
182
+ train_ssl(crnn_model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8)
183
+ print("SSL Training complete!")
src/utils/preprocessing.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def preprocess_image(image_path_or_array, target_size=(1024, 32)):
5
+ """
6
+ Preprocess the image for handwritten text recognition.
7
+ 1. Read image as grayscale
8
+ 2. Resize while maintaining aspect ratio (padding with white)
9
+ 3. Apply binarization / normalization
10
+ """
11
+ if isinstance(image_path_or_array, str):
12
+ img = cv2.imread(image_path_or_array, cv2.IMREAD_GRAYSCALE)
13
+ if img is None:
14
+ raise FileNotFoundError(f"Could not read image at {image_path_or_array}")
15
+ else:
16
+ if len(image_path_or_array.shape) == 3:
17
+ img = cv2.cvtColor(image_path_or_array, cv2.COLOR_BGR2GRAY)
18
+ else:
19
+ img = image_path_or_array.copy()
20
+
21
+ # Enhance contrast (CLAHE - Contrast Limited Adaptive Histogram Equalization)
22
+ # We do NOT want to do this if the image is already aggressively thresholded/binarized
23
+ # However, for smooth grayscale training images, CLAHE is great.
24
+ # Let's keep it but recognize it might amplify noise if not careful.
25
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
26
+ img = clahe.apply(img)
27
+
28
+ # Resize keeping aspect ratio
29
+ h, w = img.shape
30
+ target_w, target_h = target_size
31
+
32
+ # Calculate ratio
33
+ ratio_w = target_w / w
34
+ ratio_h = target_h / h
35
+ ratio = min(ratio_w, ratio_h)
36
+
37
+ new_w = int(w * ratio)
38
+ new_h = int(h * ratio)
39
+
40
+ # Check to prevent 0 width/height
41
+ if new_w == 0 or new_h == 0:
42
+ return np.ones((target_h, target_w), dtype=np.uint8) * 255
43
+
44
+ img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
45
+
46
+ # Create target blank (white) image
47
+ target_img = np.ones((target_h, target_w), dtype=np.uint8) * 255
48
+
49
+ # Calculate padding to center it vertically, but align LEFT horizontally
50
+ # (Aligning left is usually better for sequence models like CTC)
51
+ pad_y = (target_h - new_h) // 2
52
+ pad_x = 0 # Align left instead of center
53
+
54
+ # Paste resized image into target
55
+ target_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = img_resized
56
+
57
+ # Return as uint8 array without inverting, to match training behavior (white background)
58
+ return target_img
59
+
60
+ def deskew(img):
61
+ """
62
+ Deskew the image using image moments.
63
+ """
64
+ m = cv2.moments(img)
65
+ if abs(m['mu02']) < 1e-2:
66
+ return img.copy()
67
+
68
+ skew = m['mu11'] / m['mu02']
69
+ M = np.float32([[1, skew, -0.5 * img.shape[0] * skew], [0, 1, 0]])
70
+ img_deskewed = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
71
+ return img_deskewed
72
+
73
+ if __name__ == "__main__":
74
+ # Simple test
75
+ print("Preprocessing module ready.")
src/web/app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import pandas as pd
8
+ import sys
9
+ import os
10
+ import matplotlib.pyplot as plt
11
+
12
+ # Import preprocessing and model
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
14
+ from src.utils.preprocessing import preprocess_image, deskew
15
+ from src.models.crnn import CRNN
16
+
17
+ # Define device
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Build vocabulary directly from labels.csv without loading images
21
+ try:
22
+ df = pd.read_csv('data/labels.csv')
23
+ chars = set()
24
+ for text in df['text']:
25
+ if pd.notna(text):
26
+ chars.update(list(str(text)))
27
+ vocab = sorted(list(chars))
28
+ idx_to_char = {i+1: c for i, c in enumerate(vocab)}
29
+ num_classes = len(vocab) + 1
30
+ print(f"Loaded vocabulary with {len(vocab)} characters")
31
+ except Exception as e:
32
+ print(f"Could not load vocabulary from labels.csv: {e}")
33
+ # Fallback to standard IAM vocab if dataset not available
34
+ vocab = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!? ")
35
+ idx_to_char = {i+1: c for i, c in enumerate(vocab)}
36
+ num_classes = len(vocab) + 1
37
+
38
+ # Load Model
39
+ model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
40
+
41
+ import glob
42
+ def get_latest_checkpoint(weights_dir='weights'):
43
+ checkpoints = glob.glob(os.path.join(weights_dir, 'crnn_baseline_epoch_*.pth'))
44
+ if not checkpoints:
45
+ return None
46
+ # Sort by epoch number
47
+ checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
48
+ return checkpoints[-1]
49
+
50
+ weights_path = get_latest_checkpoint()
51
+ if weights_path and os.path.exists(weights_path):
52
+ print(f"Loading trained weights from {weights_path}...")
53
+ try:
54
+ model.load_state_dict(torch.load(weights_path, map_location=device))
55
+ except Exception as e:
56
+ print(f"Error loading weights perfectly (might be minor mismatch): {e}")
57
+ model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
58
+ else:
59
+ print(f"Warning: Could not find any weights in weights/. Model will output random predictions.")
60
+
61
+ model.eval()
62
+
63
+ # Transform matching training exactly
64
+ transform = transforms.Compose([
65
+ transforms.Resize((32, 1024)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize((0.5,), (0.5,))
68
+ ])
69
+
70
+ def decode_predictions(preds, idx_to_char):
71
+ _, max_preds = torch.max(preds, 2)
72
+ max_preds = max_preds.permute(1, 0)
73
+
74
+ decoded_texts = []
75
+ for batch_idx in range(max_preds.size(0)):
76
+ pred_seq = max_preds[batch_idx]
77
+ decoded_seq = []
78
+ for i in range(len(pred_seq)):
79
+ if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]):
80
+ char_idx = pred_seq[i].item()
81
+ if char_idx in idx_to_char:
82
+ decoded_seq.append(idx_to_char[char_idx])
83
+ decoded_texts.append("".join(decoded_seq))
84
+ return decoded_texts
85
+
86
+ def auto_crop_image(gray_img):
87
+ # Apply Gaussian blur to reduce noise
88
+ blurred = cv2.GaussianBlur(gray_img, (5, 5), 0)
89
+
90
+ # Apply Otsu's thresholding to separate dark ink from white background
91
+ _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
92
+
93
+ # Find contours (shapes) in the image
94
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
95
+
96
+ if not contours:
97
+ return gray_img
98
+
99
+ # Filter contours to exclude tiny noise and giant objects (like the pen)
100
+ img_area = gray_img.shape[0] * gray_img.shape[1]
101
+ valid_contours = []
102
+ for c in contours:
103
+ area = cv2.contourArea(c)
104
+ # Keep contours that are larger than a speck of dust but smaller than half the image
105
+ if 20 < area < (img_area * 0.4):
106
+ valid_contours.append(c)
107
+
108
+ if not valid_contours:
109
+ return gray_img # Fallback to original if filtering removes everything
110
+
111
+ # Find the bounding box that encompasses all valid text contours
112
+ x_min, y_min = float('inf'), float('inf')
113
+ x_max, y_max = 0, 0
114
+
115
+ for c in valid_contours:
116
+ x, y, w, h = cv2.boundingRect(c)
117
+ x_min = min(x_min, x)
118
+ y_min = min(y_min, y)
119
+ x_max = max(x_max, x + w)
120
+ y_max = max(y_max, y + h)
121
+
122
+ # Add a generous padding around the text
123
+ pad_y = int((y_max - y_min) * 0.2)
124
+ pad_x = int((x_max - x_min) * 0.05)
125
+
126
+ x_min = max(0, x_min - pad_x)
127
+ y_min = max(0, y_min - pad_y)
128
+ x_max = min(gray_img.shape[1], x_max + pad_x)
129
+ y_max = min(gray_img.shape[0], y_max + pad_y)
130
+
131
+ # Crop the image
132
+ cropped = gray_img[y_min:y_max, x_min:x_max]
133
+
134
+ # CRITICAL FIX for Out-of-Distribution aspect ratios:
135
+ # The training data (IAM dataset) has an average aspect ratio of ~16:1.
136
+ # The training pipeline blindly squashes images to 32x1024 (32:1 ratio).
137
+ # If a user uploads a short word (like a 3:1 ratio "THANK YOU"),
138
+ # it gets stretched 10x horizontally, destroying the letters!
139
+ # To fix this, we pad the cropped image with white space on the right
140
+ # so its aspect ratio matches the training average (16:1) BEFORE squashing.
141
+
142
+ h, w = cropped.shape
143
+ target_aspect_ratio = 16.0
144
+ if w / h < target_aspect_ratio:
145
+ target_w = int(h * target_aspect_ratio)
146
+ pad_width = target_w - w
147
+ # Pad with white (255) on the right
148
+ cropped = cv2.copyMakeBorder(cropped, 0, 0, 0, pad_width, cv2.BORDER_CONSTANT, value=255)
149
+
150
+ return cropped
151
+
152
+ def process_and_predict(image, apply_auto_crop=True):
153
+ if image is None:
154
+ return None, "Please upload an image.", None, None, None
155
+
156
+ # Convert Gradio Image (which is a PIL Image by default) to grayscale
157
+ if not isinstance(image, Image.Image):
158
+ image = Image.fromarray(image)
159
+
160
+ gray_image = image.convert('L')
161
+
162
+ # For display purposes (Gradio output image)
163
+ img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
164
+ gray_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
165
+
166
+ # CRITICAL: Binarization (Otsu's thresholding) to force pure black text on pure white background
167
+ # This removes shadows, lighting gradients, and colored paper backgrounds
168
+ # that the model was never trained on.
169
+ blurred = cv2.GaussianBlur(gray_cv, (5, 5), 0)
170
+ _, binarized = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
171
+
172
+ if not apply_auto_crop:
173
+ # If auto-crop is disabled, we bypass all fancy preprocessing to precisely
174
+ # match the dataset loading behavior. This ensures dataset images work perfectly.
175
+ gray_image_pil = Image.fromarray(gray_cv)
176
+ img_tensor = transform(gray_image_pil).unsqueeze(0).to(device)
177
+ # For display, just show what the network sees (squashed)
178
+ display_processed_img = np.array(gray_image_pil.resize((1024, 32), Image.BILINEAR))
179
+ else:
180
+ # Auto-crop if requested (using the binarized image for cleaner crops)
181
+ processed_base = auto_crop_image(binarized)
182
+
183
+ deskewed_img = deskew(processed_base)
184
+ processed_img_np = preprocess_image(deskewed_img, target_size=(1024, 32))
185
+ display_processed_img = processed_img_np
186
+
187
+ # Convert cropped numpy array back to PIL for tensor transform
188
+ gray_image_cropped = Image.fromarray(display_processed_img)
189
+
190
+ # For Model Prediction
191
+ # We must use exactly the same transform as training, and pass a PIL image
192
+ img_tensor = transform(gray_image_cropped).unsqueeze(0).to(device)
193
+
194
+ # Predict and extract features
195
+ with torch.no_grad():
196
+ # Get CNN features for activation map
197
+ cnn_features = model.cnn(img_tensor) # shape: (1, 512, 1, seq_len)
198
+
199
+ preds = model(img_tensor)
200
+ preds = preds.permute(1, 0, 2) # (seq_len, batch, num_classes)
201
+ decoded_text = decode_predictions(preds, idx_to_char)[0]
202
+
203
+ # Calculate probabilities from LogSoftmax output
204
+ probs = torch.exp(preds[:, 0, :]) # shape: (seq_len, num_classes)
205
+
206
+ if not decoded_text.strip():
207
+ decoded_text = "[Model returned blank - Needs more training epochs]"
208
+
209
+ # 1. Generate CTC Probability Matrix Heatmap
210
+ probs_np = probs.cpu().numpy().T # shape: (num_classes, seq_len)
211
+ fig_heatmap, ax1 = plt.subplots(figsize=(10, 4))
212
+ cax = ax1.imshow(probs_np, aspect='auto', cmap='viridis')
213
+ ax1.set_title("CTC Probability Matrix Heatmap")
214
+ ax1.set_xlabel("Time Frame (Sequence Steps)")
215
+ ax1.set_ylabel("Vocabulary Character Index")
216
+ fig_heatmap.colorbar(cax, ax=ax1, fraction=0.046, pad=0.04, label="Probability")
217
+ plt.tight_layout()
218
+
219
+ # 2. Generate Character Confidence Bar Chart
220
+ max_probs, max_idx = torch.max(probs, dim=1)
221
+ chars = []
222
+ confidences = []
223
+
224
+ for i in range(len(max_idx)):
225
+ if max_idx[i] != 0 and (i == 0 or max_idx[i] != max_idx[i-1]):
226
+ char_idx = max_idx[i].item()
227
+ if char_idx in idx_to_char:
228
+ chars.append(idx_to_char[char_idx])
229
+ confidences.append(max_probs[i].item())
230
+
231
+ # Adjust width based on number of characters
232
+ fig_bar, ax2 = plt.subplots(figsize=(max(8, len(chars)*0.4), 4))
233
+ if chars:
234
+ bars = ax2.bar(range(len(chars)), confidences, color='#FF9900')
235
+ ax2.set_xticks(range(len(chars)))
236
+ ax2.set_xticklabels(chars)
237
+ ax2.set_ylim(0, 1.1)
238
+ ax2.set_title("Character Confidence Scores")
239
+ ax2.set_ylabel("Confidence Probability")
240
+
241
+ # Add percentage labels above bars
242
+ for bar in bars:
243
+ yval = bar.get_height()
244
+ ax2.text(bar.get_x() + bar.get_width()/2.0, yval + 0.02,
245
+ f'{yval*100:.0f}%', va='bottom', ha='center', fontsize=8, rotation=45)
246
+ else:
247
+ ax2.text(0.5, 0.5, "No characters predicted", ha='center', va='center')
248
+
249
+ plt.tight_layout()
250
+
251
+ # 3. Generate CNN Feature Activation Overlay
252
+ # Average the CNN features across all channels to get a 1D activation map
253
+ activation = torch.mean(cnn_features, dim=1).squeeze().cpu().numpy()
254
+
255
+ # Normalize activation to 0-255
256
+ activation = (activation - activation.min()) / (activation.max() - activation.min() + 1e-8)
257
+ activation = (activation * 255).astype(np.uint8)
258
+
259
+ # Resize to match the original image dimensions
260
+ heatmap_img = cv2.resize(activation, (processed_img_np.shape[1], processed_img_np.shape[0]))
261
+
262
+ # Apply color map
263
+ heatmap_color = cv2.applyColorMap(heatmap_img, cv2.COLORMAP_JET)
264
+
265
+ # Convert grayscale original image to BGR so we can blend it
266
+ original_bgr = cv2.cvtColor(display_processed_img, cv2.COLOR_GRAY2BGR)
267
+
268
+ # Overlay heatmap on original image (50% alpha blend)
269
+ overlay_img = cv2.addWeighted(heatmap_color, 0.5, original_bgr, 0.5, 0)
270
+ # Convert BGR to RGB for Gradio display
271
+ overlay_img = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB)
272
+
273
+ return display_processed_img, decoded_text, fig_heatmap, fig_bar, overlay_img
274
+
275
+ # Redesign UI with Gradio Blocks for a proper Dashboard layout
276
+ with gr.Blocks(title="Handwritten Text Recognition (HTR)", theme=gr.themes.Soft()) as demo:
277
+ gr.Markdown("<h1 style='text-align: center;'>Handwritten Text Recognition (HTR) Dashboard</h1>")
278
+ gr.Markdown("Upload an image of handwritten text. The system will preprocess it and extract the text using our trained custom CRNN model.")
279
+
280
+ with gr.Row():
281
+ with gr.Column(scale=1):
282
+ # Editor tool allows manual cropping in UI before sending
283
+ input_image = gr.Image(type="pil", label="Upload Handwritten Text Image")
284
+ auto_crop_checkbox = gr.Checkbox(label="✨ Auto-Crop Background (Smart Vision)", value=True, info="Automatically zooms in on the text and removes giant background objects/pens.")
285
+ with gr.Row():
286
+ clear_btn = gr.Button("Clear")
287
+ submit_btn = gr.Button("Submit", variant="primary")
288
+
289
+ with gr.Column(scale=1):
290
+ output_image = gr.Image(type="numpy", label="Preprocessed (1024 x 32)")
291
+ gr.Markdown("<p style='font-size: 12px; color: gray;'>Grayscale, aspect-ratio preserved, padded to 32x1024</p>")
292
+ output_text = gr.Textbox(label="Predicted Text", lines=2)
293
+
294
+ gr.Markdown("---")
295
+ gr.Markdown("### 📊 Model Insights & Analytics (Explainable AI)")
296
+
297
+ with gr.Accordion("📖 How to read these graphs (Interpretation Guide)", open=False):
298
+ gr.Markdown("""
299
+ **1. CNN Feature Activation Overlay:** Shows exactly where the model's 'eyes' are focusing on the image. Red/hot areas indicate regions with strong visual features (like complex curves or sharp lines) that the Convolutional Neural Network detected.
300
+
301
+ **2. CTC Probability Matrix Heatmap:** Shows *when* the model made a decision. The X-axis is the timeline (reading left-to-right), and the Y-axis contains all possible characters. Yellow dots indicate the exact moment the AI identified a specific letter.
302
+
303
+ **3. Character Confidence Scores:** Shows *how sure* the model is about each letter it predicted. If the model misreads a word, this chart usually shows a low confidence score for the incorrect letter, proving it was uncertain.
304
+ """)
305
+
306
+ with gr.Row():
307
+ cnn_activation_image = gr.Image(type="numpy", label="1. CNN Feature Activation Overlay")
308
+
309
+ with gr.Row():
310
+ heatmap_plot = gr.Plot(label="2. CTC Probability Heatmap")
311
+
312
+ with gr.Row():
313
+ confidence_plot = gr.Plot(label="3. Character Confidence Scores")
314
+
315
+ submit_btn.click(
316
+ fn=process_and_predict,
317
+ inputs=[input_image, auto_crop_checkbox],
318
+ outputs=[output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image]
319
+ )
320
+
321
+ clear_btn.click(
322
+ fn=lambda: [None, True, None, "", None, None, None],
323
+ inputs=[],
324
+ outputs=[input_image, auto_crop_checkbox, output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image]
325
+ )
326
+
327
+ if __name__ == "__main__":
328
+ demo.launch(share=True)
weights/crnn_baseline_epoch_30.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f203e852eb08710b520beed65b3bbf0edb5c8fb66ac34e61936bb9660ed2dec7
3
+ size 31473673