Spaces:
Build error
Build error
| import os | |
| import io | |
| import json | |
| import time | |
| import shutil | |
| import tempfile | |
| from typing import Tuple | |
| import cv2 | |
| import fitz # PyMuPDF | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from detectron2.config import get_cfg | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.data import MetadataCatalog | |
| from detectron2 import model_zoo | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| # ----------------------------- | |
| # Configuration (override via env if needed) | |
| # ----------------------------- | |
| TEXTLINE_MODEL_PATH = os.getenv("TEXTLINE_MODEL_PATH", "./model_final.pth") | |
| USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" | |
| SCORE_THRESHOLD = float(os.getenv("SCORE_THRESHOLD", "0.5")) | |
| AREA_THRESHOLD_PERCENT = float(os.getenv("AREA_THRESHOLD_PERCENT", "12.5")) | |
| DPI = int(os.getenv("PDF_DPI", "200")) | |
| TROCR_SPANISH_MODEL = os.getenv("TROCR_SPANISH_MODEL", "qantev/trocr-large-spanish") | |
| TROCR_FALLBACK_MODEL = os.getenv("TROCR_FALLBACK_MODEL", "microsoft/trocr-base-printed") | |
| class EnhancedTextlineExtractor: | |
| def __init__(self, model_path: str): | |
| self.cfg = self._setup_cfg(model_path) | |
| self.predictor = DefaultPredictor(self.cfg) | |
| # Init TrOCR | |
| self.device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu") | |
| self.trocr_processor, self.trocr_model = self._load_trocr() | |
| self.trocr_model.to(self.device) | |
| def _setup_cfg(self, model_path: str): | |
| cfg = get_cfg() | |
| cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # textline, baseline | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESHOLD | |
| cfg.MODEL.WEIGHTS = model_path | |
| cfg.DATASETS.TEST = ("page_test",) | |
| cfg.DATALOADER.NUM_WORKERS = 2 | |
| MetadataCatalog.get("page_test").thing_classes = ["textline", "baseline"] | |
| return cfg | |
| def _load_trocr(self): | |
| try: | |
| processor = TrOCRProcessor.from_pretrained(TROCR_SPANISH_MODEL) | |
| model = VisionEncoderDecoderModel.from_pretrained(TROCR_SPANISH_MODEL) | |
| return processor, model | |
| except Exception: | |
| processor = TrOCRProcessor.from_pretrained(TROCR_FALLBACK_MODEL) | |
| model = VisionEncoderDecoderModel.from_pretrained(TROCR_FALLBACK_MODEL) | |
| return processor, model | |
| def pdf_to_images(self, pdf_path: str, dpi: int = DPI): | |
| doc = fitz.open(pdf_path) | |
| images = [] | |
| try: | |
| for page_num in range(len(doc)): | |
| page = doc.load_page(page_num) | |
| mat = fitz.Matrix(dpi / 72, dpi / 72) | |
| pix = page.get_pixmap(matrix=mat) | |
| img_data = pix.tobytes("png") | |
| nparr = np.frombuffer(img_data, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| images.append(img) | |
| finally: | |
| doc.close() | |
| return images | |
| def filter_margin_boxes_by_area(self, boxes, scores, area_threshold_percent: float = AREA_THRESHOLD_PERCENT): | |
| if len(boxes) == 0: | |
| return np.array([]), np.array([]), np.array([]), np.array([]) | |
| areas = [] | |
| for box in boxes: | |
| x1, y1, x2, y2 = box | |
| areas.append((x2 - x1) * (y2 - y1)) | |
| areas = np.array(areas) | |
| avg_area = np.mean(areas) | |
| area_threshold = avg_area * (area_threshold_percent / 100.0) | |
| main_boxes, main_scores, margin_boxes, margin_scores = [], [], [], [] | |
| for b, s, a in zip(boxes, scores, areas): | |
| if a >= area_threshold: | |
| main_boxes.append(b) | |
| main_scores.append(s) | |
| else: | |
| margin_boxes.append(b) | |
| margin_scores.append(s) | |
| return np.array(main_boxes), np.array(main_scores), np.array(margin_boxes), np.array(margin_scores) | |
| def process_page_standard(self, image): | |
| outputs = self.predictor(image) | |
| instances = outputs["instances"] | |
| boxes = instances.pred_boxes.tensor.cpu().numpy() | |
| scores = instances.scores.cpu().numpy() | |
| if len(boxes) == 0: | |
| return {"success": False, "error": "No textlines detected"} | |
| main_boxes, main_scores, _, _ = self.filter_margin_boxes_by_area(boxes, scores) | |
| if len(main_boxes) == 0: | |
| return {"success": False, "error": "No textlines after filtering"} | |
| line_segments = [] | |
| full_text_lines = [] | |
| for i, (box, score) in enumerate(zip(main_boxes, main_scores)): | |
| x1, y1, x2, y2 = map(int, box) | |
| crop_bgr = image[y1:y2, x1:x2] | |
| try: | |
| crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(crop_rgb) | |
| pixel_values = self.trocr_processor(images=pil_image, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(self.device) | |
| with torch.no_grad(): | |
| generated_ids = self.trocr_model.generate(pixel_values, max_new_tokens=128) | |
| generated_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| text = generated_text.strip() | |
| full_text_lines.append(text) | |
| line_segments.append({ | |
| "line_index": i, | |
| "bbox": [int(x1), int(y1), int(x2), int(y2)], | |
| "score": float(score), | |
| "text": text, | |
| "confidence": 1.0 | |
| }) | |
| except Exception: | |
| line_segments.append({ | |
| "line_index": i, | |
| "bbox": [int(x1), int(y1), int(x2), int(y2)], | |
| "score": float(score), | |
| "text": "", | |
| "confidence": 0.0 | |
| }) | |
| return { | |
| "success": True, | |
| "line_segments": line_segments, | |
| "full_text": "\n".join(full_text_lines) | |
| } | |
| def _zip_directory(src_dir: str, zip_path: str) -> str: | |
| base, _ = os.path.splitext(zip_path) | |
| archive = shutil.make_archive(base, 'zip', src_dir) | |
| return archive | |
| def run_ocr(pdf_path: str, split_page_enabled: bool = False, use_llm: bool = False, gemini_key: str = None) -> Tuple[str, str]: | |
| """ | |
| Run OCR on the provided PDF. | |
| Returns: | |
| combined_text (str), zip_file_path (str) | |
| """ | |
| extractor = EnhancedTextlineExtractor(TEXTLINE_MODEL_PATH) | |
| images = extractor.pdf_to_images(pdf_path, dpi=DPI) | |
| temp_dir = tempfile.mkdtemp(prefix="ocr_outputs_") | |
| inferences_dir = os.path.join(temp_dir, "inferences") | |
| os.makedirs(inferences_dir, exist_ok=True) | |
| all_results = [] | |
| for i, image in enumerate(images): | |
| result = extractor.process_page_standard(image) | |
| all_results.append(result) | |
| page_file = os.path.join(inferences_dir, f"page_{i+1}_result.json") | |
| with open(page_file, "w", encoding="utf-8") as f: | |
| json.dump(result, f, ensure_ascii=False, indent=2) | |
| combined_text = "\n\n".join([r.get("full_text", "") for r in all_results if r.get("success")]) | |
| # Optional Gemini correction over combined text (simple, single pass) | |
| if use_llm and gemini_key and combined_text.strip(): | |
| try: | |
| import google.generativeai as genai | |
| genai.configure(api_key=gemini_key) | |
| prompt = ( | |
| "Correct the following historical Spanish OCR text while preserving grammar and style. " | |
| "Fix orthography, punctuation, and obvious OCR mistakes. Return only corrected text.\n\n" + combined_text | |
| ) | |
| response = genai.GenerativeModel('gemini-2.5-pro').generate_content(prompt) | |
| if getattr(response, 'text', None): | |
| combined_text = response.text.strip() | |
| except Exception: | |
| # Swallow LLM errors and return original text | |
| pass | |
| zip_path = os.path.join(temp_dir, "per_page_jsons.zip") | |
| archive_path = _zip_directory(inferences_dir, zip_path) | |
| return combined_text, archive_path | |