Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # app.py | |
| import os | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import traceback # Ensure this is imported | |
| # Import all necessary configuration values from config.py | |
| # Wrap this import in a try-except | |
| try: | |
| from config import ( | |
| IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL, | |
| TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, | |
| MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS | |
| ) | |
| except Exception as e: | |
| st.error(f"FATAL ERROR: Could not load config.py. Please check your config.py file for errors. Details: {e}") | |
| st.stop() # Stop the app if config fails to load | |
| # Import classes and functions from data_handler_ocr.py and model_ocr.py | |
| # Wrap these imports in a try-except | |
| try: | |
| from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders | |
| from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode | |
| from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model | |
| except Exception as e: | |
| st.error(f"FATAL ERROR: Could not load core modules (data_handler_ocr.py, model_ocr.py, utils_ocr.py). Please check these files for errors. Details: {e}") | |
| st.stop() # Stop the app if core modules fail to load | |
| # --- Global Variables --- | |
| # Initialize training_history in Streamlit's session state to persist across reruns | |
| if 'training_history' not in st.session_state: | |
| st.session_state.training_history = None | |
| # Initialize ocr_model and char_indexer as None; they will be populated below | |
| ocr_model = None | |
| char_indexer = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Streamlit App Setup --- | |
| st.set_page_config(layout="wide", page_title="Handwritten Name OCR App") | |
| # Main Title and Description (Centered) | |
| main_title_col1, main_title_col2, main_title_col3 = st.columns([1, 3, 1]) | |
| with main_title_col2: | |
| st.title("๐ Handwritten Name Recognition (OCR) App") | |
| # --- Initialize CharIndexer --- | |
| try: | |
| char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL) | |
| except Exception as e: | |
| st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}") | |
| st.stop() | |
| # --- Model Loading / Initialization (Cached and Global) --- | |
| def get_and_load_ocr_model_cached_internal(num_classes, model_path): | |
| """ | |
| Initializes the OCR model and attempts to load a pre-trained model. | |
| Returns (model_instance, message_type, message_text) | |
| """ | |
| model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2) | |
| message_type = "warning" | |
| message_text = "No pre-trained OCR model found. Please train a model using the 'Train & Evaluate' tab." | |
| if os.path.exists(model_path): | |
| try: | |
| model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| message_type = "success" | |
| message_text = "OCR model loaded successfully!" | |
| except Exception as e: | |
| message_type = "error" | |
| message_text = f"Error loading model from '{model_path}' during app startup: {e}. A new model will be initialized." | |
| # If loading fails, re-initialize to a fresh model to avoid issues. | |
| model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2) | |
| return model_instance, message_type, message_text | |
| # Display messages OUTSIDE the cached function | |
| try: | |
| loaded_model_instance, load_msg_type, load_msg_text = get_and_load_ocr_model_cached_internal(char_indexer.num_classes, MODEL_SAVE_PATH) | |
| # Assign to global ocr_model | |
| ocr_model = loaded_model_instance | |
| # Display status messages as toasts | |
| if load_msg_type == "success": | |
| st.toast(load_msg_text, icon="โ ") | |
| elif load_msg_type == "warning": | |
| st.toast(load_msg_text, icon="โ ๏ธ") | |
| elif load_msg_type == "error": | |
| st.toast(load_msg_text, icon="๐จ") | |
| if ocr_model is not None: | |
| ocr_model.to(device) | |
| ocr_model.eval() # Set model to evaluation mode for inference by default | |
| else: | |
| st.error("Model instance is None after cached load. Prediction will not be available.") | |
| except Exception as e: | |
| st.error(f"FATAL ERROR: Could not initialize or load OCR model during app startup (outer block). Check model_ocr.py (CRNN class) or your saved model file. Details: {e}") | |
| st.stop() | |
| # --- Define Tabs --- | |
| tabs_col1, tabs_col2, tabs_col3 = st.columns([1, 3, 1]) | |
| with tabs_col2: | |
| tab1, tab2, tab3 = st.tabs([" ๐จ๏ธ Project Description", " ๐ Predict Name", " ๐ Train & Evaluate"]) | |
| # --- Tab 1: Project Description --- | |
| with tab1: | |
| st.markdown(""" | |
| This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch. | |
| Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit. | |
| Here are some helpful resources related to this project: | |
| """) | |
| st.markdown(""" | |
| **[๐ Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)** | |
| **[๐๏ธ Demo Presentation ](https://drive.google.com/file/d/1j_S8cijxy6zxIn3cWg6tuLPNWB_7nwdI/view?usp=sharing)** | |
| **[๐ Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)** | |
| **[๐ Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)** | |
| """) | |
| # --- Tab 2: Predict Name (Main Content: Prediction Section) --- | |
| with tab2: | |
| st.markdown("Upload a clear image of a single handwritten name or word for recognition.") | |
| # Check the global ocr_model for prediction availability | |
| if ocr_model is None: | |
| st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.") | |
| else: | |
| uploaded_file = st.file_uploader("๐ผ๏ธ Choose an image...", type=["png", "jpg", "jpeg", "jfif"]) | |
| if uploaded_file is not None: | |
| try: | |
| image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale | |
| st.image(image_pil, caption="Uploaded Image", use_container_width=True) | |
| st.write("---") | |
| st.write("Processing and Recognizing...") | |
| processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device) | |
| ocr_model.eval() # Ensure model is in eval mode for prediction | |
| with torch.no_grad(): | |
| output = ocr_model(processed_image_tensor) | |
| predicted_texts = ctc_greedy_decode(output, char_indexer) | |
| predicted_text = predicted_texts[0] | |
| st.success(f"Recognized Text: **{predicted_text}**") | |
| except Exception as e: | |
| st.error(f"Error processing image or recognizing text: {e}") | |
| st.info("๐ก **Tips for best results:**\n" | |
| "- Ensure the handwritten text is clear and on a clean background.\n" | |
| "- Only include one name/word per image.\n" | |
| "- The model is trained on specific characters. Unusual symbols might not be recognized.") | |
| st.exception(e) # Display full traceback for debugging | |
| # --- Tab 3: Train & Evaluate --- | |
| with tab3: | |
| # --- Model Training Section --- | |
| st.subheader("Train OCR Model") | |
| st.write("Click the button below to start training the OCR model.") | |
| # Progress bar and label for training within this tab | |
| progress_message_placeholder = st.empty() | |
| progress_bar_placeholder = st.progress(0) | |
| def update_progress_callback(value, text): | |
| progress_bar_placeholder.progress(int(value * 100)) | |
| progress_message_placeholder.info(text) # Use info for dynamic messages | |
| if st.button("๐ Start Training"): | |
| progress_message_placeholder.empty() # Clear previous messages | |
| progress_bar_placeholder.progress(0) # Reset progress bar | |
| if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR): | |
| st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.") | |
| elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR): | |
| st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. " | |
| "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.") | |
| else: | |
| progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...") | |
| try: | |
| train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH) | |
| progress_message_placeholder.success("Training and Test DataFrames loaded successfully.") | |
| progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples") | |
| progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples") | |
| if len(test_df) == 0: | |
| progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.") | |
| if len(train_df) == 0: | |
| progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.") | |
| if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing | |
| st.stop() # Added st.stop for critical data missing scenario | |
| char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL) | |
| progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.") | |
| train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE) | |
| progress_message_placeholder.success("DataLoaders created successfully.") | |
| ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2) | |
| ocr_model_for_training.to(device) | |
| ocr_model_for_training.train() # Set to train mode before passing | |
| progress_message_placeholder.write("Training in progress... This may take a while.") | |
| ocr_model_for_training, history_result = train_ocr_model( | |
| model=ocr_model_for_training, # Pass the local ocr_model_for_training instance | |
| train_loader=train_loader, | |
| test_loader=test_loader, | |
| char_indexer=char_indexer_for_training, | |
| epochs=NUM_EPOCHS, | |
| device=device, | |
| progress_callback=update_progress_callback | |
| ) | |
| st.session_state.training_history = history_result # Save history to session state | |
| progress_message_placeholder.success("OCR model training finished!") | |
| update_progress_callback(1.0, "Training complete!") | |
| os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True) | |
| save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH) # Save the now trained ocr_model_for_training | |
| progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`") | |
| # Crucial: Update the global ocr_model with the newly trained one | |
| ocr_model = ocr_model_for_training | |
| ocr_model.eval() # Set to eval mode for subsequent predictions | |
| except Exception as e: | |
| progress_message_placeholder.error(f"An error occurred during training: {e}") | |
| st.exception(e) # This will print a detailed traceback in the Streamlit UI | |
| update_progress_callback(0.0, "Training failed!") | |
| st.write("---") | |
| # --- Model Loading Section --- | |
| st.subheader("Load Pre-trained Model") | |
| st.write("If you have a saved model, you can load it here instead of training.") | |
| if st.button("๐พ Load Model"): | |
| if os.path.exists(MODEL_SAVE_PATH): | |
| try: | |
| loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2) | |
| load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH) | |
| loaded_model_instance.to(device) | |
| ocr_model = loaded_model_instance # Update global model reference | |
| ocr_model.eval() # Set to eval mode after loading | |
| st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`") | |
| # For simplicity, training history is only populated after a training run. | |
| # If you need to load history with the model, it would need to be saved separately. | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| st.exception(e) | |
| else: | |
| st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.") | |
| st.write("---") | |
| # --- Training History Plots Section --- | |
| st.subheader("Training History Plots") | |
| if st.session_state.training_history: # Check if history exists in session state | |
| history_df = pd.DataFrame({ | |
| 'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1), | |
| 'Train Loss': st.session_state.training_history['train_loss'], | |
| 'Test Loss': st.session_state.training_history['test_loss'], | |
| 'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']], | |
| 'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']] | |
| }) | |
| st.markdown("**Loss over Epochs**") | |
| st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]) | |
| st.caption("Lower loss indicates better model performance.") | |
| st.markdown("**Character Error Rate (CER) over Epochs**") | |
| st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']]) | |
| st.caption("Lower CER indicates fewer character errors (0% is perfect).") | |
| st.markdown("**Exact Match Accuracy over Epochs**") | |
| st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]) | |
| st.caption("Higher exact match accuracy indicates more perfectly recognized names.") | |
| st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**") | |
| st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']]) | |
| st.caption("CER should decrease, Accuracy should increase.") | |
| else: | |
| st.info("Train the model first to see training history plots here.") | |
| # --- Final Footer (Centered) --- | |
| footer_col1, footer_col2, footer_col3 = st.columns([1, 3, 1]) | |
| with footer_col2: | |
| st.markdown(""" | |
| --- | |
| *Built using Streamlit, PyTorch, OpenCV, and EditDistance ยฉ2025 by MFT* | |
| """) | |