Spaces:
Runtime error
Runtime error
Initial upload of files
Browse files- README.md +196 -12
- models/handwritten_name_ocr_model.pth +3 -0
- requirements.txt +15 -3
- src/config.py +48 -0
- src/data_handler_ocr.py +165 -0
- src/model_ocr.py +286 -0
- src/streamlit_app.py +225 -38
- src/utils_ocr.py +83 -0
README.md
CHANGED
|
@@ -1,19 +1,203 @@
|
|
| 1 |
---
|
| 2 |
-
title: Handwritten Name Recognizer
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
app_port: 8501
|
| 8 |
-
tags:
|
| 9 |
-
- streamlit
|
| 10 |
pinned: false
|
| 11 |
-
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Handwritten Name Recognizer
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: docker
|
|
|
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Handwritten Name Recognition (OCR) App βπ»
|
| 12 |
|
| 13 |
+
*An end-to-end Streamlit application for training and predicting handwritten names using a CRNN model.*
|
| 14 |
|
| 15 |
+
[π Demo and Documentation](https://drive.google.com/drive/folders/1rOmwyTJkDCsU-Wuh-_CzvQ9sdb_ci_kX?usp=sharing)
|
| 16 |
+
[π GitHub Repository](https://github.com/marianeft/handwritten_name_ocr_app.git)
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Table of Contents
|
| 21 |
+
|
| 22 |
+
* [Overview](#overview)
|
| 23 |
+
* [Quickstart](#quickstart)
|
| 24 |
+
* [Features](#features)
|
| 25 |
+
* [Project Structure](#project-structure)
|
| 26 |
+
* [Project Index](#project-index)
|
| 27 |
+
* [Roadmap](#roadmap)
|
| 28 |
+
* [Contribution](#contribution)
|
| 29 |
+
* [License](#license)
|
| 30 |
+
* [Acknowledgements](#acknowledgements)
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## πΉοΈ Overview
|
| 35 |
+
|
| 36 |
+
This project implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) architecture built with PyTorch. The application is presented as an interactive web interface using Streamlit, allowing users to:
|
| 37 |
+
|
| 38 |
+
1. **Train** a new OCR model from a local dataset.
|
| 39 |
+
2. **Load** a pre-trained model.
|
| 40 |
+
3. **Predict** text from uploaded handwritten image files.
|
| 41 |
+
4. **Upload** the local dataset to the Hugging Face Hub for sharing and versioning.
|
| 42 |
+
|
| 43 |
+
The CRNN model combines a CNN backbone for feature extraction from images and a Bidirectional LSTM layer for sequence modeling, followed by a linear layer for character classification using CTC (Connectionist Temporal Classification) Loss.
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## π© Quickstart
|
| 48 |
+
|
| 49 |
+
Follow these steps to get the application up and running on your local machine.
|
| 50 |
+
|
| 51 |
+
### Prerequisites
|
| 52 |
+
|
| 53 |
+
* Python 3.8+
|
| 54 |
+
* `pip` (Python package installer)
|
| 55 |
+
|
| 56 |
+
#### 1. Clone the Repository (or set up your project folder)
|
| 57 |
+
|
| 58 |
+
Ensure your project structure matches the expected layout (e.g., `app.py`, `config.py`, `data/`, `models/` etc.).
|
| 59 |
+
|
| 60 |
+
#### 2. Create and Activate a Virtual Environment
|
| 61 |
+
It's highly recommended to use a virtual environment to manage dependencies.
|
| 62 |
+
|
| 63 |
+
``` bash
|
| 64 |
+
# Navigate to your project root directory
|
| 65 |
+
cd path/to/your/handwritten_name_ocr_app
|
| 66 |
+
|
| 67 |
+
# Create a virtual environment named 'venvy'
|
| 68 |
+
python -m venv venvy
|
| 69 |
+
|
| 70 |
+
# Activate the virtual environment
|
| 71 |
+
# On Windows (Command Prompt):
|
| 72 |
+
.\venvy\Scripts\activate.bat
|
| 73 |
+
|
| 74 |
+
# On Windows (PowerShell):
|
| 75 |
+
.\venvy\Scripts\Activate.ps1
|
| 76 |
+
|
| 77 |
+
# On macOS/Linux:
|
| 78 |
+
source venvy/bin/activate
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
#### 3. Install Dependencies
|
| 82 |
+
With your virtual environment activated, install all required Python packages:
|
| 83 |
+
`pip install streamlit` `pandas` `numpy` `Pillow` `torch` `torchvision` `scikit-learn` `tqdm` `editdistance` `huggingface_hub`
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
*Note on PyTorch (torch and torchvision):
|
| 87 |
+
The command above installs the CPU-only version of PyTorch. If you have a CUDA-enabled GPU and want to leverage it for faster training, please refer to the official PyTorch website (pytorch.org/get-started/locally/) for specific installation commands tailored to your CUDA version.*
|
| 88 |
+
|
| 89 |
+
#### 4. Prepare Your Dataset
|
| 90 |
+
The application expects a dataset structured as follows:
|
| 91 |
+
``` bash
|
| 92 |
+
data/
|
| 93 |
+
βββ images/
|
| 94 |
+
β βββ train/
|
| 95 |
+
β β βββ image1.png
|
| 96 |
+
β β βββ image2.png
|
| 97 |
+
β β βββ ...
|
| 98 |
+
β βββ test/
|
| 99 |
+
β βββ image_test1.png
|
| 100 |
+
β βββ image_test2.png
|
| 101 |
+
β βββ ...
|
| 102 |
+
βββ train.csv
|
| 103 |
+
βββ test.csv
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
#### 5. Clear Python Cache *(Important!)*
|
| 107 |
+
After making code changes or installing new packages, it's crucial to clear Python's compiled cache to ensure the latest code is used.
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
find . -name "__pycache__" -exec rm -rf {} + # For macOS/Linux
|
| 111 |
+
|
| 112 |
+
Get-ChildItem -Path . -Include __pycache__ -Recurse | Remove-Item -Recurse -Force # For Windows PowerShell
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### 6. Run the Streamlit Application
|
| 116 |
+
With your virtual environment activated and dependencies installed:
|
| 117 |
+
`streamlit run app.py`
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
*This will open the application in your web browser.*
|
| 121 |
+
|
| 122 |
+
## βοΈ Features
|
| 123 |
+
- **CRNN Model Architecture**: Utilizes a Convolutional Recurrent Neural Network for robust OCR.
|
| 124 |
+
- **CTC Loss**: Employs Connectionist Temporal Classification for sequence prediction.
|
| 125 |
+
**Model Training**: Train a new OCR model from your local image and CSV datasets.
|
| 126 |
+
- **Pre-trained Model Loading**: Load previously saved models to avoid retraining.
|
| 127 |
+
- **Handwritten Text Prediction**: Upload an image and get instant text recognition.
|
| 128 |
+
- **Training Progress Visualization**: Real-time updates and plots for training loss, CER, and accuracy.
|
| 129 |
+
- **Hugging Face Hub Integration**: Seamlessly upload your dataset to the Hugging Face Hub for easy sharing and version control.
|
| 130 |
+
- **Responsive UI**: Built with Streamlit for an intuitive and user-friendly experience.
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
## ποΈ Project Structure
|
| 134 |
+
```
|
| 135 |
+
handwritten_name_ocr_app/
|
| 136 |
+
βββ app.py # Main Streamlit application file
|
| 137 |
+
βββ config.py # Configuration settings (paths, model params, chars)
|
| 138 |
+
βββ data/ # Directory for datasets
|
| 139 |
+
β βββ images/
|
| 140 |
+
β β βββ train/ # Training images
|
| 141 |
+
β β βββ test/ # Testing images
|
| 142 |
+
β βββ train.csv # Training labels
|
| 143 |
+
β βββ test.csv # Testing labels
|
| 144 |
+
βββ data_handler_ocr.py # Custom PyTorch Dataset and DataLoader logic
|
| 145 |
+
βββ models/ # Directory to save/load trained models
|
| 146 |
+
β βββ handwritten_name_ocr_model.pth # Default model save path
|
| 147 |
+
βββ model_ocr.py # Defines the CRNN model architecture and training/evaluation functions
|
| 148 |
+
βββ utils_ocr.py # Utility functions for image preprocessing
|
| 149 |
+
βββ requirements.txt # List of Python dependencies
|
| 150 |
+
βββ venvy/ # Python virtual environment (created by `python -m venv venvy`)
|
| 151 |
+
βββ ...
|
| 152 |
+
````
|
| 153 |
+
|
| 154 |
+
## ποΈ Project Index
|
| 155 |
+
|
| 156 |
+
`app.py`: The central Streamlit application. Handles UI, triggers training/prediction, and integrates with Hugging Face Hub.
|
| 157 |
+
|
| 158 |
+
`config.py`: Stores global configuration variables such as file paths, image dimensions, character sets, and training hyperparameters.
|
| 159 |
+
|
| 160 |
+
`data_handler_ocr.py`: Contains the CharIndexer class for character-to-index mapping and the OCRDataset and ocr_collate_fn for efficient data loading and batching for PyTorch.
|
| 161 |
+
|
| 162 |
+
`model_ocr.py`: Defines the CNN_Backbone, BidirectionalLSTM, and CRNN (the main OCR model) classes. It also includes functions for train_ocr_model, evaluate_model, save_ocr_model, load_ocr_model, and ctc_greedy_decode.
|
| 163 |
+
|
| 164 |
+
``utils_ocr.py``: Provides helper functions for image preprocessing steps like binarization, resizing, and normalization, used before feeding images to the model.
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
## π Roadmap
|
| 169 |
+
- Advanced Data Augmentation: Implement more sophisticated augmentation techniques (e.g., elastic deformations, random noise) for training data.
|
| 170 |
+
- Beam Search Decoding: Replace greedy decoding with beam search for potentially more accurate predictions.
|
| 171 |
+
- Error Analysis Dashboard: Integrate a more detailed error analysis section to visualize common recognition mistakes.
|
| 172 |
+
- Support for Multiple Languages: Extend character sets and train on multilingual datasets.
|
| 173 |
+
- Deployment to Cloud Platforms: Provide instructions for deploying the Streamlit app to platforms like Hugging Face Spaces, Heroku, or AWS.
|
| 174 |
+
- Pre-trained Model Download: Allow users to download pre-trained models directly from Hugging Face Hub.
|
| 175 |
+
- Interactive Drawing Pad: Enable users to draw a name directly in the app for recognition.
|
| 176 |
+
|
| 177 |
+
## π Contribution
|
| 178 |
+
Contributions are welcome! If you have suggestions, bug reports, or want to contribute code, please feel free to *fork the repository.*
|
| 179 |
+
- Create a new branch (git checkout -b feature/your-feature-name).
|
| 180 |
+
Make your changes.
|
| 181 |
+
- Commit your changes (git commit -m 'Add new feature').
|
| 182 |
+
- Push to the branch (git push origin feature/your-feature-name).
|
| 183 |
+
- Open a Pull Request.
|
| 184 |
+
|
| 185 |
+
## βοΈ License
|
| 186 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 187 |
+
|
| 188 |
+
## β¨ Acknowledgements
|
| 189 |
+
**Streamlit**: For building interactive web applications with ease.
|
| 190 |
+
|
| 191 |
+
**PyTorch**: The open-source machine learning framework.
|
| 192 |
+
|
| 193 |
+
**Hugging** Face Hub: For model and dataset sharing.
|
| 194 |
+
|
| 195 |
+
**OpenCV**: For image processing utilities (implicitly used via utils_ocr).
|
| 196 |
+
|
| 197 |
+
**EditDistance**: For efficient calculation of character error rate.
|
| 198 |
+
|
| 199 |
+
**tqdm**: For progress bars during training.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β© 2025 by **MFT***
|
models/handwritten_name_ocr_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4dc43b9d30bca8a300c6b7ab1c379685e366e2d838fdca6d5daee27e141ae7b
|
| 3 |
+
size 21382549
|
requirements.txt
CHANGED
|
@@ -1,3 +1,15 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#requirements.txt
|
| 2 |
+
# This file lists all the Python libraries required to run the Handwritten Name OCR application.
|
| 3 |
+
# Install using: pip install -r requirements.txt
|
| 4 |
+
|
| 5 |
+
streamlit>=1.33.0
|
| 6 |
+
pandas>=2.2.2
|
| 7 |
+
numpy>=1.26.4
|
| 8 |
+
Pillow>=10.3.0
|
| 9 |
+
opencv-python-headless>=4.9.0.80
|
| 10 |
+
torch>=2.2.2
|
| 11 |
+
torchvision>=0.17.2
|
| 12 |
+
matplotlib>=3.8.4
|
| 13 |
+
tqdm>=4.66.2
|
| 14 |
+
editdistance>=0.8.1
|
| 15 |
+
scikit-learn>=1.4.2
|
src/config.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# --- Paths ---
|
| 6 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 7 |
+
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
| 8 |
+
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
| 9 |
+
|
| 10 |
+
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
|
| 11 |
+
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
|
| 12 |
+
|
| 13 |
+
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
| 14 |
+
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
| 15 |
+
|
| 16 |
+
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
| 17 |
+
|
| 18 |
+
# --- Character Set and OCR Configuration ---
|
| 19 |
+
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
| 20 |
+
BLANK_TOKEN_SYMBOL = 'Γ'
|
| 21 |
+
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
| 22 |
+
NUM_CLASSES = len(VOCABULARY)
|
| 23 |
+
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
| 24 |
+
|
| 25 |
+
# --- Sanity Checks ---
|
| 26 |
+
if BLANK_TOKEN == -1:
|
| 27 |
+
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
| 28 |
+
if BLANK_TOKEN >= NUM_CLASSES:
|
| 29 |
+
raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
|
| 30 |
+
|
| 31 |
+
print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
|
| 32 |
+
print(f"Vocabulary Length: {len(VOCABULARY)}")
|
| 33 |
+
print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --- Image Preprocessing Parameters ---
|
| 37 |
+
IMG_HEIGHT = 32 # Target height for all input images to the model
|
| 38 |
+
MAX_IMG_WIDTH = 1024 # Adjust this value based on your typical image widths and available RAM
|
| 39 |
+
|
| 40 |
+
# --- Training Parameters ---
|
| 41 |
+
BATCH_SIZE = 10
|
| 42 |
+
|
| 43 |
+
# Dataset Limits
|
| 44 |
+
TRAIN_SAMPLES_LIMIT = 1000
|
| 45 |
+
TEST_SAMPLES_LIMIT = 1000
|
| 46 |
+
|
| 47 |
+
NUM_EPOCHS = 5
|
| 48 |
+
LEARNING_RATE = 0.001
|
src/data_handler_ocr.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_handler_ocr.py
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
# Import utility functions and config
|
| 13 |
+
from config import (
|
| 14 |
+
VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT,
|
| 15 |
+
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
| 16 |
+
TRAIN_SAMPLES_LIMIT, TEST_SAMPLES_LIMIT
|
| 17 |
+
)
|
| 18 |
+
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
| 19 |
+
|
| 20 |
+
class CharIndexer:
|
| 21 |
+
"""Manages character-to-index and index-to-character mappings."""
|
| 22 |
+
def __init__(self, vocabulary_string: str, blank_token_symbol: str):
|
| 23 |
+
self.chars = sorted(list(set(vocabulary_string)))
|
| 24 |
+
self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
|
| 25 |
+
self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
|
| 26 |
+
|
| 27 |
+
if blank_token_symbol not in self.char_to_idx:
|
| 28 |
+
raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
|
| 29 |
+
|
| 30 |
+
self.blank_token_idx = self.char_to_idx[blank_token_symbol]
|
| 31 |
+
self.num_classes = len(self.chars)
|
| 32 |
+
|
| 33 |
+
if self.blank_token_idx >= self.num_classes:
|
| 34 |
+
raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
|
| 35 |
+
|
| 36 |
+
print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
|
| 37 |
+
print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
|
| 38 |
+
|
| 39 |
+
def encode(self, text: str) -> list[int]:
|
| 40 |
+
"""Converts a text string to a list of integer indices."""
|
| 41 |
+
encoded_list = []
|
| 42 |
+
for char in text:
|
| 43 |
+
if char in self.char_to_idx:
|
| 44 |
+
encoded_list.append(self.char_to_idx[char])
|
| 45 |
+
else:
|
| 46 |
+
print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
|
| 47 |
+
encoded_list.append(self.blank_token_idx)
|
| 48 |
+
return encoded_list
|
| 49 |
+
|
| 50 |
+
def decode(self, indices: list[int]) -> str:
|
| 51 |
+
"""Converts a list of integer indices back to a text string."""
|
| 52 |
+
decoded_text = []
|
| 53 |
+
for i, idx in enumerate(indices):
|
| 54 |
+
if idx == self.blank_token_idx:
|
| 55 |
+
continue # Skip blank tokens
|
| 56 |
+
|
| 57 |
+
if i > 0 and indices[i-1] == idx:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
if idx in self.idx_to_char:
|
| 61 |
+
decoded_text.append(self.idx_to_char[idx])
|
| 62 |
+
else:
|
| 63 |
+
print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
|
| 64 |
+
|
| 65 |
+
return "".join(decoded_text)
|
| 66 |
+
|
| 67 |
+
class OCRDataset(Dataset):
|
| 68 |
+
"""
|
| 69 |
+
Custom PyTorch Dataset for the Handwritten Name Recognition task.
|
| 70 |
+
Loads images and their corresponding text labels.
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
| 73 |
+
self.data = dataframe
|
| 74 |
+
self.char_indexer = char_indexer
|
| 75 |
+
self.image_dir = image_dir
|
| 76 |
+
|
| 77 |
+
if transform is None:
|
| 78 |
+
self.transform = transforms.Compose([
|
| 79 |
+
transforms.Lambda(lambda img: binarize_image(img)),
|
| 80 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)), # Resize image to fixed height
|
| 81 |
+
transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor (H, W) -> (1, H, W), scales to [0,1]
|
| 82 |
+
transforms.Lambda(normalize_image_for_model) # Normalize pixel values to [-1, 1]
|
| 83 |
+
])
|
| 84 |
+
else:
|
| 85 |
+
self.transform = transform
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def __len__(self) -> int:
|
| 89 |
+
return len(self.data)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx):
|
| 92 |
+
raw_filename_entry = self.data.loc[idx, 'FILENAME']
|
| 93 |
+
ground_truth_text = self.data.loc[idx, 'IDENTITY']
|
| 94 |
+
|
| 95 |
+
filename = raw_filename_entry.split(',')[0].strip()
|
| 96 |
+
img_path = os.path.join(self.image_dir, filename)
|
| 97 |
+
ground_truth_text = str(ground_truth_text)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
image = load_image_as_grayscale(img_path) # Returns PIL Image 'L'
|
| 101 |
+
except FileNotFoundError:
|
| 102 |
+
print(f"Error: Image file not found at {img_path}. Skipping this item.")
|
| 103 |
+
raise
|
| 104 |
+
|
| 105 |
+
if self.transform:
|
| 106 |
+
image = self.transform(image)
|
| 107 |
+
|
| 108 |
+
image_width = image.shape[2] # Assuming image is (C, H, W) after transform
|
| 109 |
+
|
| 110 |
+
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
| 111 |
+
text_length = len(text_encoded)
|
| 112 |
+
|
| 113 |
+
return image, text_encoded, image_width, text_length
|
| 114 |
+
|
| 115 |
+
def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 116 |
+
"""
|
| 117 |
+
Custom collate function for the DataLoader to handle variable-width images
|
| 118 |
+
and variable-length text sequences for CTC loss.
|
| 119 |
+
"""
|
| 120 |
+
images, texts, image_widths, text_lengths = zip(*batch)
|
| 121 |
+
|
| 122 |
+
max_batch_width = max(image_widths)
|
| 123 |
+
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
| 124 |
+
images_batch = torch.stack(padded_images, 0)
|
| 125 |
+
|
| 126 |
+
texts_batch = torch.cat(texts, 0)
|
| 127 |
+
text_lengths_tensor = torch.tensor(list(text_lengths), dtype=torch.long)
|
| 128 |
+
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
|
| 129 |
+
|
| 130 |
+
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 134 |
+
"""
|
| 135 |
+
Loads training and testing dataframes.
|
| 136 |
+
Assumes CSVs have 'FILENAME' and 'IDENTITY' columns.
|
| 137 |
+
Applies dataset limits from config.py.
|
| 138 |
+
"""
|
| 139 |
+
train_df = pd.read_csv(train_csv_path, encoding='ISO-8859-1')
|
| 140 |
+
test_df = pd.read_csv(test_csv_path, encoding='ISO-8859-1')
|
| 141 |
+
|
| 142 |
+
# Apply limits if they are set (not 0)
|
| 143 |
+
if TRAIN_SAMPLES_LIMIT > 0:
|
| 144 |
+
train_df = train_df.head(TRAIN_SAMPLES_LIMIT)
|
| 145 |
+
print(f"Limited training data to {TRAIN_SAMPLES_LIMIT} samples.")
|
| 146 |
+
if TEST_SAMPLES_LIMIT > 0:
|
| 147 |
+
test_df = test_df.head(TEST_SAMPLES_LIMIT)
|
| 148 |
+
print(f"Limited test data to {TEST_SAMPLES_LIMIT} samples.")
|
| 149 |
+
|
| 150 |
+
return train_df, test_df
|
| 151 |
+
|
| 152 |
+
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
| 153 |
+
char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
|
| 154 |
+
"""
|
| 155 |
+
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
| 156 |
+
using specific image directories for train/test.
|
| 157 |
+
"""
|
| 158 |
+
train_dataset = OCRDataset(train_df, char_indexer, TRAIN_IMAGES_DIR)
|
| 159 |
+
test_dataset = OCRDataset(test_df, char_indexer, TEST_IMAGES_DIR)
|
| 160 |
+
|
| 161 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
| 162 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
| 163 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
| 164 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
| 165 |
+
return train_loader, test_loader
|
src/model_ocr.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_ocr.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from sklearn.metrics import accuracy_score
|
| 10 |
+
import editdistance
|
| 11 |
+
|
| 12 |
+
# Import config and char_indexer
|
| 13 |
+
from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
|
| 14 |
+
from data_handler_ocr import CharIndexer
|
| 15 |
+
from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CNN_Backbone(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
CNN feature extractor for OCR. Designed to produce features suitable for RNN.
|
| 21 |
+
Output feature map should have height 1 after the final pooling/reduction.
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, input_channels=1, output_channels=512):
|
| 24 |
+
super(CNN_Backbone, self).__init__()
|
| 25 |
+
self.cnn = nn.Sequential(
|
| 26 |
+
# First block
|
| 27 |
+
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
|
| 28 |
+
nn.ReLU(True),
|
| 29 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
|
| 30 |
+
|
| 31 |
+
# Second block
|
| 32 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
| 33 |
+
nn.ReLU(True),
|
| 34 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
|
| 35 |
+
|
| 36 |
+
# Third block (with two conv layers)
|
| 37 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
| 38 |
+
nn.ReLU(True),
|
| 39 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
| 40 |
+
nn.ReLU(True),
|
| 41 |
+
# This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
|
| 42 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
|
| 43 |
+
|
| 44 |
+
# Fourth block
|
| 45 |
+
nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
|
| 46 |
+
nn.ReLU(True),
|
| 47 |
+
# This AdaptiveAvgPool2d makes sure the height dimension becomes 1
|
| 48 |
+
# while preserving the width. This is crucial for RNN input.
|
| 49 |
+
nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
| 54 |
+
|
| 55 |
+
# Pass through the CNN layers
|
| 56 |
+
conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
|
| 57 |
+
|
| 58 |
+
# Squeeze the height dimension (which is 1)
|
| 59 |
+
# This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
|
| 60 |
+
conv_features = conv_features.squeeze(2)
|
| 61 |
+
|
| 62 |
+
# Permute for RNN input: (sequence_length, batch_size, input_size)
|
| 63 |
+
# This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
|
| 64 |
+
conv_features = conv_features.permute(2, 0, 1)
|
| 65 |
+
|
| 66 |
+
# Return the CNN features, ready for the RNN layer in CRNN
|
| 67 |
+
return conv_features
|
| 68 |
+
|
| 69 |
+
class BidirectionalLSTM(nn.Module):
|
| 70 |
+
"""Bidirectional LSTM layer for sequence modeling."""
|
| 71 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
|
| 72 |
+
super(BidirectionalLSTM, self).__init__()
|
| 73 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
| 74 |
+
bidirectional=True, dropout=dropout, batch_first=False)
|
| 75 |
+
# batch_first=False expects input as (sequence_length, batch_size, input_size)
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
|
| 79 |
+
return output
|
| 80 |
+
|
| 81 |
+
class CRNN(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Convolutional Recurrent Neural Network for OCR.
|
| 84 |
+
Combines CNN for feature extraction, LSTMs for sequence modeling,
|
| 85 |
+
and a final linear layer for character prediction.
|
| 86 |
+
"""
|
| 87 |
+
def __init__(self, num_classes: int, cnn_output_channels: int = 512,
|
| 88 |
+
rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
|
| 89 |
+
super(CRNN, self).__init__()
|
| 90 |
+
self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
|
| 91 |
+
# Input to LSTM is the number of channels from the CNN output
|
| 92 |
+
self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
|
| 93 |
+
# Output of bidirectional LSTM is hidden_size * 2
|
| 94 |
+
self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
| 98 |
+
|
| 99 |
+
# 1. Pass through the CNN to extract features
|
| 100 |
+
conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
|
| 101 |
+
|
| 102 |
+
# 2. Pass CNN features through the RNN (LSTM)
|
| 103 |
+
rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
|
| 104 |
+
|
| 105 |
+
# 3. Pass RNN features through the final fully connected layer
|
| 106 |
+
# Apply the linear layer to each time step independently
|
| 107 |
+
# output will be (W_prime, N, num_classes)
|
| 108 |
+
output = self.fc(rnn_features)
|
| 109 |
+
|
| 110 |
+
return output
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# --- Decoding Function ---
|
| 114 |
+
def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
|
| 115 |
+
"""
|
| 116 |
+
Performs greedy decoding on the CTC output.
|
| 117 |
+
output: (sequence_length, batch_size, num_classes) - raw logits
|
| 118 |
+
"""
|
| 119 |
+
# Apply log_softmax to get probabilities for argmax
|
| 120 |
+
log_probs = F.log_softmax(output, dim=2)
|
| 121 |
+
|
| 122 |
+
# Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
|
| 123 |
+
predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
|
| 124 |
+
|
| 125 |
+
decoded_texts = []
|
| 126 |
+
for seq in predicted_indices:
|
| 127 |
+
# Use char_indexer's decode method, which handles blank removal and duplicate collapse
|
| 128 |
+
decoded_texts.append(char_indexer.decode(seq.tolist()))
|
| 129 |
+
return decoded_texts
|
| 130 |
+
|
| 131 |
+
# --- Evaluation Function ---
|
| 132 |
+
def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
|
| 133 |
+
model.eval()
|
| 134 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
| 135 |
+
total_loss = 0
|
| 136 |
+
all_predictions = []
|
| 137 |
+
all_ground_truths = []
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
|
| 141 |
+
inputs = inputs.to(device)
|
| 142 |
+
targets_padded = targets_padded.to(device)
|
| 143 |
+
target_lengths_tensor = target_lengths.to(device)
|
| 144 |
+
|
| 145 |
+
output = model(inputs)
|
| 146 |
+
|
| 147 |
+
outputs_seq_len_for_ctc = torch.full(
|
| 148 |
+
size=(output.shape[1],),
|
| 149 |
+
fill_value=output.shape[0],
|
| 150 |
+
dtype=torch.long,
|
| 151 |
+
device=device
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
| 155 |
+
log_probs_for_loss = F.log_softmax(output, dim=2)
|
| 156 |
+
|
| 157 |
+
# CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
|
| 158 |
+
loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
|
| 159 |
+
total_loss += loss.item() * inputs.size(0)
|
| 160 |
+
|
| 161 |
+
decoded_preds = ctc_greedy_decode(output, char_indexer)
|
| 162 |
+
all_predictions.extend(decoded_preds)
|
| 163 |
+
|
| 164 |
+
ground_truths_batch = []
|
| 165 |
+
current_idx_in_concatenated_targets = 0
|
| 166 |
+
|
| 167 |
+
target_lengths_list = target_lengths.cpu().tolist()
|
| 168 |
+
|
| 169 |
+
for i in range(inputs.size(0)):
|
| 170 |
+
length = target_lengths_list[i]
|
| 171 |
+
|
| 172 |
+
current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
|
| 173 |
+
ground_truths_batch.append(char_indexer.decode(current_target_segment))
|
| 174 |
+
current_idx_in_concatenated_targets += length
|
| 175 |
+
|
| 176 |
+
all_ground_truths.extend(ground_truths_batch)
|
| 177 |
+
|
| 178 |
+
avg_loss = total_loss / len(dataloader.dataset)
|
| 179 |
+
|
| 180 |
+
# Calculate Character Error Rate (CER)
|
| 181 |
+
cer_sum = 0
|
| 182 |
+
total_chars = 0
|
| 183 |
+
for pred, gt in zip(all_predictions, all_ground_truths):
|
| 184 |
+
cer_sum += editdistance.eval(pred, gt)
|
| 185 |
+
total_chars += len(gt)
|
| 186 |
+
char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
|
| 187 |
+
|
| 188 |
+
# Calculate Exact Match Accuracy (Word-level Accuracy)
|
| 189 |
+
exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
|
| 190 |
+
|
| 191 |
+
return avg_loss, char_error_rate, exact_match_accuracy
|
| 192 |
+
|
| 193 |
+
# --- Training Function ---
|
| 194 |
+
def train_ocr_model(model: nn.Module, train_loader: DataLoader,
|
| 195 |
+
test_loader: DataLoader, char_indexer: CharIndexer,
|
| 196 |
+
epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
|
| 197 |
+
"""
|
| 198 |
+
Trains the OCR model using CTC loss.
|
| 199 |
+
"""
|
| 200 |
+
# CTCLoss needs the blank token index
|
| 201 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
| 202 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
|
| 203 |
+
# Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
|
| 204 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
|
| 205 |
+
|
| 206 |
+
model.to(device) # Ensure model is on the correct device
|
| 207 |
+
model.train() # Set model to training mode
|
| 208 |
+
|
| 209 |
+
training_history = {
|
| 210 |
+
'train_loss': [],
|
| 211 |
+
'test_loss': [],
|
| 212 |
+
'test_cer': [],
|
| 213 |
+
'test_exact_match_accuracy': []
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
for epoch in range(epochs):
|
| 217 |
+
running_loss = 0.0
|
| 218 |
+
pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
|
| 219 |
+
for images, texts_encoded, _, text_lengths in pbar_train:
|
| 220 |
+
images = images.to(device)
|
| 221 |
+
# Ensure target tensors are on the correct device for CTCLoss calculation
|
| 222 |
+
texts_encoded = texts_encoded.to(device)
|
| 223 |
+
text_lengths = text_lengths.to(device)
|
| 224 |
+
|
| 225 |
+
optimizer.zero_grad() # Clear gradients from previous step
|
| 226 |
+
outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
|
| 227 |
+
|
| 228 |
+
# `outputs.shape[0]` is the actual sequence length (T) produced by the model.
|
| 229 |
+
# CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
|
| 230 |
+
outputs_seq_len_for_ctc = torch.full(
|
| 231 |
+
size=(outputs.shape[1],), # batch_size
|
| 232 |
+
fill_value=outputs.shape[0], # actual sequence length (T) from model output
|
| 233 |
+
dtype=torch.long,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
| 238 |
+
log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
|
| 239 |
+
|
| 240 |
+
# Use outputs_seq_len_for_ctc for the input_lengths argument
|
| 241 |
+
loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
|
| 242 |
+
loss.backward() # Backpropagate
|
| 243 |
+
optimizer.step() # Update model weights
|
| 244 |
+
|
| 245 |
+
running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
|
| 246 |
+
pbar_train.set_postfix(loss=loss.item())
|
| 247 |
+
|
| 248 |
+
epoch_train_loss = running_loss / len(train_loader.dataset)
|
| 249 |
+
training_history['train_loss'].append(epoch_train_loss)
|
| 250 |
+
|
| 251 |
+
# Evaluate on test set using the dedicated function
|
| 252 |
+
# Ensure model is in eval mode before calling evaluate_model
|
| 253 |
+
model.eval()
|
| 254 |
+
test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
|
| 255 |
+
training_history['test_loss'].append(test_loss)
|
| 256 |
+
training_history['test_cer'].append(test_cer)
|
| 257 |
+
training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
|
| 258 |
+
|
| 259 |
+
# Adjust learning rate based on test loss
|
| 260 |
+
scheduler.step(test_loss)
|
| 261 |
+
|
| 262 |
+
print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
|
| 263 |
+
f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
|
| 264 |
+
|
| 265 |
+
if progress_callback:
|
| 266 |
+
# Update progress bar with current epoch and key metrics
|
| 267 |
+
progress_val = (epoch + 1) / epochs
|
| 268 |
+
progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
|
| 269 |
+
|
| 270 |
+
model.train() # Set model back to training mode after evaluation
|
| 271 |
+
|
| 272 |
+
return model, training_history
|
| 273 |
+
|
| 274 |
+
def save_ocr_model(model: nn.Module, path: str):
|
| 275 |
+
"""Saves the state dictionary of the trained OCR model."""
|
| 276 |
+
torch.save(model.state_dict(), path)
|
| 277 |
+
print(f"OCR model saved to {path}")
|
| 278 |
+
|
| 279 |
+
def load_ocr_model(model: nn.Module, path: str):
|
| 280 |
+
"""
|
| 281 |
+
Loads a trained OCR model's state dictionary.
|
| 282 |
+
Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
|
| 283 |
+
"""
|
| 284 |
+
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
|
| 285 |
+
model.eval() # Set to evaluation mode
|
| 286 |
+
print(f"OCR model loaded from {path}")
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,227 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# app.py
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
# Disable Streamlit file watcher to prevent conflicts with PyTorch
|
| 6 |
+
os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
|
| 7 |
+
|
| 8 |
import streamlit as st
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
import traceback
|
| 16 |
+
|
| 17 |
+
# Import all necessary configuration values from config.py
|
| 18 |
+
from config import (
|
| 19 |
+
IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
|
| 20 |
+
TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
| 21 |
+
MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Import classes and functions from data_handler_ocr.py and model_ocr.py
|
| 25 |
+
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
|
| 26 |
+
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
| 27 |
+
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# --- Global Variables ---
|
| 31 |
+
ocr_model = None
|
| 32 |
+
char_indexer = None
|
| 33 |
+
training_history = None
|
| 34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
|
| 36 |
+
# --- Streamlit App Setup ---
|
| 37 |
+
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
st.title("π Handwritten Name Recognition (OCR) App")
|
| 41 |
+
st.markdown("""
|
| 42 |
+
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
| 43 |
+
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
| 44 |
+
of a handwritten name for prediction or train a new model using the provided dataset.
|
| 45 |
+
|
| 46 |
+
**Note:** Training a robust OCR model can be time-consuming.
|
| 47 |
+
""")
|
| 48 |
+
|
| 49 |
+
# --- Initialize CharIndexer ---
|
| 50 |
+
# This initializes char_indexer once when the script starts
|
| 51 |
+
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
|
| 52 |
+
|
| 53 |
+
# --- Model Loading / Initialization ---
|
| 54 |
+
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
| 55 |
+
def get_and_load_ocr_model_cached(num_classes, model_path):
|
| 56 |
+
"""
|
| 57 |
+
Initializes the OCR model and attempts to load a pre-trained model.
|
| 58 |
+
If no pre-trained model exists, a new model instance is returned.
|
| 59 |
+
"""
|
| 60 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
| 61 |
+
|
| 62 |
+
if os.path.exists(model_path):
|
| 63 |
+
st.sidebar.info("Loading pre-trained OCR model...")
|
| 64 |
+
try:
|
| 65 |
+
# Load model to CPU first, then move to device
|
| 66 |
+
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 67 |
+
st.sidebar.success("OCR model loaded successfully!")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
|
| 70 |
+
# If loading fails, re-initialize an untrained model
|
| 71 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
| 72 |
+
else:
|
| 73 |
+
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
|
| 74 |
+
|
| 75 |
+
return model_instance
|
| 76 |
+
|
| 77 |
+
# Get the model instance and assign it to the global 'ocr_model'
|
| 78 |
+
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
| 79 |
+
# Ensure the model is on the correct device for inference
|
| 80 |
+
ocr_model.to(device)
|
| 81 |
+
ocr_model.eval() # Set model to evaluation mode for inference by default
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# --- Sidebar for Model Training ---
|
| 85 |
+
st.sidebar.header("Train OCR Model")
|
| 86 |
+
st.sidebar.write("Click the button below to start training the OCR model.")
|
| 87 |
+
|
| 88 |
+
# Progress bar and label for training in the sidebar
|
| 89 |
+
progress_bar_sidebar = st.sidebar.progress(0)
|
| 90 |
+
progress_label_sidebar = st.sidebar.empty()
|
| 91 |
+
|
| 92 |
+
def update_progress_callback_sidebar(value, text):
|
| 93 |
+
progress_bar_sidebar.progress(int(value * 100))
|
| 94 |
+
progress_label_sidebar.text(text)
|
| 95 |
+
|
| 96 |
+
if st.sidebar.button("π Start Training"):
|
| 97 |
+
progress_bar_sidebar.progress(0)
|
| 98 |
+
progress_label_sidebar.empty()
|
| 99 |
+
st.empty()
|
| 100 |
+
|
| 101 |
+
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
|
| 102 |
+
st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
|
| 103 |
+
elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
|
| 104 |
+
st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
|
| 105 |
+
"Evaluation might be affected or skipped. Please ensure all data paths are correct.")
|
| 106 |
+
else:
|
| 107 |
+
st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
|
| 111 |
+
st.sidebar.success("Training and Test DataFrames loaded successfully.")
|
| 112 |
+
|
| 113 |
+
st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
|
| 114 |
+
|
| 115 |
+
train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
|
| 116 |
+
st.sidebar.success("DataLoaders created successfully.")
|
| 117 |
+
|
| 118 |
+
ocr_model.train()
|
| 119 |
+
|
| 120 |
+
st.sidebar.write("Training in progress... This may take a while.")
|
| 121 |
+
ocr_model, training_history = train_ocr_model(
|
| 122 |
+
model=ocr_model,
|
| 123 |
+
train_loader=train_loader,
|
| 124 |
+
test_loader=test_loader,
|
| 125 |
+
char_indexer=char_indexer,
|
| 126 |
+
epochs=NUM_EPOCHS,
|
| 127 |
+
device=device,
|
| 128 |
+
progress_callback=update_progress_callback_sidebar
|
| 129 |
+
)
|
| 130 |
+
st.sidebar.success("OCR model training finished!")
|
| 131 |
+
update_progress_callback_sidebar(1.0, "Training complete!")
|
| 132 |
+
|
| 133 |
+
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
| 134 |
+
save_ocr_model(ocr_model, MODEL_SAVE_PATH)
|
| 135 |
+
st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
st.sidebar.error(f"An error occurred during training: {e}")
|
| 139 |
+
st.exception(e)
|
| 140 |
+
update_progress_callback_sidebar(0.0, "Training failed!")
|
| 141 |
+
|
| 142 |
+
# --- Sidebar for Model Loading ---
|
| 143 |
+
st.sidebar.header("Load Pre-trained Model")
|
| 144 |
+
st.sidebar.write("If you have a saved model, you can load it here instead of training.")
|
| 145 |
+
|
| 146 |
+
if st.sidebar.button("πΎ Load Model"):
|
| 147 |
+
if os.path.exists(MODEL_SAVE_PATH):
|
| 148 |
+
try:
|
| 149 |
+
loaded_model = CRNN(num_classes=char_indexer.num_classes)
|
| 150 |
+
load_ocr_model(loaded_model, MODEL_SAVE_PATH)
|
| 151 |
+
loaded_model.to(device)
|
| 152 |
+
|
| 153 |
+
st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
|
| 154 |
+
except Exception as e:
|
| 155 |
+
st.sidebar.error(f"Error loading model: {e}")
|
| 156 |
+
st.exception(e)
|
| 157 |
+
else:
|
| 158 |
+
st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
|
| 159 |
+
|
| 160 |
+
# --- Main Content: Prediction Section and Training History ---
|
| 161 |
+
|
| 162 |
+
# Display training history chart
|
| 163 |
+
if training_history:
|
| 164 |
+
st.subheader("Training History Plots")
|
| 165 |
+
history_df = pd.DataFrame({
|
| 166 |
+
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
| 167 |
+
'Train Loss': training_history['train_loss'],
|
| 168 |
+
'Test Loss': training_history['test_loss'],
|
| 169 |
+
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
|
| 170 |
+
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
st.markdown("**Loss over Epochs**")
|
| 174 |
+
st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
|
| 175 |
+
st.caption("Lower loss indicates better model performance.")
|
| 176 |
+
|
| 177 |
+
st.markdown("**Character Error Rate (CER) over Epochs**")
|
| 178 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
|
| 179 |
+
st.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
| 180 |
+
|
| 181 |
+
st.markdown("**Exact Match Accuracy over Epochs**")
|
| 182 |
+
st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
|
| 183 |
+
st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
| 184 |
+
|
| 185 |
+
st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
|
| 186 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
|
| 187 |
+
st.caption("CER should decrease, Accuracy should increase.")
|
| 188 |
+
st.write("---") # Separator after charts
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Predict on a New Image
|
| 192 |
+
|
| 193 |
+
if ocr_model is None:
|
| 194 |
+
st.warning("Please train or load a model before attempting prediction.")
|
| 195 |
+
else:
|
| 196 |
+
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
|
| 197 |
+
|
| 198 |
+
if uploaded_file is not None:
|
| 199 |
+
try:
|
| 200 |
+
image_pil = Image.open(uploaded_file).convert('L')
|
| 201 |
+
st.image(image_pil, caption="Uploaded Image", use_container_width=True)
|
| 202 |
+
st.write("---")
|
| 203 |
+
st.write("Processing and Recognizing...")
|
| 204 |
+
|
| 205 |
+
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
| 206 |
+
|
| 207 |
+
ocr_model.eval()
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
output = ocr_model(processed_image_tensor)
|
| 210 |
+
|
| 211 |
+
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
| 212 |
+
predicted_text = predicted_texts[0]
|
| 213 |
+
|
| 214 |
+
st.success(f"Recognized Text: **{predicted_text}**")
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
st.error(f"Error processing image or recognizing text: {e}")
|
| 218 |
+
st.info("π‘ **Tips for best results:**\n"
|
| 219 |
+
"- Ensure the handwritten text is clear and on a clean background.\n"
|
| 220 |
+
"- Only include one name/word per image.\n"
|
| 221 |
+
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
| 222 |
+
st.exception(e)
|
| 223 |
|
| 224 |
+
st.markdown("""
|
| 225 |
+
---
|
| 226 |
+
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
|
| 227 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils_ocr.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#utils_ocr.py
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Import config for IMG_HEIGHT and MAX_IMG_WIDTH
|
| 11 |
+
from config import IMG_HEIGHT, MAX_IMG_WIDTH
|
| 12 |
+
|
| 13 |
+
# --- Image Preprocessing Functions ---
|
| 14 |
+
|
| 15 |
+
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
| 16 |
+
"""Loads an image from path and converts it to grayscale PIL Image."""
|
| 17 |
+
if not os.path.exists(image_path):
|
| 18 |
+
raise FileNotFoundError(f"Image not found at: {image_path}")
|
| 19 |
+
return Image.open(image_path).convert('L') # 'L' for grayscale
|
| 20 |
+
|
| 21 |
+
def binarize_image(img: Image.Image) -> Image.Image:
|
| 22 |
+
"""
|
| 23 |
+
Binarizes a grayscale PIL Image using Otsu's method.
|
| 24 |
+
Returns a PIL Image.
|
| 25 |
+
"""
|
| 26 |
+
# Convert PIL Image to OpenCV format (numpy array)
|
| 27 |
+
img_np = np.array(img)
|
| 28 |
+
|
| 29 |
+
# Apply Otsu's binarization
|
| 30 |
+
_, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 31 |
+
|
| 32 |
+
# Convert back to PIL Image
|
| 33 |
+
return Image.fromarray(binary_img)
|
| 34 |
+
|
| 35 |
+
def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
|
| 36 |
+
"""
|
| 37 |
+
Resizes a PIL Image to a fixed height while maintaining aspect ratio.
|
| 38 |
+
Also ensures the width does not exceed MAX_IMG_WIDTH.
|
| 39 |
+
"""
|
| 40 |
+
width, height = img.size
|
| 41 |
+
|
| 42 |
+
# Calculate new width based on target height, maintaining aspect ratio
|
| 43 |
+
new_width = int(width * (img_height / height))
|
| 44 |
+
|
| 45 |
+
if new_width > MAX_IMG_WIDTH:
|
| 46 |
+
new_width = MAX_IMG_WIDTH
|
| 47 |
+
resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
|
| 48 |
+
if resized_img.width > MAX_IMG_WIDTH:
|
| 49 |
+
# Crop the image from the left to MAX_IMG_WIDTH
|
| 50 |
+
resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
|
| 51 |
+
return resized_img
|
| 52 |
+
|
| 53 |
+
return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
|
| 54 |
+
|
| 55 |
+
def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
Normalizes a torch.Tensor image (grayscale) for input into the model.
|
| 58 |
+
Puts pixel values in range [-1, 1].
|
| 59 |
+
Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
|
| 60 |
+
"""
|
| 61 |
+
# Formula: (pixel_value - mean) / std_dev
|
| 62 |
+
# For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
|
| 63 |
+
img_tensor = (img_tensor - 0.5) / 0.5
|
| 64 |
+
return img_tensor
|
| 65 |
+
|
| 66 |
+
def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Applies all necessary preprocessing steps to a user-uploaded PIL Image
|
| 69 |
+
to prepare it for the OCR model.
|
| 70 |
+
"""
|
| 71 |
+
# Define a transformation pipeline similar to the dataset, but including ToTensor
|
| 72 |
+
transform_pipeline = transforms.Compose([
|
| 73 |
+
transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
|
| 74 |
+
# Use the updated resize function that also handles MAX_IMG_WIDTH
|
| 75 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
|
| 76 |
+
transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
|
| 77 |
+
transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
processed_image = transform_pipeline(image_pil)
|
| 81 |
+
|
| 82 |
+
# Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
|
| 83 |
+
return processed_image.unsqueeze(0)
|