Spaces:
Build error
Build error
Training Model Complete
Browse files- app.py +227 -219
- config.py +13 -77
- data_handler_ocr.py +165 -151
- model_ocr.py +285 -286
- utils_ocr.py +60 -161
app.py
CHANGED
|
@@ -1,219 +1,227 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# app.py
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
import
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import
|
| 15 |
-
import
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
from
|
| 27 |
-
from
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# --- Global Variables ---
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
st.
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
st.sidebar.
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
#
|
| 80 |
-
ocr_model
|
| 81 |
-
#
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
st.sidebar.
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
st.
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# app.py
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
# Disable Streamlit file watcher to prevent conflicts with PyTorch
|
| 6 |
+
os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
import traceback
|
| 16 |
+
|
| 17 |
+
# Import all necessary configuration values from config.py
|
| 18 |
+
from config import (
|
| 19 |
+
IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
|
| 20 |
+
TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
| 21 |
+
MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Import classes and functions from data_handler_ocr.py and model_ocr.py
|
| 25 |
+
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
|
| 26 |
+
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
| 27 |
+
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# --- Global Variables ---
|
| 31 |
+
ocr_model = None
|
| 32 |
+
char_indexer = None
|
| 33 |
+
training_history = None
|
| 34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
|
| 36 |
+
# --- Streamlit App Setup ---
|
| 37 |
+
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
st.title("π Handwritten Name Recognition (OCR) App")
|
| 41 |
+
st.markdown("""
|
| 42 |
+
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
| 43 |
+
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
| 44 |
+
of a handwritten name for prediction or train a new model using the provided dataset.
|
| 45 |
+
|
| 46 |
+
**Note:** Training a robust OCR model can be time-consuming.
|
| 47 |
+
""")
|
| 48 |
+
|
| 49 |
+
# --- Initialize CharIndexer ---
|
| 50 |
+
# This initializes char_indexer once when the script starts
|
| 51 |
+
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
|
| 52 |
+
|
| 53 |
+
# --- Model Loading / Initialization ---
|
| 54 |
+
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
| 55 |
+
def get_and_load_ocr_model_cached(num_classes, model_path):
|
| 56 |
+
"""
|
| 57 |
+
Initializes the OCR model and attempts to load a pre-trained model.
|
| 58 |
+
If no pre-trained model exists, a new model instance is returned.
|
| 59 |
+
"""
|
| 60 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
| 61 |
+
|
| 62 |
+
if os.path.exists(model_path):
|
| 63 |
+
st.sidebar.info("Loading pre-trained OCR model...")
|
| 64 |
+
try:
|
| 65 |
+
# Load model to CPU first, then move to device
|
| 66 |
+
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 67 |
+
st.sidebar.success("OCR model loaded successfully!")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
|
| 70 |
+
# If loading fails, re-initialize an untrained model
|
| 71 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
| 72 |
+
else:
|
| 73 |
+
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
|
| 74 |
+
|
| 75 |
+
return model_instance
|
| 76 |
+
|
| 77 |
+
# Get the model instance and assign it to the global 'ocr_model'
|
| 78 |
+
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
| 79 |
+
# Ensure the model is on the correct device for inference
|
| 80 |
+
ocr_model.to(device)
|
| 81 |
+
ocr_model.eval() # Set model to evaluation mode for inference by default
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# --- Sidebar for Model Training ---
|
| 85 |
+
st.sidebar.header("Train OCR Model")
|
| 86 |
+
st.sidebar.write("Click the button below to start training the OCR model.")
|
| 87 |
+
|
| 88 |
+
# Progress bar and label for training in the sidebar
|
| 89 |
+
progress_bar_sidebar = st.sidebar.progress(0)
|
| 90 |
+
progress_label_sidebar = st.sidebar.empty()
|
| 91 |
+
|
| 92 |
+
def update_progress_callback_sidebar(value, text):
|
| 93 |
+
progress_bar_sidebar.progress(int(value * 100))
|
| 94 |
+
progress_label_sidebar.text(text)
|
| 95 |
+
|
| 96 |
+
if st.sidebar.button("π Start Training"):
|
| 97 |
+
progress_bar_sidebar.progress(0)
|
| 98 |
+
progress_label_sidebar.empty()
|
| 99 |
+
st.empty()
|
| 100 |
+
|
| 101 |
+
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
|
| 102 |
+
st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
|
| 103 |
+
elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
|
| 104 |
+
st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
|
| 105 |
+
"Evaluation might be affected or skipped. Please ensure all data paths are correct.")
|
| 106 |
+
else:
|
| 107 |
+
st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
|
| 111 |
+
st.sidebar.success("Training and Test DataFrames loaded successfully.")
|
| 112 |
+
|
| 113 |
+
st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
|
| 114 |
+
|
| 115 |
+
train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
|
| 116 |
+
st.sidebar.success("DataLoaders created successfully.")
|
| 117 |
+
|
| 118 |
+
ocr_model.train()
|
| 119 |
+
|
| 120 |
+
st.sidebar.write("Training in progress... This may take a while.")
|
| 121 |
+
ocr_model, training_history = train_ocr_model(
|
| 122 |
+
model=ocr_model,
|
| 123 |
+
train_loader=train_loader,
|
| 124 |
+
test_loader=test_loader,
|
| 125 |
+
char_indexer=char_indexer,
|
| 126 |
+
epochs=NUM_EPOCHS,
|
| 127 |
+
device=device,
|
| 128 |
+
progress_callback=update_progress_callback_sidebar
|
| 129 |
+
)
|
| 130 |
+
st.sidebar.success("OCR model training finished!")
|
| 131 |
+
update_progress_callback_sidebar(1.0, "Training complete!")
|
| 132 |
+
|
| 133 |
+
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
| 134 |
+
save_ocr_model(ocr_model, MODEL_SAVE_PATH)
|
| 135 |
+
st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
st.sidebar.error(f"An error occurred during training: {e}")
|
| 139 |
+
st.exception(e)
|
| 140 |
+
update_progress_callback_sidebar(0.0, "Training failed!")
|
| 141 |
+
|
| 142 |
+
# --- Sidebar for Model Loading ---
|
| 143 |
+
st.sidebar.header("Load Pre-trained Model")
|
| 144 |
+
st.sidebar.write("If you have a saved model, you can load it here instead of training.")
|
| 145 |
+
|
| 146 |
+
if st.sidebar.button("πΎ Load Model"):
|
| 147 |
+
if os.path.exists(MODEL_SAVE_PATH):
|
| 148 |
+
try:
|
| 149 |
+
loaded_model = CRNN(num_classes=char_indexer.num_classes)
|
| 150 |
+
load_ocr_model(loaded_model, MODEL_SAVE_PATH)
|
| 151 |
+
loaded_model.to(device)
|
| 152 |
+
|
| 153 |
+
st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
|
| 154 |
+
except Exception as e:
|
| 155 |
+
st.sidebar.error(f"Error loading model: {e}")
|
| 156 |
+
st.exception(e)
|
| 157 |
+
else:
|
| 158 |
+
st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
|
| 159 |
+
|
| 160 |
+
# --- Main Content: Prediction Section and Training History ---
|
| 161 |
+
|
| 162 |
+
# Display training history chart
|
| 163 |
+
if training_history:
|
| 164 |
+
st.subheader("Training History Plots")
|
| 165 |
+
history_df = pd.DataFrame({
|
| 166 |
+
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
| 167 |
+
'Train Loss': training_history['train_loss'],
|
| 168 |
+
'Test Loss': training_history['test_loss'],
|
| 169 |
+
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
|
| 170 |
+
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
st.markdown("**Loss over Epochs**")
|
| 174 |
+
st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
|
| 175 |
+
st.caption("Lower loss indicates better model performance.")
|
| 176 |
+
|
| 177 |
+
st.markdown("**Character Error Rate (CER) over Epochs**")
|
| 178 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
|
| 179 |
+
st.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
| 180 |
+
|
| 181 |
+
st.markdown("**Exact Match Accuracy over Epochs**")
|
| 182 |
+
st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
|
| 183 |
+
st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
| 184 |
+
|
| 185 |
+
st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
|
| 186 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
|
| 187 |
+
st.caption("CER should decrease, Accuracy should increase.")
|
| 188 |
+
st.write("---") # Separator after charts
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Predict on a New Image
|
| 192 |
+
|
| 193 |
+
if ocr_model is None:
|
| 194 |
+
st.warning("Please train or load a model before attempting prediction.")
|
| 195 |
+
else:
|
| 196 |
+
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
|
| 197 |
+
|
| 198 |
+
if uploaded_file is not None:
|
| 199 |
+
try:
|
| 200 |
+
image_pil = Image.open(uploaded_file).convert('L')
|
| 201 |
+
st.image(image_pil, caption="Uploaded Image", use_container_width=True)
|
| 202 |
+
st.write("---")
|
| 203 |
+
st.write("Processing and Recognizing...")
|
| 204 |
+
|
| 205 |
+
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
| 206 |
+
|
| 207 |
+
ocr_model.eval()
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
output = ocr_model(processed_image_tensor)
|
| 210 |
+
|
| 211 |
+
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
| 212 |
+
predicted_text = predicted_texts[0]
|
| 213 |
+
|
| 214 |
+
st.success(f"Recognized Text: **{predicted_text}**")
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
st.error(f"Error processing image or recognizing text: {e}")
|
| 218 |
+
st.info("π‘ **Tips for best results:**\n"
|
| 219 |
+
"- Ensure the handwritten text is clear and on a clean background.\n"
|
| 220 |
+
"- Only include one name/word per image.\n"
|
| 221 |
+
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
| 222 |
+
st.exception(e)
|
| 223 |
+
|
| 224 |
+
st.markdown("""
|
| 225 |
+
---
|
| 226 |
+
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
|
| 227 |
+
""")
|
config.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
# config.py
|
| 3 |
|
| 4 |
import os
|
|
@@ -8,8 +7,8 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
| 8 |
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
| 9 |
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
| 10 |
|
| 11 |
-
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images'
|
| 12 |
-
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images'
|
| 13 |
|
| 14 |
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
| 15 |
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
|
@@ -17,26 +16,13 @@ TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
|
| 17 |
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
| 18 |
|
| 19 |
# --- Character Set and OCR Configuration ---
|
| 20 |
-
# This character set MUST cover all characters present in your dataset.
|
| 21 |
-
# Add any special characters if needed.
|
| 22 |
-
# The order here is crucial as it defines the indices for your characters.
|
| 23 |
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
| 24 |
-
|
| 25 |
-
# Define the character for the blank token. It MUST NOT be in CHARS.
|
| 26 |
-
BLANK_TOKEN_SYMBOL = 'Γ'
|
| 27 |
-
|
| 28 |
-
# Construct the full vocabulary string. It's conventional to put the blank token last.
|
| 29 |
-
# This VOCABULARY string is what you pass to CharIndexer.
|
| 30 |
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
| 31 |
-
|
| 32 |
-
# NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
|
| 33 |
NUM_CLASSES = len(VOCABULARY)
|
| 34 |
-
|
| 35 |
-
# BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
|
| 36 |
-
# Since we appended it last, its index will be len(CHARS).
|
| 37 |
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
| 38 |
|
| 39 |
-
# --- Sanity Checks
|
| 40 |
if BLANK_TOKEN == -1:
|
| 41 |
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
| 42 |
if BLANK_TOKEN >= NUM_CLASSES:
|
|
@@ -48,65 +34,15 @@ print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
|
| 48 |
|
| 49 |
|
| 50 |
# --- Image Preprocessing Parameters ---
|
| 51 |
-
IMG_HEIGHT = 32
|
|
|
|
| 52 |
|
| 53 |
# --- Training Parameters ---
|
| 54 |
-
BATCH_SIZE =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
LEARNING_RATE = 0.001
|
| 56 |
-
=======
|
| 57 |
-
# config.py
|
| 58 |
-
|
| 59 |
-
import os
|
| 60 |
-
|
| 61 |
-
# --- Paths ---
|
| 62 |
-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 63 |
-
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
| 64 |
-
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
| 65 |
-
|
| 66 |
-
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
|
| 67 |
-
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
|
| 68 |
-
|
| 69 |
-
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
| 70 |
-
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
| 71 |
-
|
| 72 |
-
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
| 73 |
-
|
| 74 |
-
# --- Character Set and OCR Configuration ---
|
| 75 |
-
# This character set MUST cover all characters present in your dataset.
|
| 76 |
-
# Add any special characters if needed.
|
| 77 |
-
# The order here is crucial as it defines the indices for your characters.
|
| 78 |
-
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
| 79 |
-
|
| 80 |
-
# Define the character for the blank token. It MUST NOT be in CHARS.
|
| 81 |
-
BLANK_TOKEN_SYMBOL = 'Γ'
|
| 82 |
-
|
| 83 |
-
# Construct the full vocabulary string. It's conventional to put the blank token last.
|
| 84 |
-
# This VOCABULARY string is what you pass to CharIndexer.
|
| 85 |
-
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
| 86 |
-
|
| 87 |
-
# NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
|
| 88 |
-
NUM_CLASSES = len(VOCABULARY)
|
| 89 |
-
|
| 90 |
-
# BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
|
| 91 |
-
# Since we appended it last, its index will be len(CHARS).
|
| 92 |
-
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
| 93 |
-
|
| 94 |
-
# --- Sanity Checks (Highly Recommended) ---
|
| 95 |
-
if BLANK_TOKEN == -1:
|
| 96 |
-
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
| 97 |
-
if BLANK_TOKEN >= NUM_CLASSES:
|
| 98 |
-
raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
|
| 99 |
-
|
| 100 |
-
print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
|
| 101 |
-
print(f"Vocabulary Length: {len(VOCABULARY)}")
|
| 102 |
-
print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
# --- Image Preprocessing Parameters ---
|
| 106 |
-
IMG_HEIGHT = 32
|
| 107 |
-
|
| 108 |
-
# --- Training Parameters ---
|
| 109 |
-
BATCH_SIZE = 64
|
| 110 |
-
LEARNING_RATE = 0.001
|
| 111 |
-
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
| 112 |
-
NUM_EPOCHS = 3
|
|
|
|
|
|
|
| 1 |
# config.py
|
| 2 |
|
| 3 |
import os
|
|
|
|
| 7 |
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
| 8 |
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
| 9 |
|
| 10 |
+
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
|
| 11 |
+
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
|
| 12 |
|
| 13 |
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
| 14 |
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
|
|
|
| 16 |
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
| 17 |
|
| 18 |
# --- Character Set and OCR Configuration ---
|
|
|
|
|
|
|
|
|
|
| 19 |
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
| 20 |
+
BLANK_TOKEN_SYMBOL = 'Γ'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
|
|
|
|
|
|
| 22 |
NUM_CLASSES = len(VOCABULARY)
|
|
|
|
|
|
|
|
|
|
| 23 |
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
| 24 |
|
| 25 |
+
# --- Sanity Checks ---
|
| 26 |
if BLANK_TOKEN == -1:
|
| 27 |
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
| 28 |
if BLANK_TOKEN >= NUM_CLASSES:
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
# --- Image Preprocessing Parameters ---
|
| 37 |
+
IMG_HEIGHT = 32 # Target height for all input images to the model
|
| 38 |
+
MAX_IMG_WIDTH = 1024 # Adjust this value based on your typical image widths and available RAM
|
| 39 |
|
| 40 |
# --- Training Parameters ---
|
| 41 |
+
BATCH_SIZE = 10
|
| 42 |
+
|
| 43 |
+
# NEW: Dataset Limits
|
| 44 |
+
TRAIN_SAMPLES_LIMIT = 1000
|
| 45 |
+
TEST_SAMPLES_LIMIT = 1000
|
| 46 |
+
|
| 47 |
+
NUM_EPOCHS = 5
|
| 48 |
LEARNING_RATE = 0.001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_handler_ocr.py
CHANGED
|
@@ -1,151 +1,165 @@
|
|
| 1 |
-
#data_handler_ocr.py
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import torch
|
| 5 |
-
from torch.utils.data import Dataset, DataLoader
|
| 6 |
-
from torchvision import transforms
|
| 7 |
-
import os
|
| 8 |
-
from PIL import Image
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
-
# Import utility functions and config
|
| 13 |
-
from config import
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
"
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
if idx
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
ground_truth_text =
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#data_handler_ocr.py
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
# Import utility functions and config
|
| 13 |
+
from config import (
|
| 14 |
+
VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT,
|
| 15 |
+
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
| 16 |
+
TRAIN_SAMPLES_LIMIT, TEST_SAMPLES_LIMIT
|
| 17 |
+
)
|
| 18 |
+
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
| 19 |
+
|
| 20 |
+
class CharIndexer:
|
| 21 |
+
"""Manages character-to-index and index-to-character mappings."""
|
| 22 |
+
def __init__(self, vocabulary_string: str, blank_token_symbol: str):
|
| 23 |
+
self.chars = sorted(list(set(vocabulary_string)))
|
| 24 |
+
self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
|
| 25 |
+
self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
|
| 26 |
+
|
| 27 |
+
if blank_token_symbol not in self.char_to_idx:
|
| 28 |
+
raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
|
| 29 |
+
|
| 30 |
+
self.blank_token_idx = self.char_to_idx[blank_token_symbol]
|
| 31 |
+
self.num_classes = len(self.chars)
|
| 32 |
+
|
| 33 |
+
if self.blank_token_idx >= self.num_classes:
|
| 34 |
+
raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
|
| 35 |
+
|
| 36 |
+
print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
|
| 37 |
+
print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
|
| 38 |
+
|
| 39 |
+
def encode(self, text: str) -> list[int]:
|
| 40 |
+
"""Converts a text string to a list of integer indices."""
|
| 41 |
+
encoded_list = []
|
| 42 |
+
for char in text:
|
| 43 |
+
if char in self.char_to_idx:
|
| 44 |
+
encoded_list.append(self.char_to_idx[char])
|
| 45 |
+
else:
|
| 46 |
+
print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
|
| 47 |
+
encoded_list.append(self.blank_token_idx)
|
| 48 |
+
return encoded_list
|
| 49 |
+
|
| 50 |
+
def decode(self, indices: list[int]) -> str:
|
| 51 |
+
"""Converts a list of integer indices back to a text string."""
|
| 52 |
+
decoded_text = []
|
| 53 |
+
for i, idx in enumerate(indices):
|
| 54 |
+
if idx == self.blank_token_idx:
|
| 55 |
+
continue # Skip blank tokens
|
| 56 |
+
|
| 57 |
+
if i > 0 and indices[i-1] == idx:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
if idx in self.idx_to_char:
|
| 61 |
+
decoded_text.append(self.idx_to_char[idx])
|
| 62 |
+
else:
|
| 63 |
+
print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
|
| 64 |
+
|
| 65 |
+
return "".join(decoded_text)
|
| 66 |
+
|
| 67 |
+
class OCRDataset(Dataset):
|
| 68 |
+
"""
|
| 69 |
+
Custom PyTorch Dataset for the Handwritten Name Recognition task.
|
| 70 |
+
Loads images and their corresponding text labels.
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
| 73 |
+
self.data = dataframe
|
| 74 |
+
self.char_indexer = char_indexer
|
| 75 |
+
self.image_dir = image_dir
|
| 76 |
+
|
| 77 |
+
if transform is None:
|
| 78 |
+
self.transform = transforms.Compose([
|
| 79 |
+
transforms.Lambda(lambda img: binarize_image(img)),
|
| 80 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)), # Resize image to fixed height
|
| 81 |
+
transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor (H, W) -> (1, H, W), scales to [0,1]
|
| 82 |
+
transforms.Lambda(normalize_image_for_model) # Normalize pixel values to [-1, 1]
|
| 83 |
+
])
|
| 84 |
+
else:
|
| 85 |
+
self.transform = transform
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def __len__(self) -> int:
|
| 89 |
+
return len(self.data)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx):
|
| 92 |
+
raw_filename_entry = self.data.loc[idx, 'FILENAME']
|
| 93 |
+
ground_truth_text = self.data.loc[idx, 'IDENTITY']
|
| 94 |
+
|
| 95 |
+
filename = raw_filename_entry.split(',')[0].strip()
|
| 96 |
+
img_path = os.path.join(self.image_dir, filename)
|
| 97 |
+
ground_truth_text = str(ground_truth_text)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
image = load_image_as_grayscale(img_path) # Returns PIL Image 'L'
|
| 101 |
+
except FileNotFoundError:
|
| 102 |
+
print(f"Error: Image file not found at {img_path}. Skipping this item.")
|
| 103 |
+
raise
|
| 104 |
+
|
| 105 |
+
if self.transform:
|
| 106 |
+
image = self.transform(image)
|
| 107 |
+
|
| 108 |
+
image_width = image.shape[2] # Assuming image is (C, H, W) after transform
|
| 109 |
+
|
| 110 |
+
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
| 111 |
+
text_length = len(text_encoded)
|
| 112 |
+
|
| 113 |
+
return image, text_encoded, image_width, text_length
|
| 114 |
+
|
| 115 |
+
def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 116 |
+
"""
|
| 117 |
+
Custom collate function for the DataLoader to handle variable-width images
|
| 118 |
+
and variable-length text sequences for CTC loss.
|
| 119 |
+
"""
|
| 120 |
+
images, texts, image_widths, text_lengths = zip(*batch)
|
| 121 |
+
|
| 122 |
+
max_batch_width = max(image_widths)
|
| 123 |
+
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
| 124 |
+
images_batch = torch.stack(padded_images, 0)
|
| 125 |
+
|
| 126 |
+
texts_batch = torch.cat(texts, 0)
|
| 127 |
+
text_lengths_tensor = torch.tensor(list(text_lengths), dtype=torch.long)
|
| 128 |
+
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
|
| 129 |
+
|
| 130 |
+
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 134 |
+
"""
|
| 135 |
+
Loads training and testing dataframes.
|
| 136 |
+
Assumes CSVs have 'FILENAME' and 'IDENTITY' columns.
|
| 137 |
+
Applies dataset limits from config.py.
|
| 138 |
+
"""
|
| 139 |
+
train_df = pd.read_csv(train_csv_path, encoding='ISO-8859-1')
|
| 140 |
+
test_df = pd.read_csv(test_csv_path, encoding='ISO-8859-1')
|
| 141 |
+
|
| 142 |
+
# Apply limits if they are set (not 0)
|
| 143 |
+
if TRAIN_SAMPLES_LIMIT > 0:
|
| 144 |
+
train_df = train_df.head(TRAIN_SAMPLES_LIMIT)
|
| 145 |
+
print(f"Limited training data to {TRAIN_SAMPLES_LIMIT} samples.")
|
| 146 |
+
if TEST_SAMPLES_LIMIT > 0:
|
| 147 |
+
test_df = test_df.head(TEST_SAMPLES_LIMIT)
|
| 148 |
+
print(f"Limited test data to {TEST_SAMPLES_LIMIT} samples.")
|
| 149 |
+
|
| 150 |
+
return train_df, test_df
|
| 151 |
+
|
| 152 |
+
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
| 153 |
+
char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
|
| 154 |
+
"""
|
| 155 |
+
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
| 156 |
+
using specific image directories for train/test.
|
| 157 |
+
"""
|
| 158 |
+
train_dataset = OCRDataset(train_df, char_indexer, TRAIN_IMAGES_DIR)
|
| 159 |
+
test_dataset = OCRDataset(test_df, char_indexer, TEST_IMAGES_DIR)
|
| 160 |
+
|
| 161 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
| 162 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
| 163 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
| 164 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
| 165 |
+
return train_loader, test_loader
|
model_ocr.py
CHANGED
|
@@ -1,286 +1,285 @@
|
|
| 1 |
-
# model_ocr.py
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as
|
| 5 |
-
import torch.
|
| 6 |
-
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
from
|
| 14 |
-
from
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
nn.
|
| 28 |
-
nn.
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
nn.
|
| 33 |
-
nn.
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
nn.
|
| 38 |
-
nn.
|
| 39 |
-
nn.
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
nn.
|
| 46 |
-
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
#
|
| 106 |
-
#
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
model.
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
'
|
| 211 |
-
'
|
| 212 |
-
'
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
# `
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
loss
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
epoch_train_loss
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
#
|
| 252 |
-
|
| 253 |
-
model
|
| 254 |
-
test_loss
|
| 255 |
-
training_history['
|
| 256 |
-
training_history['
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
progress_val =
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
model.
|
| 285 |
-
|
| 286 |
-
print(f"OCR model loaded from {path}")
|
|
|
|
| 1 |
+
# model_ocr.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from sklearn.metrics import accuracy_score
|
| 9 |
+
import editdistance
|
| 10 |
+
|
| 11 |
+
# Import config and char_indexer
|
| 12 |
+
from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
|
| 13 |
+
from data_handler_ocr import CharIndexer
|
| 14 |
+
from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CNN_Backbone(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
CNN feature extractor for OCR. Designed to produce features suitable for RNN.
|
| 20 |
+
Output feature map should have height 1 after the final pooling/reduction.
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, input_channels=1, output_channels=512):
|
| 23 |
+
super(CNN_Backbone, self).__init__()
|
| 24 |
+
self.cnn = nn.Sequential(
|
| 25 |
+
# First block
|
| 26 |
+
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
|
| 27 |
+
nn.ReLU(True),
|
| 28 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
|
| 29 |
+
|
| 30 |
+
# Second block
|
| 31 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
| 32 |
+
nn.ReLU(True),
|
| 33 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
|
| 34 |
+
|
| 35 |
+
# Third block (with two conv layers)
|
| 36 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
| 37 |
+
nn.ReLU(True),
|
| 38 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
| 39 |
+
nn.ReLU(True),
|
| 40 |
+
# This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
|
| 41 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
|
| 42 |
+
|
| 43 |
+
# Fourth block
|
| 44 |
+
nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
|
| 45 |
+
nn.ReLU(True),
|
| 46 |
+
# This AdaptiveAvgPool2d makes sure the height dimension becomes 1
|
| 47 |
+
# while preserving the width. This is crucial for RNN input.
|
| 48 |
+
nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
| 53 |
+
|
| 54 |
+
# Pass through the CNN layers
|
| 55 |
+
conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
|
| 56 |
+
|
| 57 |
+
# Squeeze the height dimension (which is 1)
|
| 58 |
+
# This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
|
| 59 |
+
conv_features = conv_features.squeeze(2)
|
| 60 |
+
|
| 61 |
+
# Permute for RNN input: (sequence_length, batch_size, input_size)
|
| 62 |
+
# This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
|
| 63 |
+
conv_features = conv_features.permute(2, 0, 1)
|
| 64 |
+
|
| 65 |
+
# Return the CNN features, ready for the RNN layer in CRNN
|
| 66 |
+
return conv_features
|
| 67 |
+
|
| 68 |
+
class BidirectionalLSTM(nn.Module):
|
| 69 |
+
"""Bidirectional LSTM layer for sequence modeling."""
|
| 70 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
|
| 71 |
+
super(BidirectionalLSTM, self).__init__()
|
| 72 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
| 73 |
+
bidirectional=True, dropout=dropout, batch_first=False)
|
| 74 |
+
# batch_first=False expects input as (sequence_length, batch_size, input_size)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
|
| 78 |
+
return output
|
| 79 |
+
|
| 80 |
+
class CRNN(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
Convolutional Recurrent Neural Network for OCR.
|
| 83 |
+
Combines CNN for feature extraction, LSTMs for sequence modeling,
|
| 84 |
+
and a final linear layer for character prediction.
|
| 85 |
+
"""
|
| 86 |
+
def __init__(self, num_classes: int, cnn_output_channels: int = 512,
|
| 87 |
+
rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
|
| 88 |
+
super(CRNN, self).__init__()
|
| 89 |
+
self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
|
| 90 |
+
# Input to LSTM is the number of channels from the CNN output
|
| 91 |
+
self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
|
| 92 |
+
# Output of bidirectional LSTM is hidden_size * 2
|
| 93 |
+
self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
|
| 94 |
+
|
| 95 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 96 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
| 97 |
+
|
| 98 |
+
# 1. Pass through the CNN to extract features
|
| 99 |
+
conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
|
| 100 |
+
|
| 101 |
+
# 2. Pass CNN features through the RNN (LSTM)
|
| 102 |
+
rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
|
| 103 |
+
|
| 104 |
+
# 3. Pass RNN features through the final fully connected layer
|
| 105 |
+
# Apply the linear layer to each time step independently
|
| 106 |
+
# output will be (W_prime, N, num_classes)
|
| 107 |
+
output = self.fc(rnn_features)
|
| 108 |
+
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# --- Decoding Function ---
|
| 113 |
+
def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
|
| 114 |
+
"""
|
| 115 |
+
Performs greedy decoding on the CTC output.
|
| 116 |
+
output: (sequence_length, batch_size, num_classes) - raw logits
|
| 117 |
+
"""
|
| 118 |
+
# Apply log_softmax to get probabilities for argmax
|
| 119 |
+
log_probs = F.log_softmax(output, dim=2)
|
| 120 |
+
|
| 121 |
+
# Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
|
| 122 |
+
predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
|
| 123 |
+
|
| 124 |
+
decoded_texts = []
|
| 125 |
+
for seq in predicted_indices:
|
| 126 |
+
# Use char_indexer's decode method, which handles blank removal and duplicate collapse
|
| 127 |
+
decoded_texts.append(char_indexer.decode(seq.tolist()))
|
| 128 |
+
return decoded_texts
|
| 129 |
+
|
| 130 |
+
# --- Evaluation Function ---
|
| 131 |
+
def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
|
| 132 |
+
model.eval()
|
| 133 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
| 134 |
+
total_loss = 0
|
| 135 |
+
all_predictions = []
|
| 136 |
+
all_ground_truths = []
|
| 137 |
+
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
|
| 140 |
+
inputs = inputs.to(device)
|
| 141 |
+
targets_padded = targets_padded.to(device)
|
| 142 |
+
target_lengths_tensor = target_lengths.to(device)
|
| 143 |
+
|
| 144 |
+
output = model(inputs)
|
| 145 |
+
|
| 146 |
+
outputs_seq_len_for_ctc = torch.full(
|
| 147 |
+
size=(output.shape[1],),
|
| 148 |
+
fill_value=output.shape[0],
|
| 149 |
+
dtype=torch.long,
|
| 150 |
+
device=device
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
| 154 |
+
log_probs_for_loss = F.log_softmax(output, dim=2)
|
| 155 |
+
|
| 156 |
+
# CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
|
| 157 |
+
loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
|
| 158 |
+
total_loss += loss.item() * inputs.size(0)
|
| 159 |
+
|
| 160 |
+
decoded_preds = ctc_greedy_decode(output, char_indexer)
|
| 161 |
+
all_predictions.extend(decoded_preds)
|
| 162 |
+
|
| 163 |
+
ground_truths_batch = []
|
| 164 |
+
current_idx_in_concatenated_targets = 0
|
| 165 |
+
|
| 166 |
+
target_lengths_list = target_lengths.cpu().tolist()
|
| 167 |
+
|
| 168 |
+
for i in range(inputs.size(0)):
|
| 169 |
+
length = target_lengths_list[i]
|
| 170 |
+
|
| 171 |
+
current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
|
| 172 |
+
ground_truths_batch.append(char_indexer.decode(current_target_segment))
|
| 173 |
+
current_idx_in_concatenated_targets += length
|
| 174 |
+
|
| 175 |
+
all_ground_truths.extend(ground_truths_batch)
|
| 176 |
+
|
| 177 |
+
avg_loss = total_loss / len(dataloader.dataset)
|
| 178 |
+
|
| 179 |
+
# Calculate Character Error Rate (CER)
|
| 180 |
+
cer_sum = 0
|
| 181 |
+
total_chars = 0
|
| 182 |
+
for pred, gt in zip(all_predictions, all_ground_truths):
|
| 183 |
+
cer_sum += editdistance.eval(pred, gt)
|
| 184 |
+
total_chars += len(gt)
|
| 185 |
+
char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
|
| 186 |
+
|
| 187 |
+
# Calculate Exact Match Accuracy (Word-level Accuracy)
|
| 188 |
+
exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
|
| 189 |
+
|
| 190 |
+
return avg_loss, char_error_rate, exact_match_accuracy
|
| 191 |
+
|
| 192 |
+
# --- Training Function ---
|
| 193 |
+
def train_ocr_model(model: nn.Module, train_loader: DataLoader,
|
| 194 |
+
test_loader: DataLoader, char_indexer: CharIndexer,
|
| 195 |
+
epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
|
| 196 |
+
"""
|
| 197 |
+
Trains the OCR model using CTC loss.
|
| 198 |
+
"""
|
| 199 |
+
# CTCLoss needs the blank token index
|
| 200 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
| 201 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
|
| 202 |
+
# Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
|
| 203 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
|
| 204 |
+
|
| 205 |
+
model.to(device) # Ensure model is on the correct device
|
| 206 |
+
model.train() # Set model to training mode
|
| 207 |
+
|
| 208 |
+
training_history = {
|
| 209 |
+
'train_loss': [],
|
| 210 |
+
'test_loss': [],
|
| 211 |
+
'test_cer': [],
|
| 212 |
+
'test_exact_match_accuracy': []
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
for epoch in range(epochs):
|
| 216 |
+
running_loss = 0.0
|
| 217 |
+
pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
|
| 218 |
+
for images, texts_encoded, _, text_lengths in pbar_train:
|
| 219 |
+
images = images.to(device)
|
| 220 |
+
# Ensure target tensors are on the correct device for CTCLoss calculation
|
| 221 |
+
texts_encoded = texts_encoded.to(device)
|
| 222 |
+
text_lengths = text_lengths.to(device)
|
| 223 |
+
|
| 224 |
+
optimizer.zero_grad() # Clear gradients from previous step
|
| 225 |
+
outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
|
| 226 |
+
|
| 227 |
+
# `outputs.shape[0]` is the actual sequence length (T) produced by the model.
|
| 228 |
+
# CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
|
| 229 |
+
outputs_seq_len_for_ctc = torch.full(
|
| 230 |
+
size=(outputs.shape[1],), # batch_size
|
| 231 |
+
fill_value=outputs.shape[0], # actual sequence length (T) from model output
|
| 232 |
+
dtype=torch.long,
|
| 233 |
+
device=device
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
| 237 |
+
log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
|
| 238 |
+
|
| 239 |
+
# Use outputs_seq_len_for_ctc for the input_lengths argument
|
| 240 |
+
loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
|
| 241 |
+
loss.backward() # Backpropagate
|
| 242 |
+
optimizer.step() # Update model weights
|
| 243 |
+
|
| 244 |
+
running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
|
| 245 |
+
pbar_train.set_postfix(loss=loss.item())
|
| 246 |
+
|
| 247 |
+
epoch_train_loss = running_loss / len(train_loader.dataset)
|
| 248 |
+
training_history['train_loss'].append(epoch_train_loss)
|
| 249 |
+
|
| 250 |
+
# Evaluate on test set using the dedicated function
|
| 251 |
+
# Ensure model is in eval mode before calling evaluate_model
|
| 252 |
+
model.eval()
|
| 253 |
+
test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
|
| 254 |
+
training_history['test_loss'].append(test_loss)
|
| 255 |
+
training_history['test_cer'].append(test_cer)
|
| 256 |
+
training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
|
| 257 |
+
|
| 258 |
+
# Adjust learning rate based on test loss
|
| 259 |
+
scheduler.step(test_loss)
|
| 260 |
+
|
| 261 |
+
print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
|
| 262 |
+
f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
|
| 263 |
+
|
| 264 |
+
if progress_callback:
|
| 265 |
+
# Update progress bar with current epoch and key metrics
|
| 266 |
+
progress_val = (epoch + 1) / epochs
|
| 267 |
+
progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
|
| 268 |
+
|
| 269 |
+
model.train() # Set model back to training mode after evaluation
|
| 270 |
+
|
| 271 |
+
return model, training_history
|
| 272 |
+
|
| 273 |
+
def save_ocr_model(model: nn.Module, path: str):
|
| 274 |
+
"""Saves the state dictionary of the trained OCR model."""
|
| 275 |
+
torch.save(model.state_dict(), path)
|
| 276 |
+
print(f"OCR model saved to {path}")
|
| 277 |
+
|
| 278 |
+
def load_ocr_model(model: nn.Module, path: str):
|
| 279 |
+
"""
|
| 280 |
+
Loads a trained OCR model's state dictionary.
|
| 281 |
+
Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
|
| 282 |
+
"""
|
| 283 |
+
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
|
| 284 |
+
model.eval() # Set to evaluation mode
|
| 285 |
+
print(f"OCR model loaded from {path}")
|
|
|
utils_ocr.py
CHANGED
|
@@ -1,184 +1,83 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
#utils_ocr.py
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
-
from matplotlib.pylab import f
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import torch
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
| 14 |
"""Loads an image from path and converts it to grayscale PIL Image."""
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
return
|
| 18 |
-
|
| 19 |
-
def binarize_image(image_pil: Image.Image) -> Image.Image:
|
| 20 |
-
"""Binarizes a grayscale PIL Image (black and white)."""
|
| 21 |
-
# Convert PIL to OpenCV format (numpy array)
|
| 22 |
-
img_np = np.array(image_pil)
|
| 23 |
-
# Apply Otsu's thresholding for adaptive binarization
|
| 24 |
-
_, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 25 |
-
# Invert colors: Handwritten text usually dark on light. OCR models often
|
| 26 |
-
# prefer light text on dark background. Check your training data's style.
|
| 27 |
-
# This example assumes dark text on light background and inverts to white text on black.
|
| 28 |
-
img_bin = 255 - img_bin
|
| 29 |
-
return Image.fromarray(img_bin)
|
| 30 |
|
| 31 |
-
def
|
| 32 |
"""
|
| 33 |
-
|
| 34 |
-
|
| 35 |
"""
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
def
|
| 43 |
"""
|
| 44 |
-
|
|
|
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
def
|
| 55 |
"""
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
"""
|
| 59 |
-
#
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# Binarize
|
| 63 |
-
img_bin = binarize_image(img_gray)
|
| 64 |
-
|
| 65 |
-
# Resize (maintain aspect ratio)
|
| 66 |
-
img_resized = resize_image_for_ocr(img_bin, target_height)
|
| 67 |
-
|
| 68 |
-
# Normalize and convert to tensor
|
| 69 |
-
img_tensor = normalize_image_for_model(img_resized)
|
| 70 |
-
|
| 71 |
-
# Add batch dimension: (C, H, W) -> (1, C, H, W)
|
| 72 |
-
img_tensor = img_tensor.unsqueeze(0)
|
| 73 |
-
|
| 74 |
return img_tensor
|
| 75 |
|
| 76 |
-
def
|
| 77 |
"""
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
Output tensor shape: (C, H, max_width)
|
| 81 |
"""
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
-
import cv2
|
| 96 |
-
from matplotlib.pylab import f
|
| 97 |
-
import numpy as np
|
| 98 |
-
from PIL import Image
|
| 99 |
-
import torch
|
| 100 |
-
from torchvision import transforms
|
| 101 |
-
|
| 102 |
-
# --- Image Preprocessing for OCR ---
|
| 103 |
-
|
| 104 |
-
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
| 105 |
-
"""Loads an image from path and converts it to grayscale PIL Image."""
|
| 106 |
-
# Use PIL for robust image loading and conversion to grayscale 'L' mode
|
| 107 |
-
img = Image.open(image_path).convert('L')
|
| 108 |
-
return img
|
| 109 |
-
|
| 110 |
-
def binarize_image(image_pil: Image.Image) -> Image.Image:
|
| 111 |
-
"""Binarizes a grayscale PIL Image (black and white)."""
|
| 112 |
-
# Convert PIL to OpenCV format (numpy array)
|
| 113 |
-
img_np = np.array(image_pil)
|
| 114 |
-
# Apply Otsu's thresholding for adaptive binarization
|
| 115 |
-
_, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 116 |
-
# Invert colors: Handwritten text usually dark on light. OCR models often
|
| 117 |
-
# prefer light text on dark background. Check your training data's style.
|
| 118 |
-
# This example assumes dark text on light background and inverts to white text on black.
|
| 119 |
-
img_bin = 255 - img_bin
|
| 120 |
-
return Image.fromarray(img_bin)
|
| 121 |
-
|
| 122 |
-
def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
|
| 123 |
-
"""
|
| 124 |
-
Resizes a PIL Image to a target height while maintaining aspect ratio.
|
| 125 |
-
Pads width if necessary to avoid distortion.
|
| 126 |
-
"""
|
| 127 |
-
original_width, original_height = image_pil.size
|
| 128 |
-
# Calculate new width based on target height and original aspect ratio
|
| 129 |
-
new_width = int(original_width * (target_height / original_height))
|
| 130 |
-
resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
|
| 131 |
-
return resized_img
|
| 132 |
-
|
| 133 |
-
def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
|
| 134 |
-
"""
|
| 135 |
-
Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
|
| 136 |
-
"""
|
| 137 |
-
# Convert to tensor (scales to 0-1 automatically)
|
| 138 |
-
tensor_transform = transforms.ToTensor()
|
| 139 |
-
img_tensor = tensor_transform(image_pil)
|
| 140 |
-
# For grayscale images, mean and std are single values.
|
| 141 |
-
# Adjust normalization values if your training data uses different ones.
|
| 142 |
-
img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
|
| 143 |
-
return img_tensor
|
| 144 |
-
|
| 145 |
-
def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
| 146 |
-
"""
|
| 147 |
-
Combines all preprocessing steps for a single user-uploaded image
|
| 148 |
-
to prepare it for the OCR model.
|
| 149 |
-
"""
|
| 150 |
-
# Ensure it's grayscale
|
| 151 |
-
img_gray = uploaded_image_pil.convert('L')
|
| 152 |
-
|
| 153 |
-
# Binarize
|
| 154 |
-
img_bin = binarize_image(img_gray)
|
| 155 |
-
|
| 156 |
-
# Resize (maintain aspect ratio)
|
| 157 |
-
img_resized = resize_image_for_ocr(img_bin, target_height)
|
| 158 |
-
|
| 159 |
-
# Normalize and convert to tensor
|
| 160 |
-
img_tensor = normalize_image_for_model(img_resized)
|
| 161 |
-
|
| 162 |
-
# Add batch dimension: (C, H, W) -> (1, C, H, W)
|
| 163 |
-
img_tensor = img_tensor.unsqueeze(0)
|
| 164 |
-
|
| 165 |
-
return img_tensor
|
| 166 |
-
|
| 167 |
-
def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
|
| 168 |
-
"""
|
| 169 |
-
Pads a single image tensor to a max_width with zeros.
|
| 170 |
-
Input tensor shape: (C, H, W)
|
| 171 |
-
Output tensor shape: (C, H, max_width)
|
| 172 |
-
"""
|
| 173 |
-
C, H, W = image_tensor.shape
|
| 174 |
-
if W > max_width:
|
| 175 |
-
# If image is wider than max_width, you might want to crop or resize it.
|
| 176 |
-
# For this example, we'll just return a warning or clip.
|
| 177 |
-
# A more robust solution might split text lines or use a different resizing strategy.
|
| 178 |
-
print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
|
| 179 |
-
return image_tensor[:, :, :max_width] # Simple cropping
|
| 180 |
-
padding = max_width - W
|
| 181 |
-
# Pad on the right (P_left, P_right, P_top, P_bottom)
|
| 182 |
-
padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
|
| 183 |
-
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
| 184 |
-
return padded_tensor
|
|
|
|
|
|
|
| 1 |
#utils_ocr.py
|
| 2 |
|
| 3 |
import cv2
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
+
# Import config for IMG_HEIGHT and MAX_IMG_WIDTH
|
| 11 |
+
from config import IMG_HEIGHT, MAX_IMG_WIDTH
|
| 12 |
+
|
| 13 |
+
# --- Image Preprocessing Functions ---
|
| 14 |
|
| 15 |
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
| 16 |
"""Loads an image from path and converts it to grayscale PIL Image."""
|
| 17 |
+
if not os.path.exists(image_path):
|
| 18 |
+
raise FileNotFoundError(f"Image not found at: {image_path}")
|
| 19 |
+
return Image.open(image_path).convert('L') # 'L' for grayscale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
def binarize_image(img: Image.Image) -> Image.Image:
|
| 22 |
"""
|
| 23 |
+
Binarizes a grayscale PIL Image using Otsu's method.
|
| 24 |
+
Returns a PIL Image.
|
| 25 |
"""
|
| 26 |
+
# Convert PIL Image to OpenCV format (numpy array)
|
| 27 |
+
img_np = np.array(img)
|
| 28 |
+
|
| 29 |
+
# Apply Otsu's binarization
|
| 30 |
+
_, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 31 |
+
|
| 32 |
+
# Convert back to PIL Image
|
| 33 |
+
return Image.fromarray(binary_img)
|
| 34 |
|
| 35 |
+
def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
|
| 36 |
"""
|
| 37 |
+
Resizes a PIL Image to a fixed height while maintaining aspect ratio.
|
| 38 |
+
Also ensures the width does not exceed MAX_IMG_WIDTH.
|
| 39 |
"""
|
| 40 |
+
width, height = img.size
|
| 41 |
+
|
| 42 |
+
# Calculate new width based on target height, maintaining aspect ratio
|
| 43 |
+
new_width = int(width * (img_height / height))
|
| 44 |
+
|
| 45 |
+
if new_width > MAX_IMG_WIDTH:
|
| 46 |
+
new_width = MAX_IMG_WIDTH
|
| 47 |
+
resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
|
| 48 |
+
if resized_img.width > MAX_IMG_WIDTH:
|
| 49 |
+
# Crop the image from the left to MAX_IMG_WIDTH
|
| 50 |
+
resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
|
| 51 |
+
return resized_img
|
| 52 |
+
|
| 53 |
+
return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
|
| 54 |
|
| 55 |
+
def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
|
| 56 |
"""
|
| 57 |
+
Normalizes a torch.Tensor image (grayscale) for input into the model.
|
| 58 |
+
Puts pixel values in range [-1, 1].
|
| 59 |
+
Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
|
| 60 |
"""
|
| 61 |
+
# Formula: (pixel_value - mean) / std_dev
|
| 62 |
+
# For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
|
| 63 |
+
img_tensor = (img_tensor - 0.5) / 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
return img_tensor
|
| 65 |
|
| 66 |
+
def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
| 67 |
"""
|
| 68 |
+
Applies all necessary preprocessing steps to a user-uploaded PIL Image
|
| 69 |
+
to prepare it for the OCR model.
|
|
|
|
| 70 |
"""
|
| 71 |
+
# Define a transformation pipeline similar to the dataset, but including ToTensor
|
| 72 |
+
transform_pipeline = transforms.Compose([
|
| 73 |
+
transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
|
| 74 |
+
# Use the updated resize function that also handles MAX_IMG_WIDTH
|
| 75 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
|
| 76 |
+
transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
|
| 77 |
+
transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
processed_image = transform_pipeline(image_pil)
|
| 81 |
+
|
| 82 |
+
# Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
|
| 83 |
+
return processed_image.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|