File size: 3,577 Bytes
1aaa7f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
import cv2
import torch
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from PIL import Image

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load vocab
# -----------------------------
def load_vocab(vocab_path):
    with open(vocab_path, "r", encoding="utf-8") as f:
        vocab = json.load(f)
    char_to_idx = vocab["char_to_idx"]
    idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
    return char_to_idx, idx_to_char

# -----------------------------
# Greedy decoder
# -----------------------------
def greedy_decode(output, idx_to_char):
    output = output.argmax(2)
    texts = []
    for seq in output:
        prev = -1
        chars = []
        for idx in seq.cpu().numpy():
            if idx != prev and idx != 0:
                chars.append(idx_to_char.get(idx, ""))
            prev = idx
        texts.append("".join(chars))
    return texts

# -----------------------------
# Transforms
# -----------------------------
class OCRTestTransform:
    def __init__(self, img_height=64, max_width=1600):
        self.img_height = img_height
        self.max_width = max_width
    def __call__(self, img):
        img = img.convert("L")
        w, h = img.size
        new_w = int(w * self.img_height / h)
        img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
        new_img = Image.new("L", (self.max_width, self.img_height), 255)
        new_img.paste(img, (0, 0))
        img = TF.to_tensor(new_img)
        img = TF.normalize(img, (0.5,), (0.5,))
        return img

transform_test = OCRTestTransform()

# -----------------------------
# Line segmentation
# -----------------------------
def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
    morphed = cv2.dilate(binary, kernel, iterations=1)
    contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
    lines = []
    for ctr in contours:
        x, y, w, h = cv2.boundingRect(ctr)
        if h < min_line_height: continue
        y1 = max(0, y - margin)
        y2 = min(img.shape[0], y + h + margin)
        line_img = img[y1:y2, x:x+w]
        lines.append(Image.fromarray(line_img))
    if visualize:
        for i, line_img in enumerate(lines):
            plt.figure(figsize=(12,2))
            plt.imshow(line_img, cmap='gray')
            plt.axis('off')
            plt.title(f"Line {i+1}")
            plt.show()
    return lines

# -----------------------------
# OCR function
# -----------------------------
def ocr_page(image_path, model, idx_to_char, visualize=False):
    lines = segment_lines_precise(image_path, visualize=visualize)
    all_texts = []
    for idx, line_img in enumerate(lines, 1):
        img_tensor = transform_test(line_img).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(img_tensor)
        pred_text = greedy_decode(outputs, idx_to_char)[0]
        all_texts.append(pred_text)
        print(f"Line {idx}: {pred_text}")
    return "\n".join(all_texts)