| import torch |
| import cv2 |
| import numpy as np |
| import os |
| from PIL import Image, ImageDraw, ImageFont |
| import h5py |
|
|
| from src.crnn_model import CRNN |
| from train_crnn import decode_ctc_output |
|
|
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
| DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| MODEL_PATH = os.path.join(PROJECT_ROOT, "models/crnn_finetuned/crnn_book_model.pth") |
| DATA_FILE = os.path.join(PROJECT_ROOT, "data/line_dataset.h5") |
| FONT_PATH = "/System/Library/Fonts/Supplemental/Times New Roman.ttf" |
|
|
| IMAGE_HEIGHT = 32 |
| FONT_SIZE = 28 |
|
|
| TEST_LINES = [ |
| "Praise for Applied Machine Learning and AI for Engineers", |
| "This book is a fantastic guide to machine learning and AI", |
| "the concrete examples with working code show how to take", |
| ] |
|
|
|
|
|
|
| def render_perfect_line(text, font): |
| """Re-creates a 'perfect' line image exactly like the training data.""" |
| bbox = font.getbbox(text) |
| line_width = bbox[2] - bbox[0] |
| image = Image.new("L", (line_width + 10, IMAGE_HEIGHT), 255) |
| draw = ImageDraw.Draw(image) |
| draw.text((5, (IMAGE_HEIGHT - FONT_SIZE) // 2), text, font=font, fill=0) |
| return np.array(image) |
|
|
|
|
| def preprocess_for_model(line_image, is_from_scan=False): |
| """Prepares an image for the model.""" |
| if is_from_scan: |
| _, binary_image = cv2.threshold(line_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| image_to_process = binary_image |
| else: |
| image_to_process = line_image |
|
|
| h, w = image_to_process.shape |
| scale_factor = IMAGE_HEIGHT / h |
| new_w = int(w * scale_factor) |
| resized_image = cv2.resize(image_to_process, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA) |
|
|
| normalized_image = (resized_image / 255.0).astype(np.float32) |
| tensor = torch.from_numpy(normalized_image).unsqueeze(0).unsqueeze(0) |
| return tensor.to(DEVICE) |
|
|
|
|
| def main(): |
| print("--- Running Final Diagnostic A/B Test ---") |
|
|
| print("Loading model...") |
| with h5py.File(DATA_FILE, 'r') as hf: |
| char_list = [c.decode('utf-8') for c in hf['char_list'][:]] |
| int_to_char = {i + 1: char for i, char in enumerate(char_list)} |
| model = CRNN(num_chars=len(char_list)).to(DEVICE) |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) |
| model.eval() |
| print("Model loaded.") |
| font = ImageFont.truetype(FONT_PATH, FONT_SIZE) |
|
|
| line_coords = [ |
| (118, 114, 1551, 44), |
| (118, 178, 1549, 36), |
| (118, 298, 1551, 35) |
| ] |
|
|
| from pdf2image import convert_from_path |
| pdf_path = os.path.join(PROJECT_ROOT, "sample_documents/books/Applied-Machine-Learning-and-AI-for-Engineers.pdf") |
| pil_image = convert_from_path(pdf_path, first_page=12, last_page=12)[0] |
| page_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2GRAY) |
|
|
| with torch.no_grad(): |
| for i, ground_truth_text in enumerate(TEST_LINES): |
| print("\n" + "=" * 50) |
| print(f"TESTING LINE {i + 1}: '{ground_truth_text}'") |
| print("=" * 50) |
|
|
| x, y, w, h = line_coords[i] |
| real_line_crop = page_image[y:y + h, x:x + w] |
| real_tensor = preprocess_for_model(real_line_crop, is_from_scan=True) |
| real_preds = model(real_tensor) |
| real_decoded_text = decode_ctc_output(real_preds, int_to_char)[0] |
|
|
| print(f" -> Prediction from REAL SCAN: '{real_decoded_text}'") |
|
|
| perfect_line_image = render_perfect_line(ground_truth_text, font) |
| perfect_tensor = preprocess_for_model(perfect_line_image, is_from_scan=False) |
| perfect_preds = model(perfect_tensor) |
| perfect_decoded_text = decode_ctc_output(perfect_preds, int_to_char)[0] |
|
|
| print(f" -> Prediction from PERFECT RENDER: '{perfect_decoded_text}'") |
|
|
| h_real, w_real = real_line_crop.shape |
| scale = perfect_line_image.shape[0] / h_real |
| resized_real_crop = cv2.resize(real_line_crop, (int(w_real * scale), perfect_line_image.shape[0])) |
|
|
| width_diff = abs(resized_real_crop.shape[1] - perfect_line_image.shape[1]) |
| if resized_real_crop.shape[1] < perfect_line_image.shape[1]: |
| resized_real_crop = cv2.copyMakeBorder(resized_real_crop, 0, 0, 0, width_diff, cv2.BORDER_CONSTANT, |
| value=255) |
| else: |
| perfect_line_image = cv2.copyMakeBorder(perfect_line_image, 0, 0, 0, width_diff, cv2.BORDER_CONSTANT, |
| value=255) |
|
|
| comparison_image = np.vstack([resized_real_crop, perfect_line_image]) |
| cv2.imwrite(f"diagnostic_comparison_line_{i + 1}.png", comparison_image) |
| print(f" -> Saved visual evidence to 'diagnostic_comparison_line_{i + 1}.png'") |
|
|
|
|
| if __name__ == "__main__": |
| main() |