IsmatS commited on
Commit
e40c8b3
·
0 Parent(s):

Complete handwriting recognition project

Browse files

- Analysis notebook with EDA and 5 detailed charts
- Training notebook for Google Colab GPU
- README with documentation
- Clean project structure

Files changed (5) hide show
  1. .gitignore +33 -0
  2. README.md +121 -0
  3. analysis.ipynb +0 -0
  4. requirements.txt +8 -0
  5. train_colab.ipynb +524 -0
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+
8
+ # Virtual environments
9
+ venv/
10
+ env/
11
+ ENV/
12
+
13
+ # Jupyter Notebook
14
+ .ipynb_checkpoints
15
+
16
+ # Model files
17
+ models/
18
+ *.pth
19
+ *.pkl
20
+
21
+ # Data
22
+ archive/
23
+ data/
24
+
25
+ # IDE
26
+ .vscode/
27
+ .idea/
28
+ *.swp
29
+ *.swo
30
+
31
+ # OS
32
+ .DS_Store
33
+ Thumbs.db
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Handwriting Recognition
2
+
3
+ Complete handwriting recognition system using CNN-BiLSTM-CTC on the IAM dataset.
4
+
5
+ ## 📁 Files
6
+
7
+ ### 1. **analysis.ipynb** - Dataset Analysis
8
+ - Exploratory Data Analysis (EDA)
9
+ - 5 detailed charts saved to `charts/` folder
10
+ - Run locally or on Colab (no GPU needed)
11
+
12
+ ### 2. **train_colab.ipynb** - Model Training (GPU)
13
+ - **⚡ Google Colab GPU compatible**
14
+ - Full training pipeline
15
+ - CNN-BiLSTM-CTC model (~9.1M parameters)
16
+ - Automatic model saving
17
+ - Download trained model for deployment
18
+
19
+ ## 🚀 Quick Start
20
+
21
+ ### Option 1: Analyze Dataset (Local/Colab)
22
+ ```bash
23
+ jupyter notebook analysis.ipynb
24
+ ```
25
+ - No GPU needed
26
+ - Generates 5 EDA charts
27
+ - Fast (~2 minutes)
28
+
29
+ ### Option 2: Train Model (Google Colab GPU)
30
+
31
+ 1. **Upload `train_colab.ipynb` to Google Colab**
32
+ 2. **Change runtime to GPU:**
33
+ - Runtime → Change runtime type → GPU (T4 recommended)
34
+ 3. **Run all cells**
35
+ 4. **Download trained model** (last cell)
36
+
37
+ **Training Time:** ~1-2 hours for 20 epochs on T4 GPU
38
+
39
+ ## 📊 Charts Generated
40
+
41
+ From `analysis.ipynb`:
42
+ 1. `charts/01_sample_images.png` - 10 sample handwritten texts
43
+ 2. `charts/02_text_length_distribution.png` - Text statistics
44
+ 3. `charts/03_image_dimensions.png` - Image analysis
45
+ 4. `charts/04_character_frequency.png` - Character distribution
46
+ 5. `charts/05_summary_statistics.png` - Summary table
47
+
48
+ ## 🎯 Model Details
49
+
50
+ **Architecture:**
51
+ - **CNN**: 7 convolutional blocks (feature extraction)
52
+ - **BiLSTM**: 2 layers, 256 hidden units (sequence modeling)
53
+ - **CTC Loss**: Alignment-free training
54
+
55
+ **Dataset:** Teklia/IAM-line (Hugging Face)
56
+ - Train: 6,482 samples
57
+ - Validation: 976 samples
58
+ - Test: 2,915 samples
59
+
60
+ **Metrics:**
61
+ - **CER** (Character Error Rate)
62
+ - **WER** (Word Error Rate)
63
+
64
+ ## 💾 Model Files
65
+
66
+ After training in Colab:
67
+ - `best_model.pth` - Trained model weights
68
+ - `training_history.png` - Loss/CER/WER plots
69
+ - `predictions.png` - Sample predictions
70
+
71
+ ## 📦 Requirements
72
+
73
+ ```
74
+ torch>=2.0.0
75
+ datasets>=2.14.0
76
+ pillow>=9.5.0
77
+ numpy>=1.24.0
78
+ matplotlib>=3.7.0
79
+ seaborn>=0.13.0
80
+ jupyter>=1.0.0
81
+ jiwer>=3.0.0
82
+ ```
83
+
84
+ ## 🔧 Usage
85
+
86
+ ### Load Trained Model
87
+ ```python
88
+ import torch
89
+
90
+ # Load checkpoint
91
+ checkpoint = torch.load('best_model.pth')
92
+ char_mapper = checkpoint['char_mapper']
93
+
94
+ # Create model
95
+ from train_colab import CRNN # Copy model class
96
+ model = CRNN(num_chars=len(char_mapper.chars))
97
+ model.load_state_dict(checkpoint['model_state_dict'])
98
+ model.eval()
99
+
100
+ # Predict
101
+ # ... (preprocessing + inference)
102
+ ```
103
+
104
+ ## 📝 Notes
105
+
106
+ - **GPU strongly recommended** for training (use Colab T4)
107
+ - Training on CPU will be extremely slow (~20x slower)
108
+ - Colab free tier: 12-hour limit, sufficient for 20 epochs
109
+ - Model checkpoint includes character mapper for deployment
110
+
111
+ ## 🎓 Training Tips
112
+
113
+ 1. **Start with fewer epochs** (5-10) to test
114
+ 2. **Monitor CER/WER** - stop if not improving
115
+ 3. **Increase epochs** if still improving (up to 50)
116
+ 4. **Save checkpoint** before Colab disconnects
117
+ 5. **Download model immediately** after training
118
+
119
+ ## 📄 License
120
+
121
+ Dataset: IAM Database (research use)
analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ datasets>=2.14.0
3
+ pillow>=9.5.0
4
+ numpy>=1.24.0
5
+ matplotlib>=3.7.0
6
+ seaborn>=0.13.0
7
+ jupyter>=1.0.0
8
+ jiwer>=3.0.0
train_colab.ipynb ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Handwriting Recognition Training (Google Colab GPU)\n",
8
+ "## CNN-BiLSTM-CTC Model on IAM Dataset\n",
9
+ "\n",
10
+ "**Runtime:** GPU (Runtime → Change runtime type → GPU)"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "## 1. Setup & Installations"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "# Install required packages\n",
27
+ "!pip install -q datasets transformers jiwer\n",
28
+ "\n",
29
+ "import torch\n",
30
+ "import torch.nn as nn\n",
31
+ "import torch.optim as optim\n",
32
+ "from torch.utils.data import Dataset, DataLoader\n",
33
+ "from datasets import load_dataset\n",
34
+ "import numpy as np\n",
35
+ "from PIL import Image\n",
36
+ "from tqdm import tqdm\n",
37
+ "from jiwer import cer, wer\n",
38
+ "import matplotlib.pyplot as plt\n",
39
+ "\n",
40
+ "# Check GPU\n",
41
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
42
+ "print(f\"✓ Using device: {device}\")\n",
43
+ "if torch.cuda.is_available():\n",
44
+ " print(f\" GPU: {torch.cuda.get_device_name(0)}\")\n",
45
+ " print(f\" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {},
51
+ "source": [
52
+ "## 2. Model Architecture"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "class CRNN(nn.Module):\n",
62
+ " \"\"\"CNN-BiLSTM-CTC for Handwriting Recognition\"\"\"\n",
63
+ " \n",
64
+ " def __init__(self, img_height=128, num_chars=80, hidden_size=256, num_layers=2):\n",
65
+ " super(CRNN, self).__init__()\n",
66
+ " \n",
67
+ " # CNN Feature Extractor\n",
68
+ " self.cnn = nn.Sequential(\n",
69
+ " nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),\n",
70
+ " nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2),\n",
71
+ " nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),\n",
72
+ " nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d((2, 1)),\n",
73
+ " nn.Conv2d(256, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),\n",
74
+ " nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d((2, 1)),\n",
75
+ " nn.Conv2d(512, 512, 2), nn.BatchNorm2d(512), nn.ReLU(),\n",
76
+ " )\n",
77
+ " \n",
78
+ " self.map2seq = nn.Linear(512 * 7, hidden_size)\n",
79
+ " self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers, bidirectional=True, \n",
80
+ " dropout=0.3 if num_layers > 1 else 0, batch_first=True)\n",
81
+ " self.fc = nn.Linear(hidden_size * 2, num_chars + 1)\n",
82
+ " \n",
83
+ " def forward(self, x):\n",
84
+ " conv = self.cnn(x)\n",
85
+ " b, c, h, w = conv.size()\n",
86
+ " conv = conv.permute(0, 3, 1, 2).reshape(b, w, c * h)\n",
87
+ " seq = self.map2seq(conv)\n",
88
+ " rnn_out, _ = self.rnn(seq)\n",
89
+ " output = self.fc(rnn_out)\n",
90
+ " return torch.nn.functional.log_softmax(output, dim=2)\n",
91
+ "\n",
92
+ "print(\"✓ Model architecture defined\")"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "markdown",
97
+ "metadata": {},
98
+ "source": [
99
+ "## 3. Character Mapper"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "class CharacterMapper:\n",
109
+ " def __init__(self):\n",
110
+ " chars = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,;:!?\\'\"()-')\n",
111
+ " self.chars = sorted(list(chars))\n",
112
+ " self.char2idx = {c: i+1 for i, c in enumerate(self.chars)}\n",
113
+ " self.idx2char = {i+1: c for i, c in enumerate(self.chars)}\n",
114
+ " self.idx2char[0] = '' # CTC blank\n",
115
+ " self.num_classes = len(self.chars) + 1\n",
116
+ " \n",
117
+ " def encode(self, text):\n",
118
+ " return [self.char2idx[c] for c in text if c in self.char2idx]\n",
119
+ " \n",
120
+ " def decode(self, indices):\n",
121
+ " chars, prev = [], None\n",
122
+ " for idx in indices:\n",
123
+ " if idx != 0 and idx != prev and idx in self.idx2char:\n",
124
+ " chars.append(self.idx2char[idx])\n",
125
+ " prev = idx\n",
126
+ " return ''.join(chars)\n",
127
+ "\n",
128
+ "char_mapper = CharacterMapper()\n",
129
+ "print(f\"✓ Character mapper: {char_mapper.num_classes} classes\")"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {},
135
+ "source": [
136
+ "## 4. Dataset & DataLoader"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "class IAMDataset(Dataset):\n",
146
+ " def __init__(self, split='train', img_height=128):\n",
147
+ " self.img_height = img_height\n",
148
+ " self.dataset = load_dataset(\"Teklia/IAM-line\", split=split)\n",
149
+ " print(f\" Loaded {len(self.dataset)} samples\")\n",
150
+ " \n",
151
+ " def __len__(self):\n",
152
+ " return len(self.dataset)\n",
153
+ " \n",
154
+ " def __getitem__(self, idx):\n",
155
+ " sample = self.dataset[idx]\n",
156
+ " img = sample['image'].convert('L')\n",
157
+ " text = sample['text']\n",
158
+ " \n",
159
+ " # Resize\n",
160
+ " w, h = img.size\n",
161
+ " new_w = int(self.img_height * (w / h))\n",
162
+ " img = img.resize((new_w, self.img_height), Image.Resampling.LANCZOS)\n",
163
+ " \n",
164
+ " # Normalize\n",
165
+ " img = np.array(img, dtype=np.float32) / 255.0\n",
166
+ " img = (img - 0.5) / 0.5\n",
167
+ " img = torch.FloatTensor(img).unsqueeze(0)\n",
168
+ " \n",
169
+ " target = torch.LongTensor(char_mapper.encode(text))\n",
170
+ " return img, target, len(target), text\n",
171
+ "\n",
172
+ "def collate_fn(batch):\n",
173
+ " images, targets, target_lengths, texts = zip(*batch)\n",
174
+ " max_w = max(img.shape[2] for img in images)\n",
175
+ " b, h = len(images), images[0].shape[1]\n",
176
+ " \n",
177
+ " padded_imgs = torch.zeros(b, 1, h, max_w)\n",
178
+ " input_lengths = []\n",
179
+ " \n",
180
+ " for i, img in enumerate(images):\n",
181
+ " w = img.shape[2]\n",
182
+ " padded_imgs[i, :, :, :w] = img\n",
183
+ " input_lengths.append((w // 4) - 1)\n",
184
+ " \n",
185
+ " return {\n",
186
+ " 'images': padded_imgs,\n",
187
+ " 'targets': torch.cat(targets) if targets[0].numel() > 0 else torch.LongTensor([]),\n",
188
+ " 'target_lengths': torch.LongTensor(target_lengths),\n",
189
+ " 'input_lengths': torch.LongTensor(input_lengths),\n",
190
+ " 'texts': texts\n",
191
+ " }\n",
192
+ "\n",
193
+ "print(\"Loading datasets...\")\n",
194
+ "train_dataset = IAMDataset('train')\n",
195
+ "val_dataset = IAMDataset('validation')\n",
196
+ "\n",
197
+ "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=2)\n",
198
+ "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=2)\n",
199
+ "\n",
200
+ "print(f\"✓ Train batches: {len(train_loader)}, Val batches: {len(val_loader)}\")"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "metadata": {},
206
+ "source": [
207
+ "## 5. Training Functions"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "def decode_predictions(outputs, char_mapper):\n",
217
+ " _, max_indices = torch.max(outputs, dim=2)\n",
218
+ " return [char_mapper.decode(idx.cpu().numpy().tolist()) for idx in max_indices]\n",
219
+ "\n",
220
+ "def compute_metrics(preds, truths):\n",
221
+ " valid = [(p, g) for p, g in zip(preds, truths) if p and g]\n",
222
+ " if not valid:\n",
223
+ " return 0.0, 0.0\n",
224
+ " preds, truths = zip(*valid)\n",
225
+ " try:\n",
226
+ " return cer(truths, preds), wer(truths, preds)\n",
227
+ " except:\n",
228
+ " return 0.0, 0.0\n",
229
+ "\n",
230
+ "def train_epoch(model, loader, criterion, optimizer, device, epoch):\n",
231
+ " model.train()\n",
232
+ " total_loss = 0\n",
233
+ " pbar = tqdm(loader, desc=f\"Epoch {epoch}\")\n",
234
+ " \n",
235
+ " for batch in pbar:\n",
236
+ " images = batch['images'].to(device)\n",
237
+ " targets = batch['targets'].to(device)\n",
238
+ " input_lengths = batch['input_lengths']\n",
239
+ " target_lengths = batch['target_lengths']\n",
240
+ " \n",
241
+ " outputs = model(images)\n",
242
+ " outputs = outputs.permute(1, 0, 2) # CTC format\n",
243
+ " \n",
244
+ " loss = criterion(outputs, targets, input_lengths, target_lengths)\n",
245
+ " \n",
246
+ " optimizer.zero_grad()\n",
247
+ " loss.backward()\n",
248
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
249
+ " optimizer.step()\n",
250
+ " \n",
251
+ " total_loss += loss.item()\n",
252
+ " pbar.set_postfix({'loss': f'{loss.item():.4f}'})\n",
253
+ " \n",
254
+ " return total_loss / len(loader)\n",
255
+ "\n",
256
+ "def validate(model, loader, criterion, device):\n",
257
+ " model.eval()\n",
258
+ " total_loss = 0\n",
259
+ " all_preds, all_truths = [], []\n",
260
+ " \n",
261
+ " with torch.no_grad():\n",
262
+ " for batch in tqdm(loader, desc=\"Validating\"):\n",
263
+ " images = batch['images'].to(device)\n",
264
+ " targets = batch['targets'].to(device)\n",
265
+ " input_lengths = batch['input_lengths']\n",
266
+ " target_lengths = batch['target_lengths']\n",
267
+ " texts = batch['texts']\n",
268
+ " \n",
269
+ " outputs = model(images)\n",
270
+ " outputs_ctc = outputs.permute(1, 0, 2)\n",
271
+ " loss = criterion(outputs_ctc, targets, input_lengths, target_lengths)\n",
272
+ " total_loss += loss.item()\n",
273
+ " \n",
274
+ " preds = decode_predictions(outputs, char_mapper)\n",
275
+ " all_preds.extend(preds)\n",
276
+ " all_truths.extend(texts)\n",
277
+ " \n",
278
+ " avg_loss = total_loss / len(loader)\n",
279
+ " cer_score, wer_score = compute_metrics(all_preds, all_truths)\n",
280
+ " \n",
281
+ " # Show examples\n",
282
+ " print(\"\\nExample predictions:\")\n",
283
+ " for i in range(min(3, len(all_preds))):\n",
284
+ " print(f\" GT: {all_truths[i]}\")\n",
285
+ " print(f\" Pred: {all_preds[i]}\")\n",
286
+ " \n",
287
+ " return avg_loss, cer_score, wer_score\n",
288
+ "\n",
289
+ "print(\"✓ Training functions ready\")"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "markdown",
294
+ "metadata": {},
295
+ "source": [
296
+ "## 6. Train Model"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "# Hyperparameters\n",
306
+ "EPOCHS = 20\n",
307
+ "LEARNING_RATE = 0.001\n",
308
+ "\n",
309
+ "# Create model\n",
310
+ "model = CRNN(img_height=128, num_chars=len(char_mapper.chars), hidden_size=256, num_layers=2)\n",
311
+ "model = model.to(device)\n",
312
+ "print(f\"Model: {sum(p.numel() for p in model.parameters()):,} parameters\")\n",
313
+ "\n",
314
+ "# Loss & Optimizer\n",
315
+ "criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n",
316
+ "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
317
+ "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)\n",
318
+ "\n",
319
+ "# Training loop\n",
320
+ "history = {'train_loss': [], 'val_loss': [], 'val_cer': [], 'val_wer': []}\n",
321
+ "best_cer = float('inf')\n",
322
+ "\n",
323
+ "print(f\"\\n{'='*60}\")\n",
324
+ "print(f\"Starting training: {EPOCHS} epochs\")\n",
325
+ "print(f\"{'='*60}\\n\")\n",
326
+ "\n",
327
+ "for epoch in range(1, EPOCHS + 1):\n",
328
+ " print(f\"\\nEpoch {epoch}/{EPOCHS}\")\n",
329
+ " print(\"-\" * 60)\n",
330
+ " \n",
331
+ " # Train\n",
332
+ " train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch)\n",
333
+ " print(f\"Train Loss: {train_loss:.4f}\")\n",
334
+ " \n",
335
+ " # Validate\n",
336
+ " val_loss, val_cer, val_wer = validate(model, val_loader, criterion, device)\n",
337
+ " print(f\"Val Loss: {val_loss:.4f}, CER: {val_cer:.4f}, WER: {val_wer:.4f}\")\n",
338
+ " \n",
339
+ " # Save history\n",
340
+ " history['train_loss'].append(train_loss)\n",
341
+ " history['val_loss'].append(val_loss)\n",
342
+ " history['val_cer'].append(val_cer)\n",
343
+ " history['val_wer'].append(val_wer)\n",
344
+ " \n",
345
+ " # Scheduler\n",
346
+ " scheduler.step(val_loss)\n",
347
+ " \n",
348
+ " # Save best\n",
349
+ " if val_cer < best_cer:\n",
350
+ " best_cer = val_cer\n",
351
+ " torch.save({\n",
352
+ " 'epoch': epoch,\n",
353
+ " 'model_state_dict': model.state_dict(),\n",
354
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
355
+ " 'val_cer': val_cer,\n",
356
+ " 'val_wer': val_wer,\n",
357
+ " 'char_mapper': char_mapper,\n",
358
+ " }, 'best_model.pth')\n",
359
+ " print(f\"✓ Saved best model (CER: {val_cer:.4f})\")\n",
360
+ "\n",
361
+ "print(f\"\\n{'='*60}\")\n",
362
+ "print(f\"Training Complete! Best CER: {best_cer:.4f}\")\n",
363
+ "print(f\"{'='*60}\")"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "markdown",
368
+ "metadata": {},
369
+ "source": [
370
+ "## 7. Plot Training History"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "metadata": {},
377
+ "outputs": [],
378
+ "source": [
379
+ "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
380
+ "\n",
381
+ "# Loss\n",
382
+ "axes[0].plot(history['train_loss'], label='Train', marker='o')\n",
383
+ "axes[0].plot(history['val_loss'], label='Val', marker='s')\n",
384
+ "axes[0].set_xlabel('Epoch')\n",
385
+ "axes[0].set_ylabel('Loss')\n",
386
+ "axes[0].set_title('Loss')\n",
387
+ "axes[0].legend()\n",
388
+ "axes[0].grid(alpha=0.3)\n",
389
+ "\n",
390
+ "# CER\n",
391
+ "axes[1].plot(history['val_cer'], label='CER', marker='o', color='green')\n",
392
+ "axes[1].set_xlabel('Epoch')\n",
393
+ "axes[1].set_ylabel('Character Error Rate')\n",
394
+ "axes[1].set_title('CER (lower is better)')\n",
395
+ "axes[1].legend()\n",
396
+ "axes[1].grid(alpha=0.3)\n",
397
+ "\n",
398
+ "# WER\n",
399
+ "axes[2].plot(history['val_wer'], label='WER', marker='s', color='orange')\n",
400
+ "axes[2].set_xlabel('Epoch')\n",
401
+ "axes[2].set_ylabel('Word Error Rate')\n",
402
+ "axes[2].set_title('WER (lower is better)')\n",
403
+ "axes[2].legend()\n",
404
+ "axes[2].grid(alpha=0.3)\n",
405
+ "\n",
406
+ "plt.tight_layout()\n",
407
+ "plt.savefig('training_history.png', dpi=150)\n",
408
+ "plt.show()\n",
409
+ "\n",
410
+ "print(f\"✓ Final metrics: CER={history['val_cer'][-1]:.4f}, WER={history['val_wer'][-1]:.4f}\")"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "markdown",
415
+ "metadata": {},
416
+ "source": [
417
+ "## 8. Inference / Prediction"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "metadata": {},
424
+ "outputs": [],
425
+ "source": [
426
+ "# Load best model\n",
427
+ "checkpoint = torch.load('best_model.pth')\n",
428
+ "model.load_state_dict(checkpoint['model_state_dict'])\n",
429
+ "model.eval()\n",
430
+ "\n",
431
+ "print(f\"✓ Loaded best model (Epoch {checkpoint['epoch']}, CER: {checkpoint['val_cer']:.4f})\")\n",
432
+ "\n",
433
+ "# Test on validation samples\n",
434
+ "test_batch = next(iter(val_loader))\n",
435
+ "\n",
436
+ "with torch.no_grad():\n",
437
+ " images = test_batch['images'].to(device)\n",
438
+ " outputs = model(images)\n",
439
+ " predictions = decode_predictions(outputs, char_mapper)\n",
440
+ "\n",
441
+ "# Visualize predictions\n",
442
+ "fig, axes = plt.subplots(5, 1, figsize=(16, 15))\n",
443
+ "\n",
444
+ "for i in range(5):\n",
445
+ " img = test_batch['images'][i, 0].cpu().numpy()\n",
446
+ " img = (img * 0.5) + 0.5 # Denormalize\n",
447
+ " \n",
448
+ " gt = test_batch['texts'][i]\n",
449
+ " pred = predictions[i]\n",
450
+ " \n",
451
+ " axes[i].imshow(img, cmap='gray')\n",
452
+ " axes[i].set_title(f\"GT: {gt}\\nPrediction: {pred}\", fontsize=11, loc='left')\n",
453
+ " axes[i].axis('off')\n",
454
+ "\n",
455
+ "plt.suptitle('Predictions on Validation Set', fontsize=16, fontweight='bold')\n",
456
+ "plt.tight_layout()\n",
457
+ "plt.savefig('predictions.png', dpi=150)\n",
458
+ "plt.show()\n",
459
+ "\n",
460
+ "print(\"\\n✓ Predictions saved to 'predictions.png'\")"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "markdown",
465
+ "metadata": {},
466
+ "source": [
467
+ "## 9. Download Model (Optional)"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "metadata": {},
474
+ "outputs": [],
475
+ "source": [
476
+ "# Download model to local machine\n",
477
+ "from google.colab import files\n",
478
+ "\n",
479
+ "print(\"Downloading model...\")\n",
480
+ "files.download('best_model.pth')\n",
481
+ "print(\"\\n✓ Model downloaded! Use it for deployment.\")"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "markdown",
486
+ "metadata": {},
487
+ "source": [
488
+ "---\n",
489
+ "## Summary\n",
490
+ "\n",
491
+ "### ✓ Training Complete!\n",
492
+ "\n",
493
+ "**Model:**\n",
494
+ "- Architecture: CNN-BiLSTM-CTC\n",
495
+ "- Parameters: ~9.1M\n",
496
+ "- Trained on: IAM-line dataset\n",
497
+ "\n",
498
+ "**Files Generated:**\n",
499
+ "- `best_model.pth` - Best model checkpoint\n",
500
+ "- `training_history.png` - Loss/CER/WER plots\n",
501
+ "- `predictions.png` - Sample predictions\n",
502
+ "\n",
503
+ "**Next Steps:**\n",
504
+ "1. Download `best_model.pth` for deployment\n",
505
+ "2. Use it in API/frontend applications\n",
506
+ "3. Fine-tune with more epochs if needed"
507
+ ]
508
+ }
509
+ ],
510
+ "metadata": {
511
+ "accelerator": "GPU",
512
+ "colab": {
513
+ "gpuType": "T4",
514
+ "provenance": []
515
+ },
516
+ "kernelspec": {
517
+ "display_name": "Python 3",
518
+ "language": "python",
519
+ "name": "python3"
520
+ }
521
+ },
522
+ "nbformat": 4,
523
+ "nbformat_minor": 0
524
+ }