import io from typing import Any import os import gradio as gr import matplotlib.pyplot as plt import numpy as np from matplotlib.figure import Figure from PIL import Image import requests from dotenv import load_dotenv from onnxtr.io import DocumentFile from onnxtr.models import EngineConfig, from_hub, ocr_predictor from onnxtr.models.predictor import OCRPredictor from onnxtr.utils.visualization import visualize_page load_dotenv() DET_ARCHS: list[str] = [ "fast_base", "fast_small", "fast_tiny", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", ] RECO_ARCHS: list[str] = [ "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large", "master", "sar_resnet31", "vitstr_small", "vitstr_base", "parseq", "viptr_tiny", ] CUSTOM_RECO_ARCHS: list[str] = [ "Felix92/onnxtr-parseq-multilingual-v1", ] def load_predictor( det_arch: str, reco_arch: str, assume_straight_pages: bool, straighten_pages: bool, export_as_straight_boxes: bool, detect_language: bool, bin_thresh: float, box_thresh: float, disable_crop_orientation: bool = False, disable_page_orientation: bool = False, ) -> OCRPredictor: """Load a predictor from doctr.models Args: ---- det_arch: detection architecture reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not disable_crop_orientation: whether to disable crop orientation or not disable_page_orientation: whether to disable page orientation or not straighten_pages: whether to straighten rotated pages or not export_as_straight_boxes: whether to export straight boxes detect_language: whether to detect the language of the text bin_thresh: binarization threshold for the segmentation map box_thresh: minimal objectness score to consider a box Returns: ------- instance of OCRPredictor """ engine_cfg = EngineConfig( providers=[ ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"}) ] ) predictor = ocr_predictor( det_arch=det_arch, reco_arch=reco_arch if reco_arch not in CUSTOM_RECO_ARCHS else from_hub(reco_arch), assume_straight_pages=assume_straight_pages, straighten_pages=straighten_pages, detect_language=detect_language, export_as_straight_boxes=export_as_straight_boxes, detect_orientation=not assume_straight_pages, disable_crop_orientation=disable_crop_orientation, disable_page_orientation=disable_page_orientation, det_engine_cfg=engine_cfg, reco_engine_cfg=engine_cfg, clf_engine_cfg=engine_cfg, ) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh return predictor def forward_image(predictor: OCRPredictor, image: np.ndarray) -> np.ndarray: """Forward an image through the predictor Args: ---- predictor: instance of OCRPredictor image: image to process Returns: ------- segmentation map """ processed_batches = predictor.det_predictor.pre_processor([image]) out = predictor.det_predictor.model(processed_batches[0], return_model_output=True) seg_map = out["out_map"] return seg_map def matplotlib_to_pil(fig: Figure | np.ndarray) -> Image.Image: """Convert a matplotlib figure to a PIL image Args: ---- fig: matplotlib figure or numpy array Returns: ------- PIL image """ buf = io.BytesIO() if isinstance(fig, Figure): fig.savefig(buf) else: plt.imsave(buf, fig) buf.seek(0) return Image.open(buf) def _simplify(text: str, access_token: str) -> str: """Simplify text using the T2K API Args: ---- text: text to simplify access_token: API token for the T2K API Returns: ------- simplified text """ if not access_token: return ( "Kein API-Zugriffstoken verfügbar. Bitte setzen Sie SIMPLIFICATION_TOKEN." ) url = os.getenv("SIMPLIFICATION_URL", "") json_data = { "stream": False, "inputText": text, "maxNewTokens": 256, "batchSize": 4, "decodingStrategy": "greedy", "separateCompounds": False, "filterComplexWords": False, "simplificationLevel": "A1", } headers = { "X-API-Token": access_token, "Content-Type": "application/json", "Accept": "application/json", } response = requests.post(url, json=json_data, headers=headers) if response.status_code == 200: return response.json()["result"] else: return f"Fehler in der Vereinfachungs-API (HTTP {response.status_code})" def analyze_page( uploaded_file: Any, page_idx: int, det_arch: str, reco_arch: str, assume_straight_pages: bool, disable_crop_orientation: bool, disable_page_orientation: bool, straighten_pages: bool, export_as_straight_boxes: bool, detect_language: bool, bin_thresh: float, box_thresh: float, t2k_access_token: str = "", ): """Analyze a page Args: ---- uploaded_file: file to analyze page_idx: index of the page to analyze det_arch: detection architecture reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not disable_crop_orientation: whether to disable crop orientation or not disable_page_orientation: whether to disable page orientation or not straighten_pages: whether to straighten rotated pages or not export_as_straight_boxes: whether to export straight boxes detect_language: whether to detect the language of the text bin_thresh: binarization threshold for the segmentation map box_thresh: minimal objectness score to consider a box t2k_access_token: Access token for the T2K simplification API Returns: ------- input image, output image, OCR output, simplified output """ if uploaded_file is None: return None, None, "Bitte laden Sie ein Dokument hoch", None if uploaded_file.name.endswith(".pdf"): doc = DocumentFile.from_pdf(uploaded_file) else: doc = DocumentFile.from_images(uploaded_file) try: page = doc[page_idx - 1] except IndexError: page = doc[-1] img = page predictor = load_predictor( det_arch=det_arch, reco_arch=reco_arch, assume_straight_pages=assume_straight_pages, straighten_pages=straighten_pages, export_as_straight_boxes=export_as_straight_boxes, detect_language=detect_language, bin_thresh=bin_thresh, box_thresh=box_thresh, disable_crop_orientation=disable_crop_orientation, disable_page_orientation=disable_page_orientation, ) out = predictor([page]) page_rendered = out.pages[0].render() fig = visualize_page( out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False ) out_img = matplotlib_to_pil(fig) simplified_out = _simplify(page_rendered, t2k_access_token) return img, out_img, page_rendered, simplified_out def simplify_direct(text: str, t2k_access_token: str = "") -> str: """Simplify text entered directly by the user Args: ---- text: text to simplify t2k_access_token: Access token for the T2K simplification API Returns: ------- simplified text """ if not text or not text.strip(): return "Bitte geben Sie einen Text ein, der vereinfacht werden soll." return _simplify(text.strip(), t2k_access_token) # Attempt to obtain an API access token at startup t2k_access_token = os.getenv("SIMPLIFICATION_TOKEN", "") DISCLAIMER_HTML = """