agentic-ocr-extractor / agentic_ocr_extractor.py
ahczhg's picture
Upload agentic_ocr_extractor.py with huggingface_hub
3354205 verified
"""
Lightweight Agentic OCR Document Extraction (Tesseract)
A lightweight, agentic OCR pipeline to extract text and structured fields from document images.
Key features:
- Multiple preprocessing variants (grayscale, thresholding, sharpening, denoise, resize)
- Multiple Tesseract page segmentation modes (PSM)
- Candidate scoring via average OCR confidence
- Simple rule-based field extraction (DOI, title, authors, abstract, keywords)
"""
import os
import re
import json
import argparse
import unicodedata
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import cv2
from PIL import Image
import pytesseract
from pytesseract import Output
# ============================================================================
# Preprocessing Variants
# ============================================================================
def _ensure_uint8(img: np.ndarray) -> np.ndarray:
"""Ensure image is uint8 dtype, clipping values if needed."""
if img.dtype == np.uint8:
return img
return np.clip(img, 0, 255).astype(np.uint8)
def preprocess_variants(rgb_img: np.ndarray, scale_factor: float = 1.5) -> Dict[str, np.ndarray]:
"""Generate multiple preprocessing variants for OCR."""
variants: Dict[str, np.ndarray] = {}
# Base
variants['raw'] = rgb_img
# Upscale (often improves OCR on smaller text)
h, w = rgb_img.shape[:2]
up = cv2.resize(rgb_img, (int(w * scale_factor), int(h * scale_factor)), interpolation=cv2.INTER_CUBIC)
variants['upscaled'] = up
# Grayscale
gray = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY)
variants['gray'] = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
# Otsu threshold
_, th_otsu = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
variants['otsu'] = cv2.cvtColor(th_otsu, cv2.COLOR_GRAY2RGB)
# Adaptive threshold
th_adapt = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 35, 11
)
variants['adaptive'] = cv2.cvtColor(th_adapt, cv2.COLOR_GRAY2RGB)
# Denoise
den = cv2.fastNlMeansDenoising(gray, None, h=15, templateWindowSize=7, searchWindowSize=21)
variants['denoise'] = cv2.cvtColor(den, cv2.COLOR_GRAY2RGB)
# Sharpen
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=np.float32)
sharp = cv2.filter2D(gray, -1, kernel)
variants['sharpen'] = cv2.cvtColor(_ensure_uint8(sharp), cv2.COLOR_GRAY2RGB)
# Contrast stretch (CLAHE for better local contrast)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
variants['clahe'] = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
# Morphological closing (helps with broken characters)
kernel_morph = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
closed = cv2.morphologyEx(th_otsu, cv2.MORPH_CLOSE, kernel_morph)
variants['morph_close'] = cv2.cvtColor(closed, cv2.COLOR_GRAY2RGB)
return variants
# ============================================================================
# OCR Functions
# ============================================================================
def ocr_with_confidence(rgb_img: np.ndarray, psm: int = 6) -> Tuple[str, float, int]:
"""Run OCR and return text, average confidence, and word count."""
cfg = f'--oem 3 --psm {psm}'
data = pytesseract.image_to_data(rgb_img, output_type=Output.DICT, config=cfg)
confs: List[float] = []
word_count = 0
for conf, text in zip(data.get('conf', []), data.get('text', [])):
try:
c_val = float(conf)
except (ValueError, TypeError):
continue
if not text or not str(text).strip():
continue
if c_val < 0:
continue
confs.append(c_val)
word_count += 1
avg_conf = float(np.mean(confs)) if confs else 0.0
text = pytesseract.image_to_string(rgb_img, config=cfg)
return text, avg_conf, word_count
@dataclass
class OcrCandidate:
"""Container for OCR candidate results."""
variant: str
psm: int
avg_conf: float
text: str
word_count: int
score: float # Combined score for ranking
def compute_score(avg_conf: float, text: str, word_count: int) -> float:
"""Compute a combined score factoring in confidence, length, and word count."""
length = len(text.strip())
# Base score is confidence
score = avg_conf
# Penalize very short outputs (likely OCR failure)
if length < 40:
score *= 0.5
elif length < 100:
score *= 0.8
# Bonus for reasonable word counts (indicates successful text extraction)
if word_count > 20:
score *= 1.1
return score
def _process_variant(args: Tuple[str, np.ndarray, int]) -> OcrCandidate:
"""Process a single variant/psm combination (for parallel execution)."""
vname, vimg, psm = args
text, avg_conf, word_count = ocr_with_confidence(vimg, psm=psm)
score = compute_score(avg_conf, text, word_count)
return OcrCandidate(
variant=vname, psm=psm, avg_conf=avg_conf,
text=text, word_count=word_count, score=score
)
def run_agent(
rgb_img: np.ndarray,
psms: List[int] = None,
scale_factor: float = 1.5,
parallel: bool = True,
top_k: int = 10,
verbose: bool = True
) -> OcrCandidate:
"""Run agentic OCR with multiple variants and PSMs, return best candidate."""
if psms is None:
psms = [3, 4, 6, 11]
variants = preprocess_variants(rgb_img, scale_factor=scale_factor)
# Build task list
tasks = [(vname, vimg, psm) for vname, vimg in variants.items() for psm in psms]
candidates: List[OcrCandidate] = []
if parallel:
with ThreadPoolExecutor(max_workers=min(8, len(tasks))) as executor:
futures = [executor.submit(_process_variant, task) for task in tasks]
for future in as_completed(futures):
candidates.append(future.result())
else:
for task in tasks:
candidates.append(_process_variant(task))
# Sort by combined score (descending)
candidates.sort(key=lambda c: c.score, reverse=True)
# Print leaderboard
if verbose:
print(f'Top {top_k} OCR candidates:')
print('-' * 90)
for c in candidates[:top_k]:
preview = c.text.strip().replace('\n', ' ')[:60]
print(f"{c.variant:12s} psm={c.psm:<2d} conf={c.avg_conf:5.1f} "
f"words={c.word_count:3d} score={c.score:5.1f} '{preview}...'")
print('-' * 90)
return candidates[0]
# ============================================================================
# Text Cleaning Utilities
# ============================================================================
def clean_text(text: str) -> str:
"""Clean and normalize OCR text output."""
# Normalize line endings (handle \r\n, \r, etc.)
text = text.replace('\r\n', '\n').replace('\r', '\n')
# Normalize whitespace (tabs, multiple spaces -> single space)
text = re.sub(r'[^\S\n]+', ' ', text)
# Remove spaces at start/end of lines
text = re.sub(r'^ +| +$', '', text, flags=re.MULTILINE)
# Remove repeated blank lines (keep max one blank line)
text = re.sub(r'\n\s*\n+', '\n\n', text)
return text.strip()
def fix_ocr_artifacts(text: str) -> str:
"""Fix common OCR misreads and artifacts."""
replacements = [
# Common character confusions
(r'\bl\b', 'I'), # lowercase L -> I (context: single letter)
(r'(?<=[a-z])0(?=[a-z])', 'o'), # 0 -> o between letters
(r'(?<=[a-z])1(?=[a-z])', 'l'), # 1 -> l between letters
(r'\bll\b', 'II'), # ll -> II (Roman numeral)
# Fix split words (hyphenation at line breaks)
(r'(\w)-\n(\w)', r'\1\2'),
# Remove stray single characters on their own lines
(r'\n[^\w\n]\n', '\n'),
# Fix multiple periods
(r'\.{2,}', '...'),
# Fix spacing around punctuation
(r'\s+([.,;:!?])', r'\1'),
(r'([.,;:!?])(?=[A-Za-z])', r'\1 '),
]
for pattern, repl in replacements:
text = re.sub(pattern, repl, text)
return text
def normalize_unicode(text: str) -> str:
"""Normalize Unicode characters to ASCII equivalents where appropriate."""
# Normalize to NFKC form (compatibility decomposition + canonical composition)
text = unicodedata.normalize('NFKC', text)
# Common Unicode replacements
replacements = {
'\u2018': "'", '\u2019': "'", # Smart quotes
'\u201c': '"', '\u201d': '"',
'\u2013': '-', '\u2014': '-', # En/em dash
'\u2026': '...', # Ellipsis
'\ufb01': 'fi', '\ufb02': 'fl', # Ligatures
'\u00a0': ' ', # Non-breaking space
}
for old, new in replacements.items():
text = text.replace(old, new)
return text
def process_ocr_text(text: str, fix_artifacts: bool = True, normalize: bool = True) -> str:
"""Full text processing pipeline."""
if normalize:
text = normalize_unicode(text)
text = clean_text(text)
if fix_artifacts:
text = fix_ocr_artifacts(text)
return text
# ============================================================================
# Field Extraction
# ============================================================================
def _first_match(pattern: str, text: str, flags: int = 0) -> Optional[str]:
"""Return first regex capture group match, or None."""
m = re.search(pattern, text, flags)
return m.group(1).strip() if m else None
def _all_matches(pattern: str, text: str, flags: int = 0) -> List[str]:
"""Return all regex capture group matches."""
return [m.strip() for m in re.findall(pattern, text, flags) if m.strip()]
@dataclass
class ExtractedFields:
"""Structured container for extracted document fields."""
doi: Optional[str] = None
issn: Optional[str] = None
volume: Optional[str] = None
issue: Optional[str] = None
year: Optional[str] = None
pages: Optional[str] = None
received: Optional[str] = None
accepted: Optional[str] = None
published: Optional[str] = None
title: Optional[str] = None
authors: Optional[List[str]] = None
affiliations: Optional[List[str]] = None
abstract: Optional[str] = None
keywords: Optional[List[str]] = None
email: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
def extract_doi(text: str) -> Optional[str]:
"""Extract DOI with multiple pattern fallbacks."""
patterns = [
r'(?:https?://)?(?:dx\.)?doi\.org/\s*(10\.[^\s]+)',
r'DOI\s*[::\u00ef\u00bc\u009a]\s*(10\.[^\s]+)',
r'\b(10\.\d{4,}/[^\s]+)',
]
for pattern in patterns:
doi = _first_match(pattern, text, re.IGNORECASE)
if doi:
# Clean trailing punctuation
doi = re.sub(r'[.,;:)\]]+$', '', doi)
return doi
return None
def extract_identifiers(text: str) -> Dict[str, Optional[str]]:
"""Extract various document identifiers."""
return {
'issn': _first_match(r'ISSN\s*[::\u00ef\u00bc\u009a]?\s*([0-9]{4}-[0-9]{3}[0-9Xx])', text, re.IGNORECASE),
'isbn': _first_match(r'ISBN\s*[::\u00ef\u00bc\u009a]?\s*([\d-]{10,17})', text, re.IGNORECASE),
'pmid': _first_match(r'PMID\s*[::\u00ef\u00bc\u009a]?\s*(\d+)', text, re.IGNORECASE),
'arxiv': _first_match(r'arXiv\s*[::\u00ef\u00bc\u009a]?\s*(\d+\.\d+)', text, re.IGNORECASE),
}
def extract_publication_info(text: str) -> Dict[str, Optional[str]]:
"""Extract volume, issue, pages, year."""
return {
'volume': _first_match(r'Vol(?:ume)?\.?\s*[::\u00ef\u00bc\u009a]?\s*(\d{1,4})', text, re.IGNORECASE),
'issue': _first_match(r'(?:Issue|No\.?|Number)\s*[::\u00ef\u00bc\u009a]?\s*(\d{1,4})', text, re.IGNORECASE),
'pages': _first_match(r'(?:pp?\.?|pages?)\s*[::\u00ef\u00bc\u009a]?\s*(\d+\s*[-\u2013]\s*\d+)', text, re.IGNORECASE),
'year': _first_match(r'\b((?:19|20)\d{2})\b', text),
}
def extract_dates(text: str) -> Dict[str, Optional[str]]:
"""Extract received/accepted/published dates."""
date_pattern = r'[::\u00ef\u00bc\u009a]?\s*([A-Za-z]+\.?\s+\d{1,2},?\s+\d{4}|\d{1,2}[-/]\d{1,2}[-/]\d{2,4}|\d{4}[-/]\d{1,2}[-/]\d{1,2})'
return {
'received': _first_match(rf'Received{date_pattern}', text, re.IGNORECASE),
'accepted': _first_match(rf'Accepted{date_pattern}', text, re.IGNORECASE),
'published': _first_match(rf'Published{date_pattern}', text, re.IGNORECASE),
}
def extract_abstract(text: str) -> Optional[str]:
"""Extract abstract text."""
patterns = [
r'Abstract\s*[::\u00ef\u00bc\u009a]?\s*(.*?)(?=\n\s*(?:Keywords?|Key\s*words|Introduction|1\.|1\s))',
r'Abstract\s*[::\u00ef\u00bc\u009a]?\s*(.*?)(?=\n\n)',
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
if match:
abstract = match.group(1).strip()
if len(abstract) > 50: # Sanity check
return clean_text(abstract)
return None
def extract_keywords(text: str) -> Optional[List[str]]:
"""Extract keywords list."""
pattern = r'(?:Keywords?|Key\s*words)\s*[::\u00ef\u00bc\u009a]?\s*(.*?)(?=\n\n|\n\s*[A-Z][a-z]+:|\Z)'
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
if match:
kw_text = match.group(1).strip()
# Split on semicolon, comma, or bullet points
parts = re.split(r'[;,\u2022\u00b7]|\s{2,}', kw_text)
keywords = [p.strip().strip('.-') for p in parts if p.strip() and len(p.strip()) > 2]
return keywords if keywords else None
return None
def extract_title(lines: List[str]) -> Optional[str]:
"""Extract paper title using heuristics."""
exclude_markers = {
'journal', 'issn', 'isbn', 'volume', 'issue', 'article', 'research article',
'department', 'university', 'corresponding', 'received', 'accepted',
'abstract', 'keywords', 'http', 'doi', 'email', '@', 'copyright'
}
candidates = []
for i, ln in enumerate(lines[:15]): # Title usually in first 15 lines
ln_lower = ln.lower()
# Skip lines with exclude markers
if any(m in ln_lower for m in exclude_markers):
continue
# Length constraints
if not (25 <= len(ln) <= 200):
continue
# Must have multiple words
words = ln.split()
if len(words) < 4:
continue
# High letter ratio
letter_ratio = sum(c.isalpha() for c in ln) / max(1, len(ln))
if letter_ratio < 0.6:
continue
# Score: prefer earlier lines, proper capitalization, longer titles
score = 100 - i * 5 # Earlier is better
if ln[0].isupper():
score += 10
if 50 < len(ln) < 150:
score += 10
candidates.append((score, ln))
candidates.sort(reverse=True)
return candidates[0][1] if candidates else None
def extract_authors(text: str, lines: List[str], title: Optional[str]) -> Optional[List[str]]:
"""Extract author names."""
# Try to find author line after title
if title and title in lines:
idx = lines.index(title)
for i in range(idx + 1, min(idx + 4, len(lines))):
candidate = lines[i]
# Authors typically have commas, "and", multiple capitalized words
if re.search(r'\b(?:and|&)\b', candidate, re.IGNORECASE) or candidate.count(',') >= 1:
# Check for name-like pattern (capitalized words)
caps = re.findall(r'\b[A-Z][a-z]+\b', candidate)
if len(caps) >= 2:
# Split into individual authors
authors = re.split(r',\s*(?:and\s+)?|\s+and\s+|\s*&\s*', candidate)
authors = [a.strip() for a in authors if a.strip() and len(a.strip()) > 2]
if authors:
return authors
return None
def extract_email(text: str) -> Optional[str]:
"""Extract corresponding author email."""
pattern = r'[\w.-]+@[\w.-]+\.\w+'
emails = re.findall(pattern, text)
return emails[0] if emails else None
def extract_fields(text: str) -> ExtractedFields:
"""Main extraction function combining all extractors."""
lines = [ln.strip() for ln in text.split('\n') if ln.strip()]
# Extract all fields
doi = extract_doi(text)
identifiers = extract_identifiers(text)
pub_info = extract_publication_info(text)
dates = extract_dates(text)
title = extract_title(lines)
authors = extract_authors(text, lines, title)
abstract = extract_abstract(text)
keywords = extract_keywords(text)
email = extract_email(text)
return ExtractedFields(
doi=doi,
issn=identifiers.get('issn'),
volume=pub_info.get('volume'),
issue=pub_info.get('issue'),
pages=pub_info.get('pages'),
year=pub_info.get('year'),
received=dates.get('received'),
accepted=dates.get('accepted'),
published=dates.get('published'),
title=title,
authors=authors,
abstract=abstract,
keywords=keywords,
email=email,
)
# ============================================================================
# Main Processing Function
# ============================================================================
def process_image(
image_path: str,
output_text_path: Optional[str] = None,
output_json_path: Optional[str] = None,
scale_factor: float = 1.5,
psms: List[int] = None,
verbose: bool = True
) -> Tuple[str, ExtractedFields, OcrCandidate]:
"""
Process a document image and extract text and structured fields.
Args:
image_path: Path to the input image file
output_text_path: Optional path to save extracted text
output_json_path: Optional path to save extracted fields as JSON
scale_factor: Scale factor for image upscaling
psms: List of Tesseract page segmentation modes to try
verbose: Whether to print progress information
Returns:
Tuple of (cleaned_text, extracted_fields, best_ocr_candidate)
"""
if psms is None:
psms = [3, 4, 6, 11]
# Load image
bgr = cv2.imread(image_path)
if bgr is None:
raise ValueError(f'Failed to read image: {image_path}')
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
if verbose:
print(f'Processing image: {image_path}')
print(f'Image size: {rgb.shape[1]}x{rgb.shape[0]}')
# Run agentic OCR
best = run_agent(rgb, psms=psms, scale_factor=scale_factor, verbose=verbose)
if verbose:
print(f'\nSelected: {best.variant} | PSM={best.psm} | conf={best.avg_conf:.1f} | score={best.score:.1f}')
# Process text
cleaned_text = process_ocr_text(best.text)
# Extract fields
fields = extract_fields(cleaned_text)
# Save outputs if paths provided
if output_text_path:
with open(output_text_path, 'w', encoding='utf-8') as f:
f.write(cleaned_text)
if verbose:
print(f'Saved text: {output_text_path}')
if output_json_path:
with open(output_json_path, 'w', encoding='utf-8') as f:
json.dump(fields.to_dict(), f, indent=2, ensure_ascii=False)
if verbose:
print(f'Saved JSON: {output_json_path}')
return cleaned_text, fields, best
# ============================================================================
# CLI Entry Point
# ============================================================================
def main():
"""Command-line interface for the OCR extractor."""
parser = argparse.ArgumentParser(
description='Lightweight Agentic OCR Document Extraction',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
Examples:
python agentic_ocr_extractor.py document.jpg
python agentic_ocr_extractor.py document.png -o output.txt -j fields.json
python agentic_ocr_extractor.py scan.jpg --scale 2.0 --psm 3 6 11
'''
)
parser.add_argument('image', help='Path to the input image file')
parser.add_argument('-o', '--output-text', help='Path to save extracted text')
parser.add_argument('-j', '--output-json', help='Path to save extracted fields as JSON')
parser.add_argument('--scale', type=float, default=1.5, help='Scale factor for upscaling (default: 1.5)')
parser.add_argument('--psm', type=int, nargs='+', default=[3, 4, 6, 11],
help='Tesseract PSM modes to try (default: 3 4 6 11)')
parser.add_argument('-q', '--quiet', action='store_true', help='Suppress progress output')
args = parser.parse_args()
if not os.path.exists(args.image):
print(f'Error: Image file not found: {args.image}')
return 1
try:
cleaned_text, fields, best = process_image(
args.image,
output_text_path=args.output_text,
output_json_path=args.output_json,
scale_factor=args.scale,
psms=args.psm,
verbose=not args.quiet
)
# Print extracted fields
print('\nExtracted Fields:')
print(json.dumps(fields.to_dict(), indent=2, ensure_ascii=False))
return 0
except Exception as e:
print(f'Error: {e}')
return 1
if __name__ == '__main__':
exit(main())