AdvancedOCR / inference.py
satvikjain's picture
initial commit
1024113
import os
import io
import json
import time
import shutil
import tempfile
from typing import Tuple
import cv2
import fitz # PyMuPDF
import numpy as np
from PIL import Image
import torch
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from detectron2 import model_zoo
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# -----------------------------
# Configuration (override via env if needed)
# -----------------------------
TEXTLINE_MODEL_PATH = os.getenv("TEXTLINE_MODEL_PATH", "./model_final.pth")
USE_GPU = os.getenv("USE_GPU", "true").lower() == "true"
SCORE_THRESHOLD = float(os.getenv("SCORE_THRESHOLD", "0.5"))
AREA_THRESHOLD_PERCENT = float(os.getenv("AREA_THRESHOLD_PERCENT", "12.5"))
DPI = int(os.getenv("PDF_DPI", "200"))
TROCR_SPANISH_MODEL = os.getenv("TROCR_SPANISH_MODEL", "qantev/trocr-large-spanish")
TROCR_FALLBACK_MODEL = os.getenv("TROCR_FALLBACK_MODEL", "microsoft/trocr-base-printed")
class EnhancedTextlineExtractor:
def __init__(self, model_path: str):
self.cfg = self._setup_cfg(model_path)
self.predictor = DefaultPredictor(self.cfg)
# Init TrOCR
self.device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu")
self.trocr_processor, self.trocr_model = self._load_trocr()
self.trocr_model.to(self.device)
def _setup_cfg(self, model_path: str):
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # textline, baseline
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESHOLD
cfg.MODEL.WEIGHTS = model_path
cfg.DATASETS.TEST = ("page_test",)
cfg.DATALOADER.NUM_WORKERS = 2
MetadataCatalog.get("page_test").thing_classes = ["textline", "baseline"]
return cfg
def _load_trocr(self):
try:
processor = TrOCRProcessor.from_pretrained(TROCR_SPANISH_MODEL)
model = VisionEncoderDecoderModel.from_pretrained(TROCR_SPANISH_MODEL)
return processor, model
except Exception:
processor = TrOCRProcessor.from_pretrained(TROCR_FALLBACK_MODEL)
model = VisionEncoderDecoderModel.from_pretrained(TROCR_FALLBACK_MODEL)
return processor, model
def pdf_to_images(self, pdf_path: str, dpi: int = DPI):
doc = fitz.open(pdf_path)
images = []
try:
for page_num in range(len(doc)):
page = doc.load_page(page_num)
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("png")
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
images.append(img)
finally:
doc.close()
return images
def filter_margin_boxes_by_area(self, boxes, scores, area_threshold_percent: float = AREA_THRESHOLD_PERCENT):
if len(boxes) == 0:
return np.array([]), np.array([]), np.array([]), np.array([])
areas = []
for box in boxes:
x1, y1, x2, y2 = box
areas.append((x2 - x1) * (y2 - y1))
areas = np.array(areas)
avg_area = np.mean(areas)
area_threshold = avg_area * (area_threshold_percent / 100.0)
main_boxes, main_scores, margin_boxes, margin_scores = [], [], [], []
for b, s, a in zip(boxes, scores, areas):
if a >= area_threshold:
main_boxes.append(b)
main_scores.append(s)
else:
margin_boxes.append(b)
margin_scores.append(s)
return np.array(main_boxes), np.array(main_scores), np.array(margin_boxes), np.array(margin_scores)
def process_page_standard(self, image):
outputs = self.predictor(image)
instances = outputs["instances"]
boxes = instances.pred_boxes.tensor.cpu().numpy()
scores = instances.scores.cpu().numpy()
if len(boxes) == 0:
return {"success": False, "error": "No textlines detected"}
main_boxes, main_scores, _, _ = self.filter_margin_boxes_by_area(boxes, scores)
if len(main_boxes) == 0:
return {"success": False, "error": "No textlines after filtering"}
line_segments = []
full_text_lines = []
for i, (box, score) in enumerate(zip(main_boxes, main_scores)):
x1, y1, x2, y2 = map(int, box)
crop_bgr = image[y1:y2, x1:x2]
try:
crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(crop_rgb)
pixel_values = self.trocr_processor(images=pil_image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
with torch.no_grad():
generated_ids = self.trocr_model.generate(pixel_values, max_new_tokens=128)
generated_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
text = generated_text.strip()
full_text_lines.append(text)
line_segments.append({
"line_index": i,
"bbox": [int(x1), int(y1), int(x2), int(y2)],
"score": float(score),
"text": text,
"confidence": 1.0
})
except Exception:
line_segments.append({
"line_index": i,
"bbox": [int(x1), int(y1), int(x2), int(y2)],
"score": float(score),
"text": "",
"confidence": 0.0
})
return {
"success": True,
"line_segments": line_segments,
"full_text": "\n".join(full_text_lines)
}
def _zip_directory(src_dir: str, zip_path: str) -> str:
base, _ = os.path.splitext(zip_path)
archive = shutil.make_archive(base, 'zip', src_dir)
return archive
def run_ocr(pdf_path: str, split_page_enabled: bool = False, use_llm: bool = False, gemini_key: str = None) -> Tuple[str, str]:
"""
Run OCR on the provided PDF.
Returns:
combined_text (str), zip_file_path (str)
"""
extractor = EnhancedTextlineExtractor(TEXTLINE_MODEL_PATH)
images = extractor.pdf_to_images(pdf_path, dpi=DPI)
temp_dir = tempfile.mkdtemp(prefix="ocr_outputs_")
inferences_dir = os.path.join(temp_dir, "inferences")
os.makedirs(inferences_dir, exist_ok=True)
all_results = []
for i, image in enumerate(images):
result = extractor.process_page_standard(image)
all_results.append(result)
page_file = os.path.join(inferences_dir, f"page_{i+1}_result.json")
with open(page_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
combined_text = "\n\n".join([r.get("full_text", "") for r in all_results if r.get("success")])
# Optional Gemini correction over combined text (simple, single pass)
if use_llm and gemini_key and combined_text.strip():
try:
import google.generativeai as genai
genai.configure(api_key=gemini_key)
prompt = (
"Correct the following historical Spanish OCR text while preserving grammar and style. "
"Fix orthography, punctuation, and obvious OCR mistakes. Return only corrected text.\n\n" + combined_text
)
response = genai.GenerativeModel('gemini-2.5-pro').generate_content(prompt)
if getattr(response, 'text', None):
combined_text = response.text.strip()
except Exception:
# Swallow LLM errors and return original text
pass
zip_path = os.path.join(temp_dir, "per_page_jsons.zip")
archive_path = _zip_directory(inferences_dir, zip_path)
return combined_text, archive_path