| """ |
| Whole-page OCR inference for Ukrainian handwritten text using TrOCR. |
| |
| This script performs line segmentation and transcription on unsegmented page images. |
| |
| Usage: |
| # Basic usage with checkpoint |
| python inference_page.py --image path/to/page.jpg --checkpoint models/ukrainian_model/checkpoint-3000 |
| |
| # With custom settings |
| python inference_page.py --image page.jpg --checkpoint checkpoint-3000 --num_beams 4 --output output.txt |
| |
| # With Transkribus PAGE XML (uses existing segmentation) |
| python inference_page.py --image page.jpg --xml page.xml --checkpoint checkpoint-3000 |
| |
| Future: Can be extended with a GUI using tkinter or PyQt. |
| """ |
|
|
| import argparse |
| import torch |
| from pathlib import Path |
| from PIL import Image, ImageDraw |
| import numpy as np |
| from typing import List, Tuple, Optional |
| import xml.etree.ElementTree as ET |
| from dataclasses import dataclass |
| import cv2 |
|
|
| |
| Image.MAX_IMAGE_PIXELS = None |
|
|
| from transformers import VisionEncoderDecoderModel, TrOCRProcessor |
|
|
|
|
| @dataclass |
| class LineSegment: |
| """Represents a segmented text line.""" |
| image: Image.Image |
| bbox: Tuple[int, int, int, int] |
| coords: Optional[List[Tuple[int, int]]] = None |
| text: Optional[str] = None |
| confidence: Optional[float] = None |
| char_confidences: Optional[List[float]] = None |
|
|
|
|
| def sort_lines_by_region(regions, lines): |
| """ |
| Sort lines in reading order: regions left-to-right, lines top-to-bottom |
| within each region. |
| |
| Works with SegRegion objects from kraken_segmenter (which carry bbox and |
| line_ids) and any list of line-like objects that have a ``.bbox`` attribute |
| with (x1, y1, x2, y2) format. |
| |
| Args: |
| regions: List of SegRegion (from kraken_segmenter) with .bbox and .line_ids. |
| If empty/None, lines are returned sorted top-to-bottom. |
| lines: List of LineSegment (or kraken LineSegment). |
| |
| Returns: |
| List of lines re-ordered by region reading order. |
| """ |
| if not regions or not lines: |
| |
| return sorted(lines, key=lambda l: l.bbox[1]) |
|
|
| |
| sorted_regions = sorted( |
| regions, |
| key=lambda r: (r.bbox[0] + r.bbox[2]) / 2, |
| ) |
|
|
| |
| region_buckets = {r.id: [] for r in sorted_regions} |
| unassigned = [] |
|
|
| for line in lines: |
| cx = (line.bbox[0] + line.bbox[2]) / 2 |
| cy = (line.bbox[1] + line.bbox[3]) / 2 |
| assigned = False |
| for r in sorted_regions: |
| rx1, ry1, rx2, ry2 = r.bbox |
| if rx1 <= cx <= rx2 and ry1 <= cy <= ry2: |
| region_buckets[r.id].append(line) |
| assigned = True |
| break |
| if not assigned: |
| unassigned.append(line) |
|
|
| |
| ordered = [] |
| for r in sorted_regions: |
| bucket = region_buckets[r.id] |
| bucket.sort(key=lambda l: l.bbox[1]) |
| ordered.extend(bucket) |
|
|
| unassigned.sort(key=lambda l: l.bbox[1]) |
| ordered.extend(unassigned) |
| return ordered |
|
|
|
|
| def normalize_background(image: Image.Image) -> Image.Image: |
| """ |
| Normalize background to light gray (similar to Efendiev dataset). |
| |
| CRITICAL for Ukrainian dataset: Models trained on data with background |
| normalization MUST have normalization applied at inference time as well. |
| |
| Args: |
| image: PIL Image with potentially aged/colored background |
| |
| Returns: |
| PIL Image with normalized background |
| """ |
| |
| img_array = np.array(image) |
|
|
| |
| lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB) |
| l, a, b = cv2.split(lab) |
|
|
| |
| |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| l_normalized = clahe.apply(l) |
|
|
| |
| lab_normalized = cv2.merge([l_normalized, a, b]) |
| rgb_normalized = cv2.cvtColor(lab_normalized, cv2.COLOR_LAB2RGB) |
|
|
| |
| gray = cv2.cvtColor(rgb_normalized, cv2.COLOR_RGB2GRAY) |
|
|
| |
| |
| normalized_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) |
|
|
| return Image.fromarray(normalized_rgb) |
|
|
|
|
| class LineSegmenter: |
| """Improved line segmentation using horizontal projection with multiple strategies.""" |
|
|
| def __init__(self, min_line_height: int = 15, min_gap: int = 5, |
| sensitivity: float = 0.02, use_morph: bool = True): |
| """ |
| Initialize LineSegmenter. |
| |
| Args: |
| min_line_height: Minimum height of a line in pixels (default: 15, lowered for tighter spacing) |
| min_gap: Minimum gap between lines in pixels (default: 5, lowered for tight spacing) |
| sensitivity: Threshold for detecting text (0.01-0.1, lower = more sensitive, default: 0.02) |
| use_morph: Apply morphological operations to clean up detection (default: True) |
| """ |
| self.min_line_height = min_line_height |
| self.min_gap = min_gap |
| self.sensitivity = sensitivity |
| self.use_morph = use_morph |
|
|
| def segment_lines(self, image: Image.Image, debug: bool = False) -> List[LineSegment]: |
| """ |
| Segment page image into text lines using horizontal projection. |
| |
| Improved algorithm: |
| 1. Multiple binarization strategies (Otsu + Sauvola for different scripts) |
| 2. Morphological operations to connect broken text |
| 3. Lower sensitivity threshold for tight line spacing |
| 4. Smart gap detection based on local context |
| |
| Args: |
| image: Input page image (PIL Image) |
| debug: If True, visualize segmentation |
| |
| Returns: |
| List of LineSegment objects |
| """ |
| |
| gray = np.array(image.convert('L')) |
|
|
| |
| from scipy.ndimage import gaussian_filter |
| blurred = gaussian_filter(gray, sigma=1.0) |
|
|
| |
| threshold_otsu = self._otsu_threshold(blurred) |
| binary_otsu = blurred < threshold_otsu |
|
|
| |
| binary_adaptive = self._adaptive_threshold(gray) |
|
|
| |
| binary = np.logical_or(binary_otsu, binary_adaptive) |
|
|
| |
| if self.use_morph: |
| from scipy.ndimage import binary_closing |
| |
| struct = np.ones((3, 5)) |
| binary = binary_closing(binary, structure=struct, iterations=2) |
|
|
| |
| h_projection = binary.sum(axis=1) |
|
|
| |
| |
| if h_projection.max() > 0: |
| threshold = h_projection.max() * self.sensitivity |
| else: |
| |
| threshold = 1 |
|
|
| is_text = h_projection > threshold |
|
|
| |
| from scipy.ndimage import median_filter |
| is_text_smoothed = median_filter(is_text.astype(float), size=3) > 0.5 |
|
|
| |
| lines = [] |
| in_line = False |
| start_y = 0 |
| gap_count = 0 |
|
|
| for y in range(len(is_text_smoothed)): |
| if is_text_smoothed[y]: |
| if not in_line: |
| |
| start_y = y |
| in_line = True |
| gap_count = 0 |
| else: |
| |
| gap_count = 0 |
| else: |
| if in_line: |
| |
| gap_count += 1 |
| if gap_count >= self.min_gap: |
| |
| end_y = y - gap_count |
| if end_y - start_y >= self.min_line_height: |
| lines.append((start_y, end_y)) |
| in_line = False |
| gap_count = 0 |
|
|
| |
| if in_line and len(is_text_smoothed) - start_y >= self.min_line_height: |
| lines.append((start_y, len(is_text_smoothed))) |
|
|
| |
| merged_lines = self._merge_close_lines(lines, max_gap=self.min_gap * 2) |
|
|
| |
| segments = [] |
| width = image.width |
|
|
| for y1, y2 in merged_lines: |
| |
| padding = 8 |
| y1_pad = max(0, y1 - padding) |
| y2_pad = min(image.height, y2 + padding) |
|
|
| |
| bbox = (0, y1_pad, width, y2_pad) |
| line_img = image.crop(bbox) |
|
|
| segments.append(LineSegment( |
| image=line_img, |
| bbox=bbox |
| )) |
|
|
| if debug: |
| self._visualize_segmentation(image, segments, h_projection) |
|
|
| print(f"[LineSegmenter] Detected {len(segments)} lines (sensitivity={self.sensitivity}, min_height={self.min_line_height})") |
|
|
| return segments |
|
|
| @staticmethod |
| def _adaptive_threshold(gray: np.ndarray, block_size: int = 35) -> np.ndarray: |
| """ |
| Apply adaptive thresholding using a local window. |
| Better for images with varying illumination or contrast. |
| """ |
| |
| try: |
| import cv2 |
| |
| binary = cv2.adaptiveThreshold( |
| gray.astype(np.uint8), |
| 255, |
| cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
| cv2.THRESH_BINARY_INV, |
| block_size, |
| 10 |
| ) |
| return binary > 0 |
| except: |
| |
| threshold = np.mean(gray) - np.std(gray) * 0.5 |
| return gray < threshold |
|
|
| @staticmethod |
| def _merge_close_lines(lines: List[Tuple[int, int]], max_gap: int = 10) -> List[Tuple[int, int]]: |
| """Merge lines that are very close together (likely one line split incorrectly).""" |
| if not lines: |
| return lines |
|
|
| merged = [lines[0]] |
| for y1, y2 in lines[1:]: |
| prev_y1, prev_y2 = merged[-1] |
| gap = y1 - prev_y2 |
|
|
| if gap <= max_gap: |
| |
| merged[-1] = (prev_y1, y2) |
| else: |
| |
| merged.append((y1, y2)) |
|
|
| return merged |
|
|
| @staticmethod |
| def _otsu_threshold(gray_array: np.ndarray) -> float: |
| """Compute Otsu's threshold.""" |
| hist, bin_edges = np.histogram(gray_array, bins=256, range=(0, 256)) |
| hist = hist.astype(float) |
|
|
| |
| hist /= hist.sum() |
|
|
| |
| weight1 = np.cumsum(hist) |
| weight2 = np.cumsum(hist[::-1])[::-1] |
|
|
| |
| mean1 = np.cumsum(hist * np.arange(256)) |
| mean2 = (np.cumsum((hist * np.arange(256))[::-1])[::-1]) |
|
|
| |
| weight1 = np.clip(weight1, 1e-10, 1) |
| weight2 = np.clip(weight2, 1e-10, 1) |
|
|
| |
| variance = weight1 * weight2 * ((mean1 / weight1) - (mean2 / weight2)) ** 2 |
|
|
| return np.argmax(variance) |
|
|
| @staticmethod |
| def _visualize_segmentation(image: Image.Image, segments: List[LineSegment], |
| h_projection: Optional[np.ndarray] = None): |
| """Visualize line segmentation for debugging.""" |
| vis = image.copy() |
| draw = ImageDraw.Draw(vis) |
|
|
| for i, seg in enumerate(segments): |
| x1, y1, x2, y2 = seg.bbox |
| |
| color = 'red' if i % 2 == 0 else 'blue' |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=2) |
| draw.text((x1 + 5, y1 + 5), f"Line {i+1}", fill=color) |
|
|
| vis.show() |
|
|
| |
| if h_projection is not None: |
| import matplotlib.pyplot as plt |
| plt.figure(figsize=(12, 4)) |
| plt.plot(h_projection) |
| plt.title("Horizontal Projection Profile") |
| plt.xlabel("Y Position") |
| plt.ylabel("Text Density") |
| plt.grid(True) |
| plt.show() |
|
|
|
|
| class PageXMLSegmenter: |
| """Segment using existing Transkribus PAGE XML annotations.""" |
|
|
| NS = {'page': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15'} |
|
|
| def __init__(self, xml_path: str): |
| self.xml_path = Path(xml_path) |
|
|
| def segment_lines(self, image: Image.Image) -> List[LineSegment]: |
| """Extract lines using PAGE XML coordinates with correct reading order.""" |
| tree = ET.parse(self.xml_path) |
| root = tree.getroot() |
|
|
| |
| |
| ns = self.NS |
| |
| page_elem = root.find('.//page:Page', ns) |
| if page_elem is None: |
| ns_2019 = {'page': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'} |
| page_elem = root.find('.//page:Page', ns_2019) |
| if page_elem is not None: |
| ns = ns_2019 |
| xml_w = int(page_elem.get('imageWidth', image.width)) if page_elem is not None else image.width |
| xml_h = int(page_elem.get('imageHeight', image.height)) if page_elem is not None else image.height |
| scale_x = image.width / xml_w if xml_w > 0 else 1.0 |
| scale_y = image.height / xml_h if xml_h > 0 else 1.0 |
|
|
| |
| self.region_data: list = [] |
|
|
| |
| regions_with_order = [] |
|
|
| for region in root.findall('.//page:TextRegion', ns): |
| |
| region_order = self._extract_reading_order(region.get('custom', '')) |
|
|
| |
| region_y = self._get_region_y_position(region, ns) |
|
|
| |
| lines_with_order = [] |
|
|
| for text_line in region.findall('.//page:TextLine', ns): |
| |
| coords_elem = text_line.find('page:Coords', ns) |
| if coords_elem is None: |
| continue |
|
|
| coords_str = coords_elem.get('points') |
| if not coords_str: |
| continue |
|
|
| |
| coords = self._parse_coords(coords_str) |
| if scale_x != 1.0 or scale_y != 1.0: |
| coords = [(int(x * scale_x), int(y * scale_y)) for x, y in coords] |
| x1, y1, x2, y2 = self._get_bounding_box(coords) |
|
|
| |
| padding = 5 |
| x1_pad = max(0, x1 - padding) |
| y1_pad = max(0, y1 - padding) |
| x2_pad = min(image.width, x2 + padding) |
| y2_pad = min(image.height, y2 + padding) |
|
|
| bbox = (x1_pad, y1_pad, x2_pad, y2_pad) |
| line_img = image.crop(bbox) |
|
|
| segment = LineSegment( |
| image=line_img, |
| bbox=bbox, |
| coords=coords |
| ) |
|
|
| |
| line_order = self._extract_reading_order(text_line.get('custom', '')) |
|
|
| |
| sort_key = line_order if line_order is not None else y1 |
| lines_with_order.append((sort_key, segment)) |
|
|
| |
| lines_with_order.sort(key=lambda x: x[0]) |
| sorted_lines = [seg for _, seg in lines_with_order] |
|
|
| |
| region_id = region.get('id', f'region_{len(regions_with_order)}') |
| region_coords_elem = region.find('page:Coords', ns) |
| if region_coords_elem is not None: |
| rc_str = region_coords_elem.get('points', '') |
| if rc_str: |
| rc = self._parse_coords(rc_str) |
| if scale_x != 1.0 or scale_y != 1.0: |
| rc = [(int(x * scale_x), int(y * scale_y)) for x, y in rc] |
| rx1, ry1, rx2, ry2 = self._get_bounding_box(rc) |
| self.region_data.append({ |
| "id": region_id, |
| "bbox": [rx1, ry1, rx2, ry2], |
| "num_lines": len(sorted_lines), |
| }) |
|
|
| |
| region_sort_key = region_order if region_order is not None else region_y |
| regions_with_order.append((region_sort_key, sorted_lines)) |
|
|
| |
| regions_with_order.sort(key=lambda x: x[0]) |
|
|
| |
| segments = [] |
| for _, region_lines in regions_with_order: |
| segments.extend(region_lines) |
|
|
| return segments |
|
|
| @staticmethod |
| def _extract_reading_order(custom_attr: str) -> Optional[int]: |
| """Extract reading order index from custom attribute. |
| |
| Format: custom="readingOrder {index:5;}" |
| Returns: 5 (or None if not found/parseable) |
| """ |
| if not custom_attr or 'readingOrder' not in custom_attr: |
| return None |
|
|
| try: |
| |
| start = custom_attr.index('index:') + 6 |
| end = custom_attr.index(';', start) |
| return int(custom_attr[start:end]) |
| except (ValueError, IndexError): |
| return None |
|
|
| def _get_region_y_position(self, region, ns=None) -> int: |
| """Get Y position of region for fallback sorting. |
| |
| Uses the Y coordinate of the region's Coords or first TextLine. |
| """ |
| if ns is None: |
| ns = self.NS |
| |
| coords_elem = region.find('page:Coords', ns) |
| if coords_elem is not None: |
| coords_str = coords_elem.get('points') |
| if coords_str: |
| coords = self._parse_coords(coords_str) |
| _, y1, _, _ = self._get_bounding_box(coords) |
| return y1 |
|
|
| |
| text_line = region.find('.//page:TextLine', ns) |
| if text_line is not None: |
| coords_elem = text_line.find('page:Coords', ns) |
| if coords_elem is not None: |
| coords_str = coords_elem.get('points') |
| if coords_str: |
| coords = self._parse_coords(coords_str) |
| _, y1, _, _ = self._get_bounding_box(coords) |
| return y1 |
|
|
| |
| return 0 |
|
|
| @staticmethod |
| def _parse_coords(coords_str: str) -> List[Tuple[int, int]]: |
| """Parse coordinate string from PAGE XML.""" |
| points = coords_str.split() |
| return [(int(p.split(',')[0]), int(p.split(',')[1])) for p in points] |
|
|
| @staticmethod |
| def _get_bounding_box(coords: List[Tuple[int, int]]) -> Tuple[int, int, int, int]: |
| """Get bounding box from polygon coordinates.""" |
| xs = [p[0] for p in coords] |
| ys = [p[1] for p in coords] |
| return min(xs), min(ys), max(xs), max(ys) |
|
|
|
|
| class TrOCRInference: |
| """TrOCR model inference.""" |
|
|
| def __init__(self, model_path: str, device: Optional[str] = None, |
| base_model: str = "kazars24/trocr-base-handwritten-ru", |
| normalize_bg: bool = False, |
| flip_rtl: bool = False, |
| is_huggingface: bool = False): |
| """ |
| Initialize TrOCR inference. |
| |
| Args: |
| model_path: Path to local checkpoint or HuggingFace model ID |
| device: 'cuda', 'cpu', or None for auto-detect |
| base_model: Base model for processor (used with local checkpoints) |
| normalize_bg: Apply background normalization |
| flip_rtl: Flip line images horizontally for RTL scripts |
| is_huggingface: If True, load from HuggingFace Hub instead of local path |
| """ |
| self.model_path = model_path |
| self.base_model = base_model |
| self.normalize_bg = normalize_bg |
| self.flip_rtl = flip_rtl |
| self.is_huggingface = is_huggingface |
|
|
| if device is None: |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| self.device = device |
|
|
| print(f"Loading model from {'HuggingFace Hub' if is_huggingface else 'local checkpoint'}: {model_path}...") |
| print(f"Using device: {self.device}") |
| print(f"Background normalization: {'Enabled' if self.normalize_bg else 'Disabled'}") |
|
|
| if is_huggingface: |
| |
| print(f"Downloading from HuggingFace Hub (if not cached): {model_path}") |
|
|
| |
| try: |
| print(f"Attempting to load processor from {model_path}...") |
| self.processor = TrOCRProcessor.from_pretrained(model_path) |
| |
| |
| |
| |
| |
| if self.processor.tokenizer.vocab_size < 100: |
| print(f"WARNING: tokenizer from '{model_path}' has vocab_size=" |
| f"{self.processor.tokenizer.vocab_size} (looks broken). " |
| f"Replacing tokenizer with microsoft/trocr-base-handwritten.") |
| _fallback = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
| self.processor.tokenizer = _fallback.tokenizer |
| except Exception as e: |
| print(f"Failed to load processor from model: {e}") |
| print(f"Falling back to base model processor: {self.base_model}") |
| self.processor = TrOCRProcessor.from_pretrained(self.base_model) |
|
|
| self.model = VisionEncoderDecoderModel.from_pretrained( |
| model_path, low_cpu_mem_usage=False) |
| |
| self.checkpoint_path = model_path |
| else: |
| |
| self.checkpoint_path = Path(model_path) |
|
|
| |
| |
| if self.checkpoint_path.is_file(): |
| model_dir = self.checkpoint_path.parent |
| print(f"Model path is a file, using directory: {model_dir}") |
| else: |
| model_dir = self.checkpoint_path |
|
|
| |
| |
| try: |
| print(f"Attempting to load processor from local model: {model_dir}") |
| self.processor = TrOCRProcessor.from_pretrained(model_dir) |
| except Exception as e: |
| print(f"Local processor not found ({e}), falling back to base model: {self.base_model}") |
| self.processor = TrOCRProcessor.from_pretrained(self.base_model) |
| self.model = VisionEncoderDecoderModel.from_pretrained( |
| model_dir, low_cpu_mem_usage=False) |
|
|
| self.model.to(self.device) |
| |
| for m in self.model.modules(): |
| if hasattr(m, '_float_tensor'): |
| m._float_tensor = m._float_tensor.to(self.device) |
| self.model.eval() |
|
|
| print("Model loaded successfully!") |
|
|
| def transcribe_line(self, line_image: Image.Image, num_beams: int = 4, |
| max_length: int = 128, return_confidence: bool = False): |
| """ |
| Transcribe a single line image. |
| |
| Args: |
| line_image: PIL Image of text line |
| num_beams: Number of beams for beam search (higher = better quality, slower) |
| max_length: Maximum sequence length |
| return_confidence: If True, return (text, confidence) tuple |
| |
| Returns: |
| If return_confidence=False: Transcribed text string |
| If return_confidence=True: Tuple of (text, confidence_score, char_confidences) |
| """ |
| |
| if self.normalize_bg: |
| line_image = normalize_background(line_image) |
|
|
| |
| if self.flip_rtl: |
| line_image = line_image.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
| |
| if line_image.mode != 'RGB': |
| line_image = line_image.convert('RGB') |
|
|
| |
| pixel_values = self.processor( |
| images=line_image, |
| return_tensors="pt" |
| ).pixel_values.to(self.device) |
|
|
| |
| with torch.no_grad(): |
| if return_confidence: |
| |
| outputs = self.model.generate( |
| pixel_values, |
| num_beams=num_beams, |
| max_length=max_length, |
| early_stopping=True, |
| output_scores=True, |
| return_dict_in_generate=True |
| ) |
| generated_ids = outputs.sequences |
|
|
| |
| |
| |
| if hasattr(outputs, 'scores') and outputs.scores and len(outputs.scores) > 0: |
| import torch.nn.functional as F |
|
|
| |
| |
| generated_tokens = generated_ids[0].cpu().numpy() |
|
|
| |
| |
| token_confidences = [] |
|
|
| for step_idx, score_tensor in enumerate(outputs.scores): |
| |
| |
| probs = F.softmax(score_tensor, dim=-1) |
|
|
| |
| |
| if step_idx + 1 < len(generated_tokens): |
| actual_token_id = generated_tokens[step_idx + 1] |
|
|
| |
| token_prob = probs[0, actual_token_id].item() |
| token_confidences.append(token_prob) |
|
|
| |
| avg_confidence = sum(token_confidences) / len(token_confidences) if token_confidences else 0.0 |
| char_confidences = token_confidences |
| else: |
| avg_confidence = 0.0 |
| char_confidences = [] |
| else: |
| generated_ids = self.model.generate( |
| pixel_values, |
| num_beams=num_beams, |
| max_length=max_length, |
| early_stopping=True |
| ) |
| avg_confidence = None |
| char_confidences = None |
|
|
| |
| text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
| if return_confidence: |
| return text, avg_confidence, char_confidences |
| else: |
| return text |
|
|
| def transcribe_segments(self, segments: List[LineSegment], |
| num_beams: int = 4, max_length: int = 128, |
| show_progress: bool = True) -> List[LineSegment]: |
| """ |
| Transcribe multiple line segments. |
| |
| Args: |
| segments: List of LineSegment objects |
| num_beams: Beam search parameter |
| max_length: Max sequence length |
| show_progress: Show progress bar |
| |
| Returns: |
| Updated segments with text field filled |
| """ |
| if show_progress: |
| from tqdm import tqdm |
| iterator = tqdm(segments, desc="Transcribing lines") |
| else: |
| iterator = segments |
|
|
| for segment in iterator: |
| segment.text = self.transcribe_line( |
| segment.image, |
| num_beams=num_beams, |
| max_length=max_length |
| ) |
|
|
| return segments |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Whole-page OCR inference for Ukrainian handwritten text" |
| ) |
| parser.add_argument( |
| '--image', |
| type=str, |
| required=True, |
| help='Path to input page image' |
| ) |
| parser.add_argument( |
| '--checkpoint', |
| type=str, |
| required=True, |
| help='Path to TrOCR checkpoint directory' |
| ) |
| parser.add_argument( |
| '--xml', |
| type=str, |
| default=None, |
| help='Optional: PAGE XML file for line segmentation (if not provided, automatic segmentation is used)' |
| ) |
| parser.add_argument( |
| '--output', |
| type=str, |
| default=None, |
| help='Output text file (default: <image_name>_transcription.txt)' |
| ) |
| parser.add_argument( |
| '--num_beams', |
| type=int, |
| default=4, |
| help='Number of beams for beam search (default: 4, higher=better quality but slower)' |
| ) |
| parser.add_argument( |
| '--max_length', |
| type=int, |
| default=128, |
| help='Maximum sequence length (default: 128)' |
| ) |
| parser.add_argument( |
| '--min_line_height', |
| type=int, |
| default=20, |
| help='Minimum line height for automatic segmentation (default: 20)' |
| ) |
| parser.add_argument( |
| '--debug', |
| action='store_true', |
| help='Visualize line segmentation' |
| ) |
| parser.add_argument( |
| '--device', |
| type=str, |
| default=None, |
| choices=['cuda', 'cpu'], |
| help='Device to use for inference (default: auto-detect)' |
| ) |
| parser.add_argument( |
| '--base_model', |
| type=str, |
| default='kazars24/trocr-base-handwritten-ru', |
| help='Base model for processor (default: kazars24/trocr-base-handwritten-ru)' |
| ) |
| parser.add_argument( |
| '--normalize-background', |
| action='store_true', |
| help='Apply background normalization (REQUIRED if model was trained with --normalize-background)' |
| ) |
| parser.add_argument( |
| '--flip-rtl', |
| action='store_true', |
| help='Flip line images horizontally for RTL scripts (REQUIRED if model was trained with --flip-rtl)' |
| ) |
|
|
| args = parser.parse_args() |
|
|
| print("=" * 80) |
| print("TrOCR Whole-Page Inference") |
| print("=" * 80) |
| print(f"Input image: {args.image}") |
| print(f"Checkpoint: {args.checkpoint}") |
| print(f"Segmentation: {'PAGE XML' if args.xml else 'Automatic'}") |
| print(f"Beam search: {args.num_beams}") |
| print("=" * 80) |
|
|
| |
| print("\nLoading image...") |
| Image.MAX_IMAGE_PIXELS = None |
| from PIL import ImageOps |
| image = Image.open(args.image) |
| image = ImageOps.exif_transpose(image) |
| image = image.convert('RGB') |
| print(f"Image size: {image.width}x{image.height}") |
|
|
| |
| print("\nSegmenting lines...") |
| if args.xml: |
| segmenter = PageXMLSegmenter(args.xml) |
| segments = segmenter.segment_lines(image) |
| print(f"Found {len(segments)} lines in PAGE XML") |
| else: |
| segmenter = LineSegmenter( |
| min_line_height=args.min_line_height |
| ) |
| segments = segmenter.segment_lines(image, debug=args.debug) |
| print(f"Detected {len(segments)} lines") |
|
|
| if not segments: |
| print("ERROR: No lines detected!") |
| return |
|
|
| |
| print("\nInitializing TrOCR model...") |
| ocr = TrOCRInference( |
| args.checkpoint, |
| device=args.device, |
| base_model=args.base_model, |
| normalize_bg=args.normalize_background, |
| flip_rtl=args.flip_rtl |
| ) |
|
|
| |
| print(f"\nTranscribing {len(segments)} lines...") |
| segments = ocr.transcribe_segments( |
| segments, |
| num_beams=args.num_beams, |
| max_length=args.max_length |
| ) |
|
|
| |
| transcription = "\n".join(seg.text for seg in segments if seg.text) |
|
|
| |
| if args.output: |
| output_path = Path(args.output) |
| else: |
| image_path = Path(args.image) |
| output_path = image_path.parent / f"{image_path.stem}_transcription.txt" |
|
|
| |
| print(f"\nSaving transcription to {output_path}...") |
| with open(output_path, 'w', encoding='utf-8') as f: |
| f.write(transcription) |
|
|
| |
| print("\n" + "=" * 80) |
| print("TRANSCRIPTION RESULT") |
| print("=" * 80) |
| print(transcription) |
| print("=" * 80) |
| print(f"\nTranscription saved to: {output_path}") |
| print(f"Total lines: {len(segments)}") |
| print(f"Average confidence: N/A (not implemented yet)") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|