triumphh77's picture
Upload 13 files
f9a156f verified
import gradio as gr
import cv2
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import pandas as pd
import sys
import os
import matplotlib.pyplot as plt
# Import preprocessing and model
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from src.utils.preprocessing import preprocess_image, deskew
from src.models.crnn import CRNN
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Build vocabulary directly from labels.csv without loading images
try:
df = pd.read_csv('data/labels.csv')
chars = set()
for text in df['text']:
if pd.notna(text):
chars.update(list(str(text)))
vocab = sorted(list(chars))
idx_to_char = {i+1: c for i, c in enumerate(vocab)}
num_classes = len(vocab) + 1
print(f"Loaded vocabulary with {len(vocab)} characters")
except Exception as e:
print(f"Could not load vocabulary from labels.csv: {e}")
# Fallback to standard IAM vocab if dataset not available
vocab = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!? ")
idx_to_char = {i+1: c for i, c in enumerate(vocab)}
num_classes = len(vocab) + 1
# Load Model
model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
import glob
def get_latest_checkpoint(weights_dir='weights'):
checkpoints = glob.glob(os.path.join(weights_dir, 'crnn_baseline_epoch_*.pth'))
if not checkpoints:
return None
# Sort by epoch number
checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
return checkpoints[-1]
weights_path = get_latest_checkpoint()
if weights_path and os.path.exists(weights_path):
print(f"Loading trained weights from {weights_path}...")
try:
model.load_state_dict(torch.load(weights_path, map_location=device))
except Exception as e:
print(f"Error loading weights perfectly (might be minor mismatch): {e}")
model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
else:
print(f"Warning: Could not find any weights in weights/. Model will output random predictions.")
model.eval()
# Transform matching training exactly
transform = transforms.Compose([
transforms.Resize((32, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def decode_predictions(preds, idx_to_char):
_, max_preds = torch.max(preds, 2)
max_preds = max_preds.permute(1, 0)
decoded_texts = []
for batch_idx in range(max_preds.size(0)):
pred_seq = max_preds[batch_idx]
decoded_seq = []
for i in range(len(pred_seq)):
if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]):
char_idx = pred_seq[i].item()
if char_idx in idx_to_char:
decoded_seq.append(idx_to_char[char_idx])
decoded_texts.append("".join(decoded_seq))
return decoded_texts
def auto_crop_image(gray_img):
# Apply Gaussian blur to reduce noise
blurred = cv2.GaussianBlur(gray_img, (5, 5), 0)
# Apply Otsu's thresholding to separate dark ink from white background
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Find contours (shapes) in the image
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return gray_img
# Filter contours to exclude tiny noise and giant objects (like the pen)
img_area = gray_img.shape[0] * gray_img.shape[1]
valid_contours = []
for c in contours:
area = cv2.contourArea(c)
# Keep contours that are larger than a speck of dust but smaller than half the image
if 20 < area < (img_area * 0.4):
valid_contours.append(c)
if not valid_contours:
return gray_img # Fallback to original if filtering removes everything
# Find the bounding box that encompasses all valid text contours
x_min, y_min = float('inf'), float('inf')
x_max, y_max = 0, 0
for c in valid_contours:
x, y, w, h = cv2.boundingRect(c)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
# Add a generous padding around the text
pad_y = int((y_max - y_min) * 0.2)
pad_x = int((x_max - x_min) * 0.05)
x_min = max(0, x_min - pad_x)
y_min = max(0, y_min - pad_y)
x_max = min(gray_img.shape[1], x_max + pad_x)
y_max = min(gray_img.shape[0], y_max + pad_y)
# Crop the image
cropped = gray_img[y_min:y_max, x_min:x_max]
# CRITICAL FIX for Out-of-Distribution aspect ratios:
# The training data (IAM dataset) has an average aspect ratio of ~16:1.
# The training pipeline blindly squashes images to 32x1024 (32:1 ratio).
# If a user uploads a short word (like a 3:1 ratio "THANK YOU"),
# it gets stretched 10x horizontally, destroying the letters!
# To fix this, we pad the cropped image with white space on the right
# so its aspect ratio matches the training average (16:1) BEFORE squashing.
h, w = cropped.shape
target_aspect_ratio = 16.0
if w / h < target_aspect_ratio:
target_w = int(h * target_aspect_ratio)
pad_width = target_w - w
# Pad with white (255) on the right
cropped = cv2.copyMakeBorder(cropped, 0, 0, 0, pad_width, cv2.BORDER_CONSTANT, value=255)
return cropped
def process_and_predict(image, apply_auto_crop=True):
if image is None:
return None, "Please upload an image.", None, None, None
# Convert Gradio Image (which is a PIL Image by default) to grayscale
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
gray_image = image.convert('L')
# For display purposes (Gradio output image)
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
gray_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
# CRITICAL: Binarization (Otsu's thresholding) to force pure black text on pure white background
# This removes shadows, lighting gradients, and colored paper backgrounds
# that the model was never trained on.
blurred = cv2.GaussianBlur(gray_cv, (5, 5), 0)
_, binarized = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if not apply_auto_crop:
# If auto-crop is disabled, we bypass all fancy preprocessing to precisely
# match the dataset loading behavior. This ensures dataset images work perfectly.
gray_image_pil = Image.fromarray(gray_cv)
img_tensor = transform(gray_image_pil).unsqueeze(0).to(device)
# For display, just show what the network sees (squashed)
display_processed_img = np.array(gray_image_pil.resize((1024, 32), Image.BILINEAR))
else:
# Auto-crop if requested (using the binarized image for cleaner crops)
processed_base = auto_crop_image(binarized)
deskewed_img = deskew(processed_base)
processed_img_np = preprocess_image(deskewed_img, target_size=(1024, 32))
display_processed_img = processed_img_np
# Convert cropped numpy array back to PIL for tensor transform
gray_image_cropped = Image.fromarray(display_processed_img)
# For Model Prediction
# We must use exactly the same transform as training, and pass a PIL image
img_tensor = transform(gray_image_cropped).unsqueeze(0).to(device)
# Predict and extract features
with torch.no_grad():
# Get CNN features for activation map
cnn_features = model.cnn(img_tensor) # shape: (1, 512, 1, seq_len)
preds = model(img_tensor)
preds = preds.permute(1, 0, 2) # (seq_len, batch, num_classes)
decoded_text = decode_predictions(preds, idx_to_char)[0]
# Calculate probabilities from LogSoftmax output
probs = torch.exp(preds[:, 0, :]) # shape: (seq_len, num_classes)
if not decoded_text.strip():
decoded_text = "[Model returned blank - Needs more training epochs]"
# 1. Generate CTC Probability Matrix Heatmap
probs_np = probs.cpu().numpy().T # shape: (num_classes, seq_len)
fig_heatmap, ax1 = plt.subplots(figsize=(10, 4))
cax = ax1.imshow(probs_np, aspect='auto', cmap='viridis')
ax1.set_title("CTC Probability Matrix Heatmap")
ax1.set_xlabel("Time Frame (Sequence Steps)")
ax1.set_ylabel("Vocabulary Character Index")
fig_heatmap.colorbar(cax, ax=ax1, fraction=0.046, pad=0.04, label="Probability")
plt.tight_layout()
# 2. Generate Character Confidence Bar Chart
max_probs, max_idx = torch.max(probs, dim=1)
chars = []
confidences = []
for i in range(len(max_idx)):
if max_idx[i] != 0 and (i == 0 or max_idx[i] != max_idx[i-1]):
char_idx = max_idx[i].item()
if char_idx in idx_to_char:
chars.append(idx_to_char[char_idx])
confidences.append(max_probs[i].item())
# Adjust width based on number of characters
fig_bar, ax2 = plt.subplots(figsize=(max(8, len(chars)*0.4), 4))
if chars:
bars = ax2.bar(range(len(chars)), confidences, color='#FF9900')
ax2.set_xticks(range(len(chars)))
ax2.set_xticklabels(chars)
ax2.set_ylim(0, 1.1)
ax2.set_title("Character Confidence Scores")
ax2.set_ylabel("Confidence Probability")
# Add percentage labels above bars
for bar in bars:
yval = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2.0, yval + 0.02,
f'{yval*100:.0f}%', va='bottom', ha='center', fontsize=8, rotation=45)
else:
ax2.text(0.5, 0.5, "No characters predicted", ha='center', va='center')
plt.tight_layout()
# 3. Generate CNN Feature Activation Overlay
# Average the CNN features across all channels to get a 1D activation map
activation = torch.mean(cnn_features, dim=1).squeeze().cpu().numpy()
# Normalize activation to 0-255
activation = (activation - activation.min()) / (activation.max() - activation.min() + 1e-8)
activation = (activation * 255).astype(np.uint8)
# Resize to match the original image dimensions
heatmap_img = cv2.resize(activation, (processed_img_np.shape[1], processed_img_np.shape[0]))
# Apply color map
heatmap_color = cv2.applyColorMap(heatmap_img, cv2.COLORMAP_JET)
# Convert grayscale original image to BGR so we can blend it
original_bgr = cv2.cvtColor(display_processed_img, cv2.COLOR_GRAY2BGR)
# Overlay heatmap on original image (50% alpha blend)
overlay_img = cv2.addWeighted(heatmap_color, 0.5, original_bgr, 0.5, 0)
# Convert BGR to RGB for Gradio display
overlay_img = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB)
return display_processed_img, decoded_text, fig_heatmap, fig_bar, overlay_img
# Redesign UI with Gradio Blocks for a proper Dashboard layout
with gr.Blocks(title="Handwritten Text Recognition (HTR)", theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>Handwritten Text Recognition (HTR) Dashboard</h1>")
gr.Markdown("Upload an image of handwritten text. The system will preprocess it and extract the text using our trained custom CRNN model.")
with gr.Row():
with gr.Column(scale=1):
# Editor tool allows manual cropping in UI before sending
input_image = gr.Image(type="pil", label="Upload Handwritten Text Image")
auto_crop_checkbox = gr.Checkbox(label="โœจ Auto-Crop Background (Smart Vision)", value=True, info="Automatically zooms in on the text and removes giant background objects/pens.")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="numpy", label="Preprocessed (1024 x 32)")
gr.Markdown("<p style='font-size: 12px; color: gray;'>Grayscale, aspect-ratio preserved, padded to 32x1024</p>")
output_text = gr.Textbox(label="Predicted Text", lines=2)
gr.Markdown("---")
gr.Markdown("### ๐Ÿ“Š Model Insights & Analytics (Explainable AI)")
with gr.Accordion("๐Ÿ“– How to read these graphs (Interpretation Guide)", open=False):
gr.Markdown("""
**1. CNN Feature Activation Overlay:** Shows exactly where the model's 'eyes' are focusing on the image. Red/hot areas indicate regions with strong visual features (like complex curves or sharp lines) that the Convolutional Neural Network detected.
**2. CTC Probability Matrix Heatmap:** Shows *when* the model made a decision. The X-axis is the timeline (reading left-to-right), and the Y-axis contains all possible characters. Yellow dots indicate the exact moment the AI identified a specific letter.
**3. Character Confidence Scores:** Shows *how sure* the model is about each letter it predicted. If the model misreads a word, this chart usually shows a low confidence score for the incorrect letter, proving it was uncertain.
""")
with gr.Row():
cnn_activation_image = gr.Image(type="numpy", label="1. CNN Feature Activation Overlay")
with gr.Row():
heatmap_plot = gr.Plot(label="2. CTC Probability Heatmap")
with gr.Row():
confidence_plot = gr.Plot(label="3. Character Confidence Scores")
submit_btn.click(
fn=process_and_predict,
inputs=[input_image, auto_crop_checkbox],
outputs=[output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image]
)
clear_btn.click(
fn=lambda: [None, True, None, "", None, None, None],
inputs=[],
outputs=[input_image, auto_crop_checkbox, output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image]
)
if __name__ == "__main__":
demo.launch(share=True)