| import gradio as gr |
| import cv2 |
| import numpy as np |
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import pandas as pd |
| import sys |
| import os |
| import matplotlib.pyplot as plt |
|
|
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) |
| from src.utils.preprocessing import preprocess_image, deskew |
| from src.models.crnn import CRNN |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| try: |
| df = pd.read_csv('data/labels.csv') |
| chars = set() |
| for text in df['text']: |
| if pd.notna(text): |
| chars.update(list(str(text))) |
| vocab = sorted(list(chars)) |
| idx_to_char = {i+1: c for i, c in enumerate(vocab)} |
| num_classes = len(vocab) + 1 |
| print(f"Loaded vocabulary with {len(vocab)} characters") |
| except Exception as e: |
| print(f"Could not load vocabulary from labels.csv: {e}") |
| |
| vocab = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!? ") |
| idx_to_char = {i+1: c for i, c in enumerate(vocab)} |
| num_classes = len(vocab) + 1 |
|
|
| |
| model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device) |
|
|
| import glob |
| def get_latest_checkpoint(weights_dir='weights'): |
| checkpoints = glob.glob(os.path.join(weights_dir, 'crnn_baseline_epoch_*.pth')) |
| if not checkpoints: |
| return None |
| |
| checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) |
| return checkpoints[-1] |
|
|
| weights_path = get_latest_checkpoint() |
| if weights_path and os.path.exists(weights_path): |
| print(f"Loading trained weights from {weights_path}...") |
| try: |
| model.load_state_dict(torch.load(weights_path, map_location=device)) |
| except Exception as e: |
| print(f"Error loading weights perfectly (might be minor mismatch): {e}") |
| model.load_state_dict(torch.load(weights_path, map_location=device), strict=False) |
| else: |
| print(f"Warning: Could not find any weights in weights/. Model will output random predictions.") |
|
|
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((32, 1024)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,), (0.5,)) |
| ]) |
|
|
| def decode_predictions(preds, idx_to_char): |
| _, max_preds = torch.max(preds, 2) |
| max_preds = max_preds.permute(1, 0) |
| |
| decoded_texts = [] |
| for batch_idx in range(max_preds.size(0)): |
| pred_seq = max_preds[batch_idx] |
| decoded_seq = [] |
| for i in range(len(pred_seq)): |
| if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]): |
| char_idx = pred_seq[i].item() |
| if char_idx in idx_to_char: |
| decoded_seq.append(idx_to_char[char_idx]) |
| decoded_texts.append("".join(decoded_seq)) |
| return decoded_texts |
|
|
| def auto_crop_image(gray_img): |
| |
| blurred = cv2.GaussianBlur(gray_img, (5, 5), 0) |
| |
| |
| _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
| |
| |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if not contours: |
| return gray_img |
| |
| |
| img_area = gray_img.shape[0] * gray_img.shape[1] |
| valid_contours = [] |
| for c in contours: |
| area = cv2.contourArea(c) |
| |
| if 20 < area < (img_area * 0.4): |
| valid_contours.append(c) |
| |
| if not valid_contours: |
| return gray_img |
| |
| |
| x_min, y_min = float('inf'), float('inf') |
| x_max, y_max = 0, 0 |
| |
| for c in valid_contours: |
| x, y, w, h = cv2.boundingRect(c) |
| x_min = min(x_min, x) |
| y_min = min(y_min, y) |
| x_max = max(x_max, x + w) |
| y_max = max(y_max, y + h) |
| |
| |
| pad_y = int((y_max - y_min) * 0.2) |
| pad_x = int((x_max - x_min) * 0.05) |
| |
| x_min = max(0, x_min - pad_x) |
| y_min = max(0, y_min - pad_y) |
| x_max = min(gray_img.shape[1], x_max + pad_x) |
| y_max = min(gray_img.shape[0], y_max + pad_y) |
| |
| |
| cropped = gray_img[y_min:y_max, x_min:x_max] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| h, w = cropped.shape |
| target_aspect_ratio = 16.0 |
| if w / h < target_aspect_ratio: |
| target_w = int(h * target_aspect_ratio) |
| pad_width = target_w - w |
| |
| cropped = cv2.copyMakeBorder(cropped, 0, 0, 0, pad_width, cv2.BORDER_CONSTANT, value=255) |
| |
| return cropped |
|
|
| def process_and_predict(image, apply_auto_crop=True): |
| if image is None: |
| return None, "Please upload an image.", None, None, None |
|
|
| |
| if not isinstance(image, Image.Image): |
| image = Image.fromarray(image) |
| |
| gray_image = image.convert('L') |
| |
| |
| img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| gray_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) |
| |
| |
| |
| |
| blurred = cv2.GaussianBlur(gray_cv, (5, 5), 0) |
| _, binarized = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| if not apply_auto_crop: |
| |
| |
| gray_image_pil = Image.fromarray(gray_cv) |
| img_tensor = transform(gray_image_pil).unsqueeze(0).to(device) |
| |
| display_processed_img = np.array(gray_image_pil.resize((1024, 32), Image.BILINEAR)) |
| else: |
| |
| processed_base = auto_crop_image(binarized) |
| |
| deskewed_img = deskew(processed_base) |
| processed_img_np = preprocess_image(deskewed_img, target_size=(1024, 32)) |
| display_processed_img = processed_img_np |
| |
| |
| gray_image_cropped = Image.fromarray(display_processed_img) |
| |
| |
| |
| img_tensor = transform(gray_image_cropped).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| |
| cnn_features = model.cnn(img_tensor) |
| |
| preds = model(img_tensor) |
| preds = preds.permute(1, 0, 2) |
| decoded_text = decode_predictions(preds, idx_to_char)[0] |
| |
| |
| probs = torch.exp(preds[:, 0, :]) |
| |
| if not decoded_text.strip(): |
| decoded_text = "[Model returned blank - Needs more training epochs]" |
|
|
| |
| probs_np = probs.cpu().numpy().T |
| fig_heatmap, ax1 = plt.subplots(figsize=(10, 4)) |
| cax = ax1.imshow(probs_np, aspect='auto', cmap='viridis') |
| ax1.set_title("CTC Probability Matrix Heatmap") |
| ax1.set_xlabel("Time Frame (Sequence Steps)") |
| ax1.set_ylabel("Vocabulary Character Index") |
| fig_heatmap.colorbar(cax, ax=ax1, fraction=0.046, pad=0.04, label="Probability") |
| plt.tight_layout() |
|
|
| |
| max_probs, max_idx = torch.max(probs, dim=1) |
| chars = [] |
| confidences = [] |
| |
| for i in range(len(max_idx)): |
| if max_idx[i] != 0 and (i == 0 or max_idx[i] != max_idx[i-1]): |
| char_idx = max_idx[i].item() |
| if char_idx in idx_to_char: |
| chars.append(idx_to_char[char_idx]) |
| confidences.append(max_probs[i].item()) |
|
|
| |
| fig_bar, ax2 = plt.subplots(figsize=(max(8, len(chars)*0.4), 4)) |
| if chars: |
| bars = ax2.bar(range(len(chars)), confidences, color='#FF9900') |
| ax2.set_xticks(range(len(chars))) |
| ax2.set_xticklabels(chars) |
| ax2.set_ylim(0, 1.1) |
| ax2.set_title("Character Confidence Scores") |
| ax2.set_ylabel("Confidence Probability") |
| |
| |
| for bar in bars: |
| yval = bar.get_height() |
| ax2.text(bar.get_x() + bar.get_width()/2.0, yval + 0.02, |
| f'{yval*100:.0f}%', va='bottom', ha='center', fontsize=8, rotation=45) |
| else: |
| ax2.text(0.5, 0.5, "No characters predicted", ha='center', va='center') |
| |
| plt.tight_layout() |
| |
| |
| |
| activation = torch.mean(cnn_features, dim=1).squeeze().cpu().numpy() |
| |
| |
| activation = (activation - activation.min()) / (activation.max() - activation.min() + 1e-8) |
| activation = (activation * 255).astype(np.uint8) |
| |
| |
| heatmap_img = cv2.resize(activation, (processed_img_np.shape[1], processed_img_np.shape[0])) |
| |
| |
| heatmap_color = cv2.applyColorMap(heatmap_img, cv2.COLORMAP_JET) |
| |
| |
| original_bgr = cv2.cvtColor(display_processed_img, cv2.COLOR_GRAY2BGR) |
| |
| |
| overlay_img = cv2.addWeighted(heatmap_color, 0.5, original_bgr, 0.5, 0) |
| |
| overlay_img = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB) |
| |
| return display_processed_img, decoded_text, fig_heatmap, fig_bar, overlay_img |
|
|
| |
| with gr.Blocks(title="Handwritten Text Recognition (HTR)", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("<h1 style='text-align: center;'>Handwritten Text Recognition (HTR) Dashboard</h1>") |
| gr.Markdown("Upload an image of handwritten text. The system will preprocess it and extract the text using our trained custom CRNN model.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| input_image = gr.Image(type="pil", label="Upload Handwritten Text Image") |
| auto_crop_checkbox = gr.Checkbox(label="โจ Auto-Crop Background (Smart Vision)", value=True, info="Automatically zooms in on the text and removes giant background objects/pens.") |
| with gr.Row(): |
| clear_btn = gr.Button("Clear") |
| submit_btn = gr.Button("Submit", variant="primary") |
| |
| with gr.Column(scale=1): |
| output_image = gr.Image(type="numpy", label="Preprocessed (1024 x 32)") |
| gr.Markdown("<p style='font-size: 12px; color: gray;'>Grayscale, aspect-ratio preserved, padded to 32x1024</p>") |
| output_text = gr.Textbox(label="Predicted Text", lines=2) |
| |
| gr.Markdown("---") |
| gr.Markdown("### ๐ Model Insights & Analytics (Explainable AI)") |
| |
| with gr.Accordion("๐ How to read these graphs (Interpretation Guide)", open=False): |
| gr.Markdown(""" |
| **1. CNN Feature Activation Overlay:** Shows exactly where the model's 'eyes' are focusing on the image. Red/hot areas indicate regions with strong visual features (like complex curves or sharp lines) that the Convolutional Neural Network detected. |
| |
| **2. CTC Probability Matrix Heatmap:** Shows *when* the model made a decision. The X-axis is the timeline (reading left-to-right), and the Y-axis contains all possible characters. Yellow dots indicate the exact moment the AI identified a specific letter. |
| |
| **3. Character Confidence Scores:** Shows *how sure* the model is about each letter it predicted. If the model misreads a word, this chart usually shows a low confidence score for the incorrect letter, proving it was uncertain. |
| """) |
| |
| with gr.Row(): |
| cnn_activation_image = gr.Image(type="numpy", label="1. CNN Feature Activation Overlay") |
| |
| with gr.Row(): |
| heatmap_plot = gr.Plot(label="2. CTC Probability Heatmap") |
| |
| with gr.Row(): |
| confidence_plot = gr.Plot(label="3. Character Confidence Scores") |
| |
| submit_btn.click( |
| fn=process_and_predict, |
| inputs=[input_image, auto_crop_checkbox], |
| outputs=[output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image] |
| ) |
| |
| clear_btn.click( |
| fn=lambda: [None, True, None, "", None, None, None], |
| inputs=[], |
| outputs=[input_image, auto_crop_checkbox, output_image, output_text, heatmap_plot, confidence_plot, cnn_activation_image] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |