import gradio as gr import cv2 import numpy as np import torch from torchvision import transforms from PIL import Image import pandas as pd import sys import os import matplotlib.pyplot as plt # Import preprocessing and model sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) from src.utils.preprocessing import preprocess_image, deskew from src.models.crnn import CRNN # Define device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Build vocabulary directly from labels.csv without loading images try: df = pd.read_csv('data/labels.csv') chars = set() for text in df['text']: if pd.notna(text): chars.update(list(str(text))) vocab = sorted(list(chars)) idx_to_char = {i+1: c for i, c in enumerate(vocab)} num_classes = len(vocab) + 1 print(f"Loaded vocabulary with {len(vocab)} characters") except Exception as e: print(f"Could not load vocabulary from labels.csv: {e}") # Fallback to standard IAM vocab if dataset not available vocab = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!? ") idx_to_char = {i+1: c for i, c in enumerate(vocab)} num_classes = len(vocab) + 1 # Load Model model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device) import glob def get_latest_checkpoint(weights_dir='weights'): checkpoints = glob.glob(os.path.join(weights_dir, 'crnn_baseline_epoch_*.pth')) if not checkpoints: return None # Sort by epoch number checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) return checkpoints[-1] weights_path = get_latest_checkpoint() if weights_path and os.path.exists(weights_path): print(f"Loading trained weights from {weights_path}...") try: model.load_state_dict(torch.load(weights_path, map_location=device)) except Exception as e: print(f"Error loading weights perfectly (might be minor mismatch): {e}") model.load_state_dict(torch.load(weights_path, map_location=device), strict=False) else: print(f"Warning: Could not find any weights in weights/. Model will output random predictions.") model.eval() # Transform matching training exactly transform = transforms.Compose([ transforms.Resize((32, 1024)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) def decode_predictions(preds, idx_to_char): _, max_preds = torch.max(preds, 2) max_preds = max_preds.permute(1, 0) decoded_texts = [] for batch_idx in range(max_preds.size(0)): pred_seq = max_preds[batch_idx] decoded_seq = [] for i in range(len(pred_seq)): if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]): char_idx = pred_seq[i].item() if char_idx in idx_to_char: decoded_seq.append(idx_to_char[char_idx]) decoded_texts.append("".join(decoded_seq)) return decoded_texts def auto_crop_image(gray_img): # Apply Gaussian blur to reduce noise blurred = cv2.GaussianBlur(gray_img, (5, 5), 0) # Apply Otsu's thresholding to separate dark ink from white background _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # Find contours (shapes) in the image contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return gray_img # Filter contours to exclude tiny noise and giant objects (like the pen) img_area = gray_img.shape[0] * gray_img.shape[1] valid_contours = [] for c in contours: area = cv2.contourArea(c) # Keep contours that are larger than a speck of dust but smaller than half the image if 20 < area < (img_area * 0.4): valid_contours.append(c) if not valid_contours: return gray_img # Fallback to original if filtering removes everything # Find the bounding box that encompasses all valid text contours x_min, y_min = float('inf'), float('inf') x_max, y_max = 0, 0 for c in valid_contours: x, y, w, h = cv2.boundingRect(c) x_min = min(x_min, x) y_min = min(y_min, y) x_max = max(x_max, x + w) y_max = max(y_max, y + h) # Add a generous padding around the text pad_y = int((y_max - y_min) * 0.2) pad_x = int((x_max - x_min) * 0.05) x_min = max(0, x_min - pad_x) y_min = max(0, y_min - pad_y) x_max = min(gray_img.shape[1], x_max + pad_x) y_max = min(gray_img.shape[0], y_max + pad_y) # Crop the image cropped = gray_img[y_min:y_max, x_min:x_max] # CRITICAL FIX for Out-of-Distribution aspect ratios: # The training data (IAM dataset) has an average aspect ratio of ~16:1. # The training pipeline blindly squashes images to 32x1024 (32:1 ratio). # If a user uploads a short word (like a 3:1 ratio "THANK YOU"), # it gets stretched 10x horizontally, destroying the letters! # To fix this, we pad the cropped image with white space on the right # so its aspect ratio matches the training average (16:1) BEFORE squashing. h, w = cropped.shape target_aspect_ratio = 16.0 if w / h < target_aspect_ratio: target_w = int(h * target_aspect_ratio) pad_width = target_w - w # Pad with white (255) on the right cropped = cv2.copyMakeBorder(cropped, 0, 0, 0, pad_width, cv2.BORDER_CONSTANT, value=255) return cropped def process_and_predict(image, apply_auto_crop=True): if image is None: return None, "Please upload an image.", None, None, None # Convert Gradio Image (which is a PIL Image by default) to grayscale if not isinstance(image, Image.Image): image = Image.fromarray(image) gray_image = image.convert('L') # For display purposes (Gradio output image) img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) gray_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) # CRITICAL: Binarization (Otsu's thresholding) to force pure black text on pure white background # This removes shadows, lighting gradients, and colored paper backgrounds # that the model was never trained on. blurred = cv2.GaussianBlur(gray_cv, (5, 5), 0) _, binarized = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) if not apply_auto_crop: # If auto-crop is disabled, we bypass all fancy preprocessing to precisely # match the dataset loading behavior. This ensures dataset images work perfectly. gray_image_pil = Image.fromarray(gray_cv) img_tensor = transform(gray_image_pil).unsqueeze(0).to(device) # For display, just show what the network sees (squashed) display_processed_img = np.array(gray_image_pil.resize((1024, 32), Image.BILINEAR)) else: # Auto-crop if requested (using the binarized image for cleaner crops) processed_base = auto_crop_image(binarized) deskewed_img = deskew(processed_base) processed_img_np = preprocess_image(deskewed_img, target_size=(1024, 32)) display_processed_img = processed_img_np # Convert cropped numpy array back to PIL for tensor transform gray_image_cropped = Image.fromarray(display_processed_img) # For Model Prediction # We must use exactly the same transform as training, and pass a PIL image img_tensor = transform(gray_image_cropped).unsqueeze(0).to(device) # Predict and extract features with torch.no_grad(): # Get CNN features for activation map cnn_features = model.cnn(img_tensor) # shape: (1, 512, 1, seq_len) preds = model(img_tensor) preds = preds.permute(1, 0, 2) # (seq_len, batch, num_classes) decoded_text = decode_predictions(preds, idx_to_char)[0] # Calculate probabilities from LogSoftmax output probs = torch.exp(preds[:, 0, :]) # shape: (seq_len, num_classes) if not decoded_text.strip(): decoded_text = "[Model returned blank - Needs more training epochs]" # 1. Generate CTC Probability Matrix Heatmap probs_np = probs.cpu().numpy().T # shape: (num_classes, seq_len) fig_heatmap, ax1 = plt.subplots(figsize=(10, 4)) cax = ax1.imshow(probs_np, aspect='auto', cmap='viridis') ax1.set_title("CTC Probability Matrix Heatmap") ax1.set_xlabel("Time Frame (Sequence Steps)") ax1.set_ylabel("Vocabulary Character Index") fig_heatmap.colorbar(cax, ax=ax1, fraction=0.046, pad=0.04, label="Probability") plt.tight_layout() # 2. Generate Character Confidence Bar Chart max_probs, max_idx = torch.max(probs, dim=1) chars = [] confidences = [] for i in range(len(max_idx)): if max_idx[i] != 0 and (i == 0 or max_idx[i] != max_idx[i-1]): char_idx = max_idx[i].item() if char_idx in idx_to_char: chars.append(idx_to_char[char_idx]) confidences.append(max_probs[i].item()) # Adjust width based on number of characters fig_bar, ax2 = plt.subplots(figsize=(max(8, len(chars)*0.4), 4)) if chars: bars = ax2.bar(range(len(chars)), confidences, color='#FF9900') ax2.set_xticks(range(len(chars))) ax2.set_xticklabels(chars) ax2.set_ylim(0, 1.1) ax2.set_title("Character Confidence Scores") ax2.set_ylabel("Confidence Probability") # Add percentage labels above bars for bar in bars: yval = bar.get_height() ax2.text(bar.get_x() + bar.get_width()/2.0, yval + 0.02, f'{yval*100:.0f}%', va='bottom', ha='center', fontsize=8, rotation=45) else: ax2.text(0.5, 0.5, "No characters predicted", ha='center', va='center') plt.tight_layout() # 3. Generate CNN Feature Activation Overlay # Average the CNN features across all channels to get a 1D activation map activation = torch.mean(cnn_features, dim=1).squeeze().cpu().numpy() # Normalize activation to 0-255 activation = (activation - activation.min()) / (activation.max() - activation.min() + 1e-8) activation = (activation * 255).astype(np.uint8) # Resize to match the original image dimensions heatmap_img = cv2.resize(activation, (processed_img_np.shape[1], processed_img_np.shape[0])) # Apply color map heatmap_color = cv2.applyColorMap(heatmap_img, cv2.COLORMAP_JET) # Convert grayscale original image to BGR so we can blend it original_bgr = cv2.cvtColor(display_processed_img, cv2.COLOR_GRAY2BGR) # Overlay heatmap on original image (50% alpha blend) overlay_img = cv2.addWeighted(heatmap_color, 0.5, original_bgr, 0.5, 0) # Convert BGR to RGB for Gradio display overlay_img = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB) return display_processed_img, decoded_text, fig_heatmap, fig_bar, overlay_img # Redesign UI with Gradio Blocks for a proper Dashboard layout with gr.Blocks(title="Handwritten Text Recognition (HTR)", theme=gr.themes.Soft()) as demo: gr.Markdown("
Grayscale, aspect-ratio preserved, padded to 32x1024
") output_text = gr.Textbox(label="Predicted Text", lines=2) gr.Markdown("---") gr.Markdown("### 📊 Model Insights & Analytics (Explainable AI)") with gr.Accordion("📖 How to read these graphs (Interpretation Guide)", open=False): gr.Markdown(""" **1. CNN Feature Activation Overlay:** Shows exactly where the model's 'eyes' are focusing on the image. Red/hot areas indicate regions with strong visual features (like complex curves or sharp lines) that the Convolutional Neural Network detected. **2. CTC Probability Matrix Heatmap:** Shows *when* the model made a decision. The X-axis is the timeline (reading left-to-right), and the Y-axis contains all possible characters. Yellow dots indicate the exact moment the AI identified a specific letter. **3. Character Confidence Scores:** Shows *how sure* the model is about each letter it predicted. If the model misreads a word, this chart usually shows a low confidence score for the incorrect letter, proving it was uncertain. """) with gr.Row(): cnn_activation_image = gr.Image(type="numpy", label="1. CNN Feature Activation Overlay") with gr.Row(): heatmap_plot = gr.Plot(label="2. CTC Probability Heatmap") with gr.Row(): confidence_plot = gr.Plot(label="3. Character Confidence Scores") submit_btn.click( fn=process_and_predict, inputs=[input_image, auto_crop_checkbox], outputs=[output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image] ) clear_btn.click( fn=lambda: [None, True, None, "", None, None, None], inputs=[], outputs=[input_image, auto_crop_checkbox, output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image] ) if __name__ == "__main__": demo.launch(share=True)