marianeft commited on
Commit
442eace
Β·
1 Parent(s): f18c275

Initial upload of files

Browse files
README.md CHANGED
@@ -1,19 +1,203 @@
1
  ---
2
- title: Handwritten Name Recognizer V2
3
- emoji: πŸš€
4
- colorFrom: red
5
- colorTo: red
6
  sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
  pinned: false
11
- short_description: Streamlit template space
12
  ---
13
 
14
- # Welcome to Streamlit!
15
 
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
 
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Handwritten Name Recognizer
3
+ emoji: πŸ“Š
4
+ colorFrom: indigo
5
+ colorTo: gray
6
  sdk: docker
 
 
 
7
  pinned: false
8
+ license: apache-2.0
9
  ---
10
 
11
+ # Handwritten Name Recognition (OCR) App ✍🏻
12
 
13
+ *An end-to-end Streamlit application for training and predicting handwritten names using a CRNN model.*
14
 
15
+ [πŸ“ƒ Demo and Documentation](https://drive.google.com/drive/folders/1rOmwyTJkDCsU-Wuh-_CzvQ9sdb_ci_kX?usp=sharing)
16
+ [πŸ“‚ GitHub Repository](https://github.com/marianeft/handwritten_name_ocr_app.git)
17
+
18
+ ---
19
+
20
+ ## Table of Contents
21
+
22
+ * [Overview](#overview)
23
+ * [Quickstart](#quickstart)
24
+ * [Features](#features)
25
+ * [Project Structure](#project-structure)
26
+ * [Project Index](#project-index)
27
+ * [Roadmap](#roadmap)
28
+ * [Contribution](#contribution)
29
+ * [License](#license)
30
+ * [Acknowledgements](#acknowledgements)
31
+
32
+ ---
33
+
34
+ ## πŸ•ΉοΈ Overview
35
+
36
+ This project implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) architecture built with PyTorch. The application is presented as an interactive web interface using Streamlit, allowing users to:
37
+
38
+ 1. **Train** a new OCR model from a local dataset.
39
+ 2. **Load** a pre-trained model.
40
+ 3. **Predict** text from uploaded handwritten image files.
41
+ 4. **Upload** the local dataset to the Hugging Face Hub for sharing and versioning.
42
+
43
+ The CRNN model combines a CNN backbone for feature extraction from images and a Bidirectional LSTM layer for sequence modeling, followed by a linear layer for character classification using CTC (Connectionist Temporal Classification) Loss.
44
+
45
+ ---
46
+
47
+ ## 🚩 Quickstart
48
+
49
+ Follow these steps to get the application up and running on your local machine.
50
+
51
+ ### Prerequisites
52
+
53
+ * Python 3.8+
54
+ * `pip` (Python package installer)
55
+
56
+ #### 1. Clone the Repository (or set up your project folder)
57
+
58
+ Ensure your project structure matches the expected layout (e.g., `app.py`, `config.py`, `data/`, `models/` etc.).
59
+
60
+ #### 2. Create and Activate a Virtual Environment
61
+ It's highly recommended to use a virtual environment to manage dependencies.
62
+
63
+ ``` bash
64
+ # Navigate to your project root directory
65
+ cd path/to/your/handwritten_name_ocr_app
66
+
67
+ # Create a virtual environment named 'venvy'
68
+ python -m venv venvy
69
+
70
+ # Activate the virtual environment
71
+ # On Windows (Command Prompt):
72
+ .\venvy\Scripts\activate.bat
73
+
74
+ # On Windows (PowerShell):
75
+ .\venvy\Scripts\Activate.ps1
76
+
77
+ # On macOS/Linux:
78
+ source venvy/bin/activate
79
+ ```
80
+
81
+ #### 3. Install Dependencies
82
+ With your virtual environment activated, install all required Python packages:
83
+ `pip install streamlit` `pandas` `numpy` `Pillow` `torch` `torchvision` `scikit-learn` `tqdm` `editdistance` `huggingface_hub`
84
+
85
+
86
+ *Note on PyTorch (torch and torchvision):
87
+ The command above installs the CPU-only version of PyTorch. If you have a CUDA-enabled GPU and want to leverage it for faster training, please refer to the official PyTorch website (pytorch.org/get-started/locally/) for specific installation commands tailored to your CUDA version.*
88
+
89
+ #### 4. Prepare Your Dataset
90
+ The application expects a dataset structured as follows:
91
+ ``` bash
92
+ data/
93
+ β”œβ”€β”€ images/
94
+ β”‚ β”œβ”€β”€ train/
95
+ β”‚ β”‚ β”œβ”€β”€ image1.png
96
+ β”‚ β”‚ β”œβ”€β”€ image2.png
97
+ β”‚ β”‚ └── ...
98
+ β”‚ └── test/
99
+ β”‚ β”œβ”€β”€ image_test1.png
100
+ β”‚ β”œβ”€β”€ image_test2.png
101
+ β”‚ └── ...
102
+ β”œβ”€β”€ train.csv
103
+ └── test.csv
104
+ ```
105
+
106
+ #### 5. Clear Python Cache *(Important!)*
107
+ After making code changes or installing new packages, it's crucial to clear Python's compiled cache to ensure the latest code is used.
108
+
109
+ ```bash
110
+ find . -name "__pycache__" -exec rm -rf {} + # For macOS/Linux
111
+
112
+ Get-ChildItem -Path . -Include __pycache__ -Recurse | Remove-Item -Recurse -Force # For Windows PowerShell
113
+ ```
114
+
115
+ ### 6. Run the Streamlit Application
116
+ With your virtual environment activated and dependencies installed:
117
+ `streamlit run app.py`
118
+
119
+
120
+ *This will open the application in your web browser.*
121
+
122
+ ## ✏️ Features
123
+ - **CRNN Model Architecture**: Utilizes a Convolutional Recurrent Neural Network for robust OCR.
124
+ - **CTC Loss**: Employs Connectionist Temporal Classification for sequence prediction.
125
+ **Model Training**: Train a new OCR model from your local image and CSV datasets.
126
+ - **Pre-trained Model Loading**: Load previously saved models to avoid retraining.
127
+ - **Handwritten Text Prediction**: Upload an image and get instant text recognition.
128
+ - **Training Progress Visualization**: Real-time updates and plots for training loss, CER, and accuracy.
129
+ - **Hugging Face Hub Integration**: Seamlessly upload your dataset to the Hugging Face Hub for easy sharing and version control.
130
+ - **Responsive UI**: Built with Streamlit for an intuitive and user-friendly experience.
131
+
132
+
133
+ ## πŸ—οΈ Project Structure
134
+ ```
135
+ handwritten_name_ocr_app/
136
+ β”œβ”€β”€ app.py # Main Streamlit application file
137
+ β”œβ”€β”€ config.py # Configuration settings (paths, model params, chars)
138
+ β”œβ”€β”€ data/ # Directory for datasets
139
+ β”‚ β”œβ”€β”€ images/
140
+ β”‚ β”‚ β”œβ”€β”€ train/ # Training images
141
+ β”‚ β”‚ └── test/ # Testing images
142
+ β”‚ β”œβ”€β”€ train.csv # Training labels
143
+ β”‚ └── test.csv # Testing labels
144
+ β”œβ”€β”€ data_handler_ocr.py # Custom PyTorch Dataset and DataLoader logic
145
+ β”œβ”€β”€ models/ # Directory to save/load trained models
146
+ β”‚ └── handwritten_name_ocr_model.pth # Default model save path
147
+ β”œβ”€β”€ model_ocr.py # Defines the CRNN model architecture and training/evaluation functions
148
+ β”œβ”€β”€ utils_ocr.py # Utility functions for image preprocessing
149
+ β”œβ”€β”€ requirements.txt # List of Python dependencies
150
+ └── venvy/ # Python virtual environment (created by `python -m venv venvy`)
151
+ └── ...
152
+ ````
153
+
154
+ ## πŸ—ƒοΈ Project Index
155
+
156
+ `app.py`: The central Streamlit application. Handles UI, triggers training/prediction, and integrates with Hugging Face Hub.
157
+
158
+ `config.py`: Stores global configuration variables such as file paths, image dimensions, character sets, and training hyperparameters.
159
+
160
+ `data_handler_ocr.py`: Contains the CharIndexer class for character-to-index mapping and the OCRDataset and ocr_collate_fn for efficient data loading and batching for PyTorch.
161
+
162
+ `model_ocr.py`: Defines the CNN_Backbone, BidirectionalLSTM, and CRNN (the main OCR model) classes. It also includes functions for train_ocr_model, evaluate_model, save_ocr_model, load_ocr_model, and ctc_greedy_decode.
163
+
164
+ ``utils_ocr.py``: Provides helper functions for image preprocessing steps like binarization, resizing, and normalization, used before feeding images to the model.
165
+
166
+
167
+
168
+ ## πŸ“Œ Roadmap
169
+ - Advanced Data Augmentation: Implement more sophisticated augmentation techniques (e.g., elastic deformations, random noise) for training data.
170
+ - Beam Search Decoding: Replace greedy decoding with beam search for potentially more accurate predictions.
171
+ - Error Analysis Dashboard: Integrate a more detailed error analysis section to visualize common recognition mistakes.
172
+ - Support for Multiple Languages: Extend character sets and train on multilingual datasets.
173
+ - Deployment to Cloud Platforms: Provide instructions for deploying the Streamlit app to platforms like Hugging Face Spaces, Heroku, or AWS.
174
+ - Pre-trained Model Download: Allow users to download pre-trained models directly from Hugging Face Hub.
175
+ - Interactive Drawing Pad: Enable users to draw a name directly in the app for recognition.
176
+
177
+ ## 🎁 Contribution
178
+ Contributions are welcome! If you have suggestions, bug reports, or want to contribute code, please feel free to *fork the repository.*
179
+ - Create a new branch (git checkout -b feature/your-feature-name).
180
+ Make your changes.
181
+ - Commit your changes (git commit -m 'Add new feature').
182
+ - Push to the branch (git push origin feature/your-feature-name).
183
+ - Open a Pull Request.
184
+
185
+ ## βš–οΈ License
186
+ This project is licensed under the MIT License - see the LICENSE file for details.
187
+
188
+ ## ✨ Acknowledgements
189
+ **Streamlit**: For building interactive web applications with ease.
190
+
191
+ **PyTorch**: The open-source machine learning framework.
192
+
193
+ **Hugging** Face Hub: For model and dataset sharing.
194
+
195
+ **OpenCV**: For image processing utilities (implicitly used via utils_ocr).
196
+
197
+ **EditDistance**: For efficient calculation of character error rate.
198
+
199
+ **tqdm**: For progress bars during training.
200
+
201
+ ---
202
+
203
+ *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β© 2025 by **MFT***
models/handwritten_name_ocr_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4dc43b9d30bca8a300c6b7ab1c379685e366e2d838fdca6d5daee27e141ae7b
3
+ size 21382549
requirements.txt CHANGED
@@ -1,3 +1,15 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #requirements.txt
2
+ # This file lists all the Python libraries required to run the Handwritten Name OCR application.
3
+ # Install using: pip install -r requirements.txt
4
+
5
+ streamlit>=1.33.0
6
+ pandas>=2.2.2
7
+ numpy>=1.26.4
8
+ Pillow>=10.3.0
9
+ opencv-python-headless>=4.9.0.80
10
+ torch>=2.2.2
11
+ torchvision>=0.17.2
12
+ matplotlib>=3.8.4
13
+ tqdm>=4.66.2
14
+ editdistance>=0.8.1
15
+ scikit-learn>=1.4.2
src/config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import os
4
+
5
+ # --- Paths ---
6
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ DATA_DIR = os.path.join(BASE_DIR, 'data')
8
+ MODELS_DIR = os.path.join(BASE_DIR, 'models')
9
+
10
+ TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
11
+ TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
12
+
13
+ TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
14
+ TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
15
+
16
+ MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
17
+
18
+ # --- Character Set and OCR Configuration ---
19
+ CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
20
+ BLANK_TOKEN_SYMBOL = 'Þ'
21
+ VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
22
+ NUM_CLASSES = len(VOCABULARY)
23
+ BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
24
+
25
+ # --- Sanity Checks ---
26
+ if BLANK_TOKEN == -1:
27
+ raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
28
+ if BLANK_TOKEN >= NUM_CLASSES:
29
+ raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
30
+
31
+ print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
32
+ print(f"Vocabulary Length: {len(VOCABULARY)}")
33
+ print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
34
+
35
+
36
+ # --- Image Preprocessing Parameters ---
37
+ IMG_HEIGHT = 32 # Target height for all input images to the model
38
+ MAX_IMG_WIDTH = 1024 # Adjust this value based on your typical image widths and available RAM
39
+
40
+ # --- Training Parameters ---
41
+ BATCH_SIZE = 10
42
+
43
+ # Dataset Limits
44
+ TRAIN_SAMPLES_LIMIT = 1000
45
+ TEST_SAMPLES_LIMIT = 1000
46
+
47
+ NUM_EPOCHS = 5
48
+ LEARNING_RATE = 0.001
src/data_handler_ocr.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_handler_ocr.py
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torchvision import transforms
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+ # Import utility functions and config
13
+ from config import (
14
+ VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT,
15
+ TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
16
+ TRAIN_SAMPLES_LIMIT, TEST_SAMPLES_LIMIT
17
+ )
18
+ from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
19
+
20
+ class CharIndexer:
21
+ """Manages character-to-index and index-to-character mappings."""
22
+ def __init__(self, vocabulary_string: str, blank_token_symbol: str):
23
+ self.chars = sorted(list(set(vocabulary_string)))
24
+ self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
25
+ self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
26
+
27
+ if blank_token_symbol not in self.char_to_idx:
28
+ raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
29
+
30
+ self.blank_token_idx = self.char_to_idx[blank_token_symbol]
31
+ self.num_classes = len(self.chars)
32
+
33
+ if self.blank_token_idx >= self.num_classes:
34
+ raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
35
+
36
+ print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
37
+ print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
38
+
39
+ def encode(self, text: str) -> list[int]:
40
+ """Converts a text string to a list of integer indices."""
41
+ encoded_list = []
42
+ for char in text:
43
+ if char in self.char_to_idx:
44
+ encoded_list.append(self.char_to_idx[char])
45
+ else:
46
+ print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
47
+ encoded_list.append(self.blank_token_idx)
48
+ return encoded_list
49
+
50
+ def decode(self, indices: list[int]) -> str:
51
+ """Converts a list of integer indices back to a text string."""
52
+ decoded_text = []
53
+ for i, idx in enumerate(indices):
54
+ if idx == self.blank_token_idx:
55
+ continue # Skip blank tokens
56
+
57
+ if i > 0 and indices[i-1] == idx:
58
+ continue
59
+
60
+ if idx in self.idx_to_char:
61
+ decoded_text.append(self.idx_to_char[idx])
62
+ else:
63
+ print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
64
+
65
+ return "".join(decoded_text)
66
+
67
+ class OCRDataset(Dataset):
68
+ """
69
+ Custom PyTorch Dataset for the Handwritten Name Recognition task.
70
+ Loads images and their corresponding text labels.
71
+ """
72
+ def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
73
+ self.data = dataframe
74
+ self.char_indexer = char_indexer
75
+ self.image_dir = image_dir
76
+
77
+ if transform is None:
78
+ self.transform = transforms.Compose([
79
+ transforms.Lambda(lambda img: binarize_image(img)),
80
+ transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)), # Resize image to fixed height
81
+ transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor (H, W) -> (1, H, W), scales to [0,1]
82
+ transforms.Lambda(normalize_image_for_model) # Normalize pixel values to [-1, 1]
83
+ ])
84
+ else:
85
+ self.transform = transform
86
+
87
+
88
+ def __len__(self) -> int:
89
+ return len(self.data)
90
+
91
+ def __getitem__(self, idx):
92
+ raw_filename_entry = self.data.loc[idx, 'FILENAME']
93
+ ground_truth_text = self.data.loc[idx, 'IDENTITY']
94
+
95
+ filename = raw_filename_entry.split(',')[0].strip()
96
+ img_path = os.path.join(self.image_dir, filename)
97
+ ground_truth_text = str(ground_truth_text)
98
+
99
+ try:
100
+ image = load_image_as_grayscale(img_path) # Returns PIL Image 'L'
101
+ except FileNotFoundError:
102
+ print(f"Error: Image file not found at {img_path}. Skipping this item.")
103
+ raise
104
+
105
+ if self.transform:
106
+ image = self.transform(image)
107
+
108
+ image_width = image.shape[2] # Assuming image is (C, H, W) after transform
109
+
110
+ text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
111
+ text_length = len(text_encoded)
112
+
113
+ return image, text_encoded, image_width, text_length
114
+
115
+ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
116
+ """
117
+ Custom collate function for the DataLoader to handle variable-width images
118
+ and variable-length text sequences for CTC loss.
119
+ """
120
+ images, texts, image_widths, text_lengths = zip(*batch)
121
+
122
+ max_batch_width = max(image_widths)
123
+ padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
124
+ images_batch = torch.stack(padded_images, 0)
125
+
126
+ texts_batch = torch.cat(texts, 0)
127
+ text_lengths_tensor = torch.tensor(list(text_lengths), dtype=torch.long)
128
+ image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
129
+
130
+ return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
131
+
132
+
133
+ def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
134
+ """
135
+ Loads training and testing dataframes.
136
+ Assumes CSVs have 'FILENAME' and 'IDENTITY' columns.
137
+ Applies dataset limits from config.py.
138
+ """
139
+ train_df = pd.read_csv(train_csv_path, encoding='ISO-8859-1')
140
+ test_df = pd.read_csv(test_csv_path, encoding='ISO-8859-1')
141
+
142
+ # Apply limits if they are set (not 0)
143
+ if TRAIN_SAMPLES_LIMIT > 0:
144
+ train_df = train_df.head(TRAIN_SAMPLES_LIMIT)
145
+ print(f"Limited training data to {TRAIN_SAMPLES_LIMIT} samples.")
146
+ if TEST_SAMPLES_LIMIT > 0:
147
+ test_df = test_df.head(TEST_SAMPLES_LIMIT)
148
+ print(f"Limited test data to {TEST_SAMPLES_LIMIT} samples.")
149
+
150
+ return train_df, test_df
151
+
152
+ def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
153
+ char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
154
+ """
155
+ Creates PyTorch DataLoader objects for OCR training and testing datasets,
156
+ using specific image directories for train/test.
157
+ """
158
+ train_dataset = OCRDataset(train_df, char_indexer, TRAIN_IMAGES_DIR)
159
+ test_dataset = OCRDataset(test_df, char_indexer, TEST_IMAGES_DIR)
160
+
161
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
162
+ num_workers=0, collate_fn=ocr_collate_fn)
163
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
164
+ num_workers=0, collate_fn=ocr_collate_fn)
165
+ return train_loader, test_loader
src/model_ocr.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_ocr.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+ from sklearn.metrics import accuracy_score
10
+ import editdistance
11
+
12
+ # Import config and char_indexer
13
+ from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
14
+ from data_handler_ocr import CharIndexer
15
+ from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
16
+
17
+
18
+ class CNN_Backbone(nn.Module):
19
+ """
20
+ CNN feature extractor for OCR. Designed to produce features suitable for RNN.
21
+ Output feature map should have height 1 after the final pooling/reduction.
22
+ """
23
+ def __init__(self, input_channels=1, output_channels=512):
24
+ super(CNN_Backbone, self).__init__()
25
+ self.cnn = nn.Sequential(
26
+ # First block
27
+ nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
28
+ nn.ReLU(True),
29
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
30
+
31
+ # Second block
32
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
33
+ nn.ReLU(True),
34
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
35
+
36
+ # Third block (with two conv layers)
37
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
38
+ nn.ReLU(True),
39
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
40
+ nn.ReLU(True),
41
+ # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
42
+ nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
43
+
44
+ # Fourth block
45
+ nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
46
+ nn.ReLU(True),
47
+ # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
48
+ # while preserving the width. This is crucial for RNN input.
49
+ nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
54
+
55
+ # Pass through the CNN layers
56
+ conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
57
+
58
+ # Squeeze the height dimension (which is 1)
59
+ # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
60
+ conv_features = conv_features.squeeze(2)
61
+
62
+ # Permute for RNN input: (sequence_length, batch_size, input_size)
63
+ # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
64
+ conv_features = conv_features.permute(2, 0, 1)
65
+
66
+ # Return the CNN features, ready for the RNN layer in CRNN
67
+ return conv_features
68
+
69
+ class BidirectionalLSTM(nn.Module):
70
+ """Bidirectional LSTM layer for sequence modeling."""
71
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
72
+ super(BidirectionalLSTM, self).__init__()
73
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
74
+ bidirectional=True, dropout=dropout, batch_first=False)
75
+ # batch_first=False expects input as (sequence_length, batch_size, input_size)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
79
+ return output
80
+
81
+ class CRNN(nn.Module):
82
+ """
83
+ Convolutional Recurrent Neural Network for OCR.
84
+ Combines CNN for feature extraction, LSTMs for sequence modeling,
85
+ and a final linear layer for character prediction.
86
+ """
87
+ def __init__(self, num_classes: int, cnn_output_channels: int = 512,
88
+ rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
89
+ super(CRNN, self).__init__()
90
+ self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
91
+ # Input to LSTM is the number of channels from the CNN output
92
+ self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
93
+ # Output of bidirectional LSTM is hidden_size * 2
94
+ self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
98
+
99
+ # 1. Pass through the CNN to extract features
100
+ conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
101
+
102
+ # 2. Pass CNN features through the RNN (LSTM)
103
+ rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
104
+
105
+ # 3. Pass RNN features through the final fully connected layer
106
+ # Apply the linear layer to each time step independently
107
+ # output will be (W_prime, N, num_classes)
108
+ output = self.fc(rnn_features)
109
+
110
+ return output
111
+
112
+
113
+ # --- Decoding Function ---
114
+ def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
115
+ """
116
+ Performs greedy decoding on the CTC output.
117
+ output: (sequence_length, batch_size, num_classes) - raw logits
118
+ """
119
+ # Apply log_softmax to get probabilities for argmax
120
+ log_probs = F.log_softmax(output, dim=2)
121
+
122
+ # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
123
+ predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
124
+
125
+ decoded_texts = []
126
+ for seq in predicted_indices:
127
+ # Use char_indexer's decode method, which handles blank removal and duplicate collapse
128
+ decoded_texts.append(char_indexer.decode(seq.tolist()))
129
+ return decoded_texts
130
+
131
+ # --- Evaluation Function ---
132
+ def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
133
+ model.eval()
134
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
135
+ total_loss = 0
136
+ all_predictions = []
137
+ all_ground_truths = []
138
+
139
+ with torch.no_grad():
140
+ for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
141
+ inputs = inputs.to(device)
142
+ targets_padded = targets_padded.to(device)
143
+ target_lengths_tensor = target_lengths.to(device)
144
+
145
+ output = model(inputs)
146
+
147
+ outputs_seq_len_for_ctc = torch.full(
148
+ size=(output.shape[1],),
149
+ fill_value=output.shape[0],
150
+ dtype=torch.long,
151
+ device=device
152
+ )
153
+
154
+ # CTC Loss calculation requires log_softmax on the output logits
155
+ log_probs_for_loss = F.log_softmax(output, dim=2)
156
+
157
+ # CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
158
+ loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
159
+ total_loss += loss.item() * inputs.size(0)
160
+
161
+ decoded_preds = ctc_greedy_decode(output, char_indexer)
162
+ all_predictions.extend(decoded_preds)
163
+
164
+ ground_truths_batch = []
165
+ current_idx_in_concatenated_targets = 0
166
+
167
+ target_lengths_list = target_lengths.cpu().tolist()
168
+
169
+ for i in range(inputs.size(0)):
170
+ length = target_lengths_list[i]
171
+
172
+ current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
173
+ ground_truths_batch.append(char_indexer.decode(current_target_segment))
174
+ current_idx_in_concatenated_targets += length
175
+
176
+ all_ground_truths.extend(ground_truths_batch)
177
+
178
+ avg_loss = total_loss / len(dataloader.dataset)
179
+
180
+ # Calculate Character Error Rate (CER)
181
+ cer_sum = 0
182
+ total_chars = 0
183
+ for pred, gt in zip(all_predictions, all_ground_truths):
184
+ cer_sum += editdistance.eval(pred, gt)
185
+ total_chars += len(gt)
186
+ char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
187
+
188
+ # Calculate Exact Match Accuracy (Word-level Accuracy)
189
+ exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
190
+
191
+ return avg_loss, char_error_rate, exact_match_accuracy
192
+
193
+ # --- Training Function ---
194
+ def train_ocr_model(model: nn.Module, train_loader: DataLoader,
195
+ test_loader: DataLoader, char_indexer: CharIndexer,
196
+ epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
197
+ """
198
+ Trains the OCR model using CTC loss.
199
+ """
200
+ # CTCLoss needs the blank token index
201
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
202
+ optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
203
+ # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
204
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
205
+
206
+ model.to(device) # Ensure model is on the correct device
207
+ model.train() # Set model to training mode
208
+
209
+ training_history = {
210
+ 'train_loss': [],
211
+ 'test_loss': [],
212
+ 'test_cer': [],
213
+ 'test_exact_match_accuracy': []
214
+ }
215
+
216
+ for epoch in range(epochs):
217
+ running_loss = 0.0
218
+ pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
219
+ for images, texts_encoded, _, text_lengths in pbar_train:
220
+ images = images.to(device)
221
+ # Ensure target tensors are on the correct device for CTCLoss calculation
222
+ texts_encoded = texts_encoded.to(device)
223
+ text_lengths = text_lengths.to(device)
224
+
225
+ optimizer.zero_grad() # Clear gradients from previous step
226
+ outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
227
+
228
+ # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
229
+ # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
230
+ outputs_seq_len_for_ctc = torch.full(
231
+ size=(outputs.shape[1],), # batch_size
232
+ fill_value=outputs.shape[0], # actual sequence length (T) from model output
233
+ dtype=torch.long,
234
+ device=device
235
+ )
236
+
237
+ # CTC Loss calculation requires log_softmax on the output logits
238
+ log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
239
+
240
+ # Use outputs_seq_len_for_ctc for the input_lengths argument
241
+ loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
242
+ loss.backward() # Backpropagate
243
+ optimizer.step() # Update model weights
244
+
245
+ running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
246
+ pbar_train.set_postfix(loss=loss.item())
247
+
248
+ epoch_train_loss = running_loss / len(train_loader.dataset)
249
+ training_history['train_loss'].append(epoch_train_loss)
250
+
251
+ # Evaluate on test set using the dedicated function
252
+ # Ensure model is in eval mode before calling evaluate_model
253
+ model.eval()
254
+ test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
255
+ training_history['test_loss'].append(test_loss)
256
+ training_history['test_cer'].append(test_cer)
257
+ training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
258
+
259
+ # Adjust learning rate based on test loss
260
+ scheduler.step(test_loss)
261
+
262
+ print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
263
+ f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
264
+
265
+ if progress_callback:
266
+ # Update progress bar with current epoch and key metrics
267
+ progress_val = (epoch + 1) / epochs
268
+ progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
269
+
270
+ model.train() # Set model back to training mode after evaluation
271
+
272
+ return model, training_history
273
+
274
+ def save_ocr_model(model: nn.Module, path: str):
275
+ """Saves the state dictionary of the trained OCR model."""
276
+ torch.save(model.state_dict(), path)
277
+ print(f"OCR model saved to {path}")
278
+
279
+ def load_ocr_model(model: nn.Module, path: str):
280
+ """
281
+ Loads a trained OCR model's state dictionary.
282
+ Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
283
+ """
284
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
285
+ model.eval() # Set to evaluation mode
286
+ print(f"OCR model loaded from {path}")
src/streamlit_app.py CHANGED
@@ -1,40 +1,227 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # -*- coding: utf-8 -*-
2
+ # app.py
3
+
4
+ import os
5
+ # Disable Streamlit file watcher to prevent conflicts with PyTorch
6
+ os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
7
+
8
  import streamlit as st
9
+ import pandas as pd
10
+ import numpy as np
11
+ from PIL import Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+ import traceback
16
+
17
+ # Import all necessary configuration values from config.py
18
+ from config import (
19
+ IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
20
+ TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
21
+ MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
22
+ )
23
+
24
+ # Import classes and functions from data_handler_ocr.py and model_ocr.py
25
+ from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
26
+ from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
27
+ from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
28
+
29
+
30
+ # --- Global Variables ---
31
+ ocr_model = None
32
+ char_indexer = None
33
+ training_history = None
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # --- Streamlit App Setup ---
37
+ st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
38
+
39
+
40
+ st.title("πŸ“ Handwritten Name Recognition (OCR) App")
41
+ st.markdown("""
42
+ This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
43
+ Optical Character Recognition (OCR) on handwritten names. You can upload an image
44
+ of a handwritten name for prediction or train a new model using the provided dataset.
45
+
46
+ **Note:** Training a robust OCR model can be time-consuming.
47
+ """)
48
+
49
+ # --- Initialize CharIndexer ---
50
+ # This initializes char_indexer once when the script starts
51
+ char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
52
+
53
+ # --- Model Loading / Initialization ---
54
+ @st.cache_resource # Cache the model to prevent reloading on every rerun
55
+ def get_and_load_ocr_model_cached(num_classes, model_path):
56
+ """
57
+ Initializes the OCR model and attempts to load a pre-trained model.
58
+ If no pre-trained model exists, a new model instance is returned.
59
+ """
60
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
61
+
62
+ if os.path.exists(model_path):
63
+ st.sidebar.info("Loading pre-trained OCR model...")
64
+ try:
65
+ # Load model to CPU first, then move to device
66
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
67
+ st.sidebar.success("OCR model loaded successfully!")
68
+ except Exception as e:
69
+ st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
70
+ # If loading fails, re-initialize an untrained model
71
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
72
+ else:
73
+ st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
74
+
75
+ return model_instance
76
+
77
+ # Get the model instance and assign it to the global 'ocr_model'
78
+ ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
79
+ # Ensure the model is on the correct device for inference
80
+ ocr_model.to(device)
81
+ ocr_model.eval() # Set model to evaluation mode for inference by default
82
+
83
+
84
+ # --- Sidebar for Model Training ---
85
+ st.sidebar.header("Train OCR Model")
86
+ st.sidebar.write("Click the button below to start training the OCR model.")
87
+
88
+ # Progress bar and label for training in the sidebar
89
+ progress_bar_sidebar = st.sidebar.progress(0)
90
+ progress_label_sidebar = st.sidebar.empty()
91
+
92
+ def update_progress_callback_sidebar(value, text):
93
+ progress_bar_sidebar.progress(int(value * 100))
94
+ progress_label_sidebar.text(text)
95
+
96
+ if st.sidebar.button("πŸ“Š Start Training"):
97
+ progress_bar_sidebar.progress(0)
98
+ progress_label_sidebar.empty()
99
+ st.empty()
100
+
101
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
102
+ st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
103
+ elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
104
+ st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
105
+ "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
106
+ else:
107
+ st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
108
+
109
+ try:
110
+ train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
111
+ st.sidebar.success("Training and Test DataFrames loaded successfully.")
112
+
113
+ st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
114
+
115
+ train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
116
+ st.sidebar.success("DataLoaders created successfully.")
117
+
118
+ ocr_model.train()
119
+
120
+ st.sidebar.write("Training in progress... This may take a while.")
121
+ ocr_model, training_history = train_ocr_model(
122
+ model=ocr_model,
123
+ train_loader=train_loader,
124
+ test_loader=test_loader,
125
+ char_indexer=char_indexer,
126
+ epochs=NUM_EPOCHS,
127
+ device=device,
128
+ progress_callback=update_progress_callback_sidebar
129
+ )
130
+ st.sidebar.success("OCR model training finished!")
131
+ update_progress_callback_sidebar(1.0, "Training complete!")
132
+
133
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
134
+ save_ocr_model(ocr_model, MODEL_SAVE_PATH)
135
+ st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
136
+
137
+ except Exception as e:
138
+ st.sidebar.error(f"An error occurred during training: {e}")
139
+ st.exception(e)
140
+ update_progress_callback_sidebar(0.0, "Training failed!")
141
+
142
+ # --- Sidebar for Model Loading ---
143
+ st.sidebar.header("Load Pre-trained Model")
144
+ st.sidebar.write("If you have a saved model, you can load it here instead of training.")
145
+
146
+ if st.sidebar.button("πŸ’Ύ Load Model"):
147
+ if os.path.exists(MODEL_SAVE_PATH):
148
+ try:
149
+ loaded_model = CRNN(num_classes=char_indexer.num_classes)
150
+ load_ocr_model(loaded_model, MODEL_SAVE_PATH)
151
+ loaded_model.to(device)
152
+
153
+ st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
154
+ except Exception as e:
155
+ st.sidebar.error(f"Error loading model: {e}")
156
+ st.exception(e)
157
+ else:
158
+ st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
159
+
160
+ # --- Main Content: Prediction Section and Training History ---
161
+
162
+ # Display training history chart
163
+ if training_history:
164
+ st.subheader("Training History Plots")
165
+ history_df = pd.DataFrame({
166
+ 'Epoch': range(1, len(training_history['train_loss']) + 1),
167
+ 'Train Loss': training_history['train_loss'],
168
+ 'Test Loss': training_history['test_loss'],
169
+ 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
170
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
171
+ })
172
+
173
+ st.markdown("**Loss over Epochs**")
174
+ st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
175
+ st.caption("Lower loss indicates better model performance.")
176
+
177
+ st.markdown("**Character Error Rate (CER) over Epochs**")
178
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
179
+ st.caption("Lower CER indicates fewer character errors (0% is perfect).")
180
+
181
+ st.markdown("**Exact Match Accuracy over Epochs**")
182
+ st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
183
+ st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
184
+
185
+ st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
186
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
187
+ st.caption("CER should decrease, Accuracy should increase.")
188
+ st.write("---") # Separator after charts
189
+
190
+
191
+ # Predict on a New Image
192
+
193
+ if ocr_model is None:
194
+ st.warning("Please train or load a model before attempting prediction.")
195
+ else:
196
+ uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
197
+
198
+ if uploaded_file is not None:
199
+ try:
200
+ image_pil = Image.open(uploaded_file).convert('L')
201
+ st.image(image_pil, caption="Uploaded Image", use_container_width=True)
202
+ st.write("---")
203
+ st.write("Processing and Recognizing...")
204
+
205
+ processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
206
+
207
+ ocr_model.eval()
208
+ with torch.no_grad():
209
+ output = ocr_model(processed_image_tensor)
210
+
211
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
212
+ predicted_text = predicted_texts[0]
213
+
214
+ st.success(f"Recognized Text: **{predicted_text}**")
215
+
216
+ except Exception as e:
217
+ st.error(f"Error processing image or recognizing text: {e}")
218
+ st.info("πŸ’‘ **Tips for best results:**\n"
219
+ "- Ensure the handwritten text is clear and on a clean background.\n"
220
+ "- Only include one name/word per image.\n"
221
+ "- The model is trained on specific characters. Unusual symbols might not be recognized.")
222
+ st.exception(e)
223
 
224
+ st.markdown("""
225
+ ---
226
+ *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
227
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils_ocr.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #utils_ocr.py
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ import os
9
+
10
+ # Import config for IMG_HEIGHT and MAX_IMG_WIDTH
11
+ from config import IMG_HEIGHT, MAX_IMG_WIDTH
12
+
13
+ # --- Image Preprocessing Functions ---
14
+
15
+ def load_image_as_grayscale(image_path: str) -> Image.Image:
16
+ """Loads an image from path and converts it to grayscale PIL Image."""
17
+ if not os.path.exists(image_path):
18
+ raise FileNotFoundError(f"Image not found at: {image_path}")
19
+ return Image.open(image_path).convert('L') # 'L' for grayscale
20
+
21
+ def binarize_image(img: Image.Image) -> Image.Image:
22
+ """
23
+ Binarizes a grayscale PIL Image using Otsu's method.
24
+ Returns a PIL Image.
25
+ """
26
+ # Convert PIL Image to OpenCV format (numpy array)
27
+ img_np = np.array(img)
28
+
29
+ # Apply Otsu's binarization
30
+ _, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
31
+
32
+ # Convert back to PIL Image
33
+ return Image.fromarray(binary_img)
34
+
35
+ def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
36
+ """
37
+ Resizes a PIL Image to a fixed height while maintaining aspect ratio.
38
+ Also ensures the width does not exceed MAX_IMG_WIDTH.
39
+ """
40
+ width, height = img.size
41
+
42
+ # Calculate new width based on target height, maintaining aspect ratio
43
+ new_width = int(width * (img_height / height))
44
+
45
+ if new_width > MAX_IMG_WIDTH:
46
+ new_width = MAX_IMG_WIDTH
47
+ resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
48
+ if resized_img.width > MAX_IMG_WIDTH:
49
+ # Crop the image from the left to MAX_IMG_WIDTH
50
+ resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
51
+ return resized_img
52
+
53
+ return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
54
+
55
+ def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Normalizes a torch.Tensor image (grayscale) for input into the model.
58
+ Puts pixel values in range [-1, 1].
59
+ Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
60
+ """
61
+ # Formula: (pixel_value - mean) / std_dev
62
+ # For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
63
+ img_tensor = (img_tensor - 0.5) / 0.5
64
+ return img_tensor
65
+
66
+ def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
67
+ """
68
+ Applies all necessary preprocessing steps to a user-uploaded PIL Image
69
+ to prepare it for the OCR model.
70
+ """
71
+ # Define a transformation pipeline similar to the dataset, but including ToTensor
72
+ transform_pipeline = transforms.Compose([
73
+ transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
74
+ # Use the updated resize function that also handles MAX_IMG_WIDTH
75
+ transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
76
+ transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
77
+ transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
78
+ ])
79
+
80
+ processed_image = transform_pipeline(image_pil)
81
+
82
+ # Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
83
+ return processed_image.unsqueeze(0)