|
|
import json
|
|
|
import cv2
|
|
|
import torch
|
|
|
import torchvision.transforms.functional as TF
|
|
|
import matplotlib.pyplot as plt
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_vocab(vocab_path):
|
|
|
with open(vocab_path, "r", encoding="utf-8") as f:
|
|
|
vocab = json.load(f)
|
|
|
char_to_idx = vocab["char_to_idx"]
|
|
|
idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
|
|
|
return char_to_idx, idx_to_char
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def greedy_decode(output, idx_to_char):
|
|
|
output = output.argmax(2)
|
|
|
texts = []
|
|
|
for seq in output:
|
|
|
prev = -1
|
|
|
chars = []
|
|
|
for idx in seq.cpu().numpy():
|
|
|
if idx != prev and idx != 0:
|
|
|
chars.append(idx_to_char.get(idx, ""))
|
|
|
prev = idx
|
|
|
texts.append("".join(chars))
|
|
|
return texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OCRTestTransform:
|
|
|
def __init__(self, img_height=64, max_width=1600):
|
|
|
self.img_height = img_height
|
|
|
self.max_width = max_width
|
|
|
def __call__(self, img):
|
|
|
img = img.convert("L")
|
|
|
w, h = img.size
|
|
|
new_w = int(w * self.img_height / h)
|
|
|
img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
|
|
|
new_img = Image.new("L", (self.max_width, self.img_height), 255)
|
|
|
new_img.paste(img, (0, 0))
|
|
|
img = TF.to_tensor(new_img)
|
|
|
img = TF.normalize(img, (0.5,), (0.5,))
|
|
|
return img
|
|
|
|
|
|
transform_test = OCRTestTransform()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
|
|
|
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
|
|
_, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
|
|
|
morphed = cv2.dilate(binary, kernel, iterations=1)
|
|
|
contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
|
|
|
lines = []
|
|
|
for ctr in contours:
|
|
|
x, y, w, h = cv2.boundingRect(ctr)
|
|
|
if h < min_line_height: continue
|
|
|
y1 = max(0, y - margin)
|
|
|
y2 = min(img.shape[0], y + h + margin)
|
|
|
line_img = img[y1:y2, x:x+w]
|
|
|
lines.append(Image.fromarray(line_img))
|
|
|
if visualize:
|
|
|
for i, line_img in enumerate(lines):
|
|
|
plt.figure(figsize=(12,2))
|
|
|
plt.imshow(line_img, cmap='gray')
|
|
|
plt.axis('off')
|
|
|
plt.title(f"Line {i+1}")
|
|
|
plt.show()
|
|
|
return lines
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ocr_page(image_path, model, idx_to_char, visualize=False):
|
|
|
lines = segment_lines_precise(image_path, visualize=visualize)
|
|
|
all_texts = []
|
|
|
for idx, line_img in enumerate(lines, 1):
|
|
|
img_tensor = transform_test(line_img).unsqueeze(0).to(device)
|
|
|
with torch.no_grad():
|
|
|
outputs = model(img_tensor)
|
|
|
pred_text = greedy_decode(outputs, idx_to_char)[0]
|
|
|
all_texts.append(pred_text)
|
|
|
print(f"Line {idx}: {pred_text}")
|
|
|
return "\n".join(all_texts)
|
|
|
|