Digital-Image-Processing-OCR / src /final_diagnostic.py
chiruu12
Initial commit of clean OCR application
9543569
Raw
History Blame Contribute Delete
4.82 kB
import torch
import cv2
import numpy as np
import os
from PIL import Image, ImageDraw, ImageFont
import h5py
from src.crnn_model import CRNN
from train_crnn import decode_ctc_output
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
MODEL_PATH = os.path.join(PROJECT_ROOT, "models/crnn_finetuned/crnn_book_model.pth")
DATA_FILE = os.path.join(PROJECT_ROOT, "data/line_dataset.h5")
FONT_PATH = "/System/Library/Fonts/Supplemental/Times New Roman.ttf"
IMAGE_HEIGHT = 32
FONT_SIZE = 28
TEST_LINES = [
"Praise for Applied Machine Learning and AI for Engineers",
"This book is a fantastic guide to machine learning and AI",
"the concrete examples with working code show how to take",
]
def render_perfect_line(text, font):
"""Re-creates a 'perfect' line image exactly like the training data."""
bbox = font.getbbox(text)
line_width = bbox[2] - bbox[0]
image = Image.new("L", (line_width + 10, IMAGE_HEIGHT), 255)
draw = ImageDraw.Draw(image)
draw.text((5, (IMAGE_HEIGHT - FONT_SIZE) // 2), text, font=font, fill=0)
return np.array(image)
def preprocess_for_model(line_image, is_from_scan=False):
"""Prepares an image for the model."""
if is_from_scan:
_, binary_image = cv2.threshold(line_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
image_to_process = binary_image
else:
image_to_process = line_image
h, w = image_to_process.shape
scale_factor = IMAGE_HEIGHT / h
new_w = int(w * scale_factor)
resized_image = cv2.resize(image_to_process, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA)
normalized_image = (resized_image / 255.0).astype(np.float32)
tensor = torch.from_numpy(normalized_image).unsqueeze(0).unsqueeze(0)
return tensor.to(DEVICE)
def main():
print("--- Running Final Diagnostic A/B Test ---")
print("Loading model...")
with h5py.File(DATA_FILE, 'r') as hf:
char_list = [c.decode('utf-8') for c in hf['char_list'][:]]
int_to_char = {i + 1: char for i, char in enumerate(char_list)}
model = CRNN(num_chars=len(char_list)).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
print("Model loaded.")
font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
line_coords = [
(118, 114, 1551, 44),
(118, 178, 1549, 36),
(118, 298, 1551, 35)
]
from pdf2image import convert_from_path
pdf_path = os.path.join(PROJECT_ROOT, "sample_documents/books/Applied-Machine-Learning-and-AI-for-Engineers.pdf")
pil_image = convert_from_path(pdf_path, first_page=12, last_page=12)[0]
page_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2GRAY)
with torch.no_grad():
for i, ground_truth_text in enumerate(TEST_LINES):
print("\n" + "=" * 50)
print(f"TESTING LINE {i + 1}: '{ground_truth_text}'")
print("=" * 50)
x, y, w, h = line_coords[i]
real_line_crop = page_image[y:y + h, x:x + w]
real_tensor = preprocess_for_model(real_line_crop, is_from_scan=True)
real_preds = model(real_tensor)
real_decoded_text = decode_ctc_output(real_preds, int_to_char)[0]
print(f" -> Prediction from REAL SCAN: '{real_decoded_text}'")
perfect_line_image = render_perfect_line(ground_truth_text, font)
perfect_tensor = preprocess_for_model(perfect_line_image, is_from_scan=False)
perfect_preds = model(perfect_tensor)
perfect_decoded_text = decode_ctc_output(perfect_preds, int_to_char)[0]
print(f" -> Prediction from PERFECT RENDER: '{perfect_decoded_text}'")
h_real, w_real = real_line_crop.shape
scale = perfect_line_image.shape[0] / h_real
resized_real_crop = cv2.resize(real_line_crop, (int(w_real * scale), perfect_line_image.shape[0]))
width_diff = abs(resized_real_crop.shape[1] - perfect_line_image.shape[1])
if resized_real_crop.shape[1] < perfect_line_image.shape[1]:
resized_real_crop = cv2.copyMakeBorder(resized_real_crop, 0, 0, 0, width_diff, cv2.BORDER_CONSTANT,
value=255)
else:
perfect_line_image = cv2.copyMakeBorder(perfect_line_image, 0, 0, 0, width_diff, cv2.BORDER_CONSTANT,
value=255)
comparison_image = np.vstack([resized_real_crop, perfect_line_image])
cv2.imwrite(f"diagnostic_comparison_line_{i + 1}.png", comparison_image)
print(f" -> Saved visual evidence to 'diagnostic_comparison_line_{i + 1}.png'")
if __name__ == "__main__":
main()