Spaces:
Sleeping
Sleeping
| import io | |
| from typing import Any | |
| import os | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.figure import Figure | |
| from PIL import Image | |
| import requests | |
| from dotenv import load_dotenv | |
| from onnxtr.io import DocumentFile | |
| from onnxtr.models import EngineConfig, from_hub, ocr_predictor | |
| from onnxtr.models.predictor import OCRPredictor | |
| from onnxtr.utils.visualization import visualize_page | |
| load_dotenv() | |
| DET_ARCHS: list[str] = [ | |
| "fast_base", | |
| "fast_small", | |
| "fast_tiny", | |
| "db_resnet50", | |
| "db_resnet34", | |
| "db_mobilenet_v3_large", | |
| "linknet_resnet18", | |
| "linknet_resnet34", | |
| "linknet_resnet50", | |
| ] | |
| RECO_ARCHS: list[str] = [ | |
| "crnn_vgg16_bn", | |
| "crnn_mobilenet_v3_small", | |
| "crnn_mobilenet_v3_large", | |
| "master", | |
| "sar_resnet31", | |
| "vitstr_small", | |
| "vitstr_base", | |
| "parseq", | |
| "viptr_tiny", | |
| ] | |
| CUSTOM_RECO_ARCHS: list[str] = [ | |
| "Felix92/onnxtr-parseq-multilingual-v1", | |
| ] | |
| def load_predictor( | |
| det_arch: str, | |
| reco_arch: str, | |
| assume_straight_pages: bool, | |
| straighten_pages: bool, | |
| export_as_straight_boxes: bool, | |
| detect_language: bool, | |
| bin_thresh: float, | |
| box_thresh: float, | |
| disable_crop_orientation: bool = False, | |
| disable_page_orientation: bool = False, | |
| ) -> OCRPredictor: | |
| """Load a predictor from doctr.models | |
| Args: | |
| ---- | |
| det_arch: detection architecture | |
| reco_arch: recognition architecture | |
| assume_straight_pages: whether to assume straight pages or not | |
| disable_crop_orientation: whether to disable crop orientation or not | |
| disable_page_orientation: whether to disable page orientation or not | |
| straighten_pages: whether to straighten rotated pages or not | |
| export_as_straight_boxes: whether to export straight boxes | |
| detect_language: whether to detect the language of the text | |
| bin_thresh: binarization threshold for the segmentation map | |
| box_thresh: minimal objectness score to consider a box | |
| Returns: | |
| ------- | |
| instance of OCRPredictor | |
| """ | |
| engine_cfg = EngineConfig( | |
| providers=[ | |
| ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"}) | |
| ] | |
| ) | |
| predictor = ocr_predictor( | |
| det_arch=det_arch, | |
| reco_arch=reco_arch | |
| if reco_arch not in CUSTOM_RECO_ARCHS | |
| else from_hub(reco_arch), | |
| assume_straight_pages=assume_straight_pages, | |
| straighten_pages=straighten_pages, | |
| detect_language=detect_language, | |
| export_as_straight_boxes=export_as_straight_boxes, | |
| detect_orientation=not assume_straight_pages, | |
| disable_crop_orientation=disable_crop_orientation, | |
| disable_page_orientation=disable_page_orientation, | |
| det_engine_cfg=engine_cfg, | |
| reco_engine_cfg=engine_cfg, | |
| clf_engine_cfg=engine_cfg, | |
| ) | |
| predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh | |
| predictor.det_predictor.model.postprocessor.box_thresh = box_thresh | |
| return predictor | |
| def forward_image(predictor: OCRPredictor, image: np.ndarray) -> np.ndarray: | |
| """Forward an image through the predictor | |
| Args: | |
| ---- | |
| predictor: instance of OCRPredictor | |
| image: image to process | |
| Returns: | |
| ------- | |
| segmentation map | |
| """ | |
| processed_batches = predictor.det_predictor.pre_processor([image]) | |
| out = predictor.det_predictor.model(processed_batches[0], return_model_output=True) | |
| seg_map = out["out_map"] | |
| return seg_map | |
| def matplotlib_to_pil(fig: Figure | np.ndarray) -> Image.Image: | |
| """Convert a matplotlib figure to a PIL image | |
| Args: | |
| ---- | |
| fig: matplotlib figure or numpy array | |
| Returns: | |
| ------- | |
| PIL image | |
| """ | |
| buf = io.BytesIO() | |
| if isinstance(fig, Figure): | |
| fig.savefig(buf) | |
| else: | |
| plt.imsave(buf, fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def _simplify(text: str, access_token: str) -> str: | |
| """Simplify text using the T2K API | |
| Args: | |
| ---- | |
| text: text to simplify | |
| access_token: API token for the T2K API | |
| Returns: | |
| ------- | |
| simplified text | |
| """ | |
| if not access_token: | |
| return ( | |
| "Kein API-Zugriffstoken verfΓΌgbar. Bitte setzen Sie SIMPLIFICATION_TOKEN." | |
| ) | |
| url = os.getenv("SIMPLIFICATION_URL", "") | |
| json_data = { | |
| "stream": False, | |
| "inputText": text, | |
| "maxNewTokens": 256, | |
| "batchSize": 4, | |
| "decodingStrategy": "greedy", | |
| "separateCompounds": False, | |
| "filterComplexWords": False, | |
| "simplificationLevel": "A1", | |
| } | |
| headers = { | |
| "X-API-Token": access_token, | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", | |
| } | |
| response = requests.post(url, json=json_data, headers=headers) | |
| if response.status_code == 200: | |
| return response.json()["result"] | |
| else: | |
| return f"Fehler in der Vereinfachungs-API (HTTP {response.status_code})" | |
| def analyze_page( | |
| uploaded_file: Any, | |
| page_idx: int, | |
| det_arch: str, | |
| reco_arch: str, | |
| assume_straight_pages: bool, | |
| disable_crop_orientation: bool, | |
| disable_page_orientation: bool, | |
| straighten_pages: bool, | |
| export_as_straight_boxes: bool, | |
| detect_language: bool, | |
| bin_thresh: float, | |
| box_thresh: float, | |
| t2k_access_token: str = "", | |
| ): | |
| """Analyze a page | |
| Args: | |
| ---- | |
| uploaded_file: file to analyze | |
| page_idx: index of the page to analyze | |
| det_arch: detection architecture | |
| reco_arch: recognition architecture | |
| assume_straight_pages: whether to assume straight pages or not | |
| disable_crop_orientation: whether to disable crop orientation or not | |
| disable_page_orientation: whether to disable page orientation or not | |
| straighten_pages: whether to straighten rotated pages or not | |
| export_as_straight_boxes: whether to export straight boxes | |
| detect_language: whether to detect the language of the text | |
| bin_thresh: binarization threshold for the segmentation map | |
| box_thresh: minimal objectness score to consider a box | |
| t2k_access_token: Access token for the T2K simplification API | |
| Returns: | |
| ------- | |
| input image, output image, OCR output, simplified output | |
| """ | |
| if uploaded_file is None: | |
| return None, None, "Bitte laden Sie ein Dokument hoch", None | |
| if uploaded_file.name.endswith(".pdf"): | |
| doc = DocumentFile.from_pdf(uploaded_file) | |
| else: | |
| doc = DocumentFile.from_images(uploaded_file) | |
| try: | |
| page = doc[page_idx - 1] | |
| except IndexError: | |
| page = doc[-1] | |
| img = page | |
| predictor = load_predictor( | |
| det_arch=det_arch, | |
| reco_arch=reco_arch, | |
| assume_straight_pages=assume_straight_pages, | |
| straighten_pages=straighten_pages, | |
| export_as_straight_boxes=export_as_straight_boxes, | |
| detect_language=detect_language, | |
| bin_thresh=bin_thresh, | |
| box_thresh=box_thresh, | |
| disable_crop_orientation=disable_crop_orientation, | |
| disable_page_orientation=disable_page_orientation, | |
| ) | |
| out = predictor([page]) | |
| page_rendered = out.pages[0].render() | |
| fig = visualize_page( | |
| out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False | |
| ) | |
| out_img = matplotlib_to_pil(fig) | |
| simplified_out = _simplify(page_rendered, t2k_access_token) | |
| return img, out_img, page_rendered, simplified_out | |
| def simplify_direct(text: str, t2k_access_token: str = "") -> str: | |
| """Simplify text entered directly by the user | |
| Args: | |
| ---- | |
| text: text to simplify | |
| t2k_access_token: Access token for the T2K simplification API | |
| Returns: | |
| ------- | |
| simplified text | |
| """ | |
| if not text or not text.strip(): | |
| return "Bitte geben Sie einen Text ein, der vereinfacht werden soll." | |
| return _simplify(text.strip(), t2k_access_token) | |
| # Attempt to obtain an API access token at startup | |
| t2k_access_token = os.getenv("SIMPLIFICATION_TOKEN", "") | |
| DISCLAIMER_HTML = """ | |
| <div style=" | |
| background: linear-gradient(135deg, #e3f2fd 0%, #f0f7ff 100%); | |
| border-bottom: 1px solid #bae6fd; | |
| padding: 16px 24px; | |
| border-radius: 0 0 12px 12px; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.08); | |
| margin-bottom: 24px; | |
| font-family: sans-serif; | |
| font-size: 14px; | |
| color: #1e40af; | |
| "> | |
| <div style="display: flex; align-items: center; justify-content: space-between; max-width: 1200px; margin: 0 auto;"> | |
| <div style="display: flex; align-items: center; gap: 12px; flex: 1;"> | |
| <span style="font-size: 20px;">i</span> | |
| <div style="flex: 1;"> | |
| <strong style="color: #0c4a6e; font-size: 15px; display: block; margin-bottom: 4px;"> | |
| Hinweis zur Nutzung | |
| </strong> | |
| <div style="color: #374151; line-height: 1.5;"> | |
| Dieser Dienst nutzt KI-Modelle zur Textverarbeitung und -vereinfachung. | |
| Die Ergebnisse werden automatisch generiert und kΓΆnnen Fehler enthalten. | |
| Es werden keine Daten gespeichert - alle Verarbeitungen erfolgen ausschlieΓlich im Arbeitsspeicher. | |
| <span style="color: #0c4a6e; font-weight: 500;">t2k GmbH ΓΌbernimmt keine Haftung fΓΌr die generierten Ausgaben.</span> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="margin-left: 20px;"> | |
| <button onclick="this.parentElement.parentElement.parentElement.style.display='none'" | |
| style=" | |
| background: white; | |
| border: 1px solid #bae6fd; | |
| color: #0c4a6e; | |
| border-radius: 8px; | |
| padding: 8px 20px; | |
| cursor: pointer; | |
| font-size: 14px; | |
| font-weight: 500; | |
| transition: all 0.2s; | |
| " | |
| onmouseover="this.style.background='#f0f9ff'; this.style.borderColor='#67e8f9'" | |
| onmouseout="this.style.background='white'; this.style.borderColor='#bae6fd'"> | |
| Verstanden | |
| </button> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.HTML(DISCLAIMER_HTML) | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <h1>OCR + Sprachvereinfachungs-Demo</h1> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Document OCR + Simplification ββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Dokument-OCR"): | |
| gr.HTML( | |
| """ | |
| <h3>1. Dokument hochladen (PDF, JPG oder PNG)</h3> | |
| <h3>2. WΓ€hlen Sie die Modellarchitekturen fΓΌr Texterkennung und -erkennung aus</h3> | |
| <h3>3. DrΓΌcken Sie <em>Seite analysieren</em>, um OCR auszufΓΌhren und den extrahierten Text zu vereinfachen</h3> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| upload = gr.File( | |
| label="Datei hochladen [JPG | PNG | PDF]", | |
| file_types=[".pdf", ".jpg", ".png"], | |
| ) | |
| page_selection = gr.Slider( | |
| minimum=1, maximum=10, step=1, value=1, label="Seitenauswahl" | |
| ) | |
| det_model = gr.Dropdown( | |
| choices=DET_ARCHS, | |
| value=DET_ARCHS[-1], | |
| label="Text-Erkennungsmodell", | |
| ) | |
| reco_model = gr.Dropdown( | |
| choices=RECO_ARCHS + CUSTOM_RECO_ARCHS, | |
| value=CUSTOM_RECO_ARCHS[0], | |
| label="Text-Recognition Modell", | |
| ) | |
| assume_straight = gr.Checkbox( | |
| value=True, label="Gerade Seiten annehmen" | |
| ) | |
| disable_crop_orientation = gr.Checkbox( | |
| value=False, label="Zuschneidorientierung deaktivieren" | |
| ) | |
| disable_page_orientation = gr.Checkbox( | |
| value=False, label="Seitenorientierung deaktivieren" | |
| ) | |
| straighten = gr.Checkbox(value=False, label="Seiten begraben") | |
| export_as_straight_boxes = gr.Checkbox( | |
| value=True, label="Als gerade Boxen exportieren" | |
| ) | |
| det_language = gr.Checkbox(value=True, label="Sprache erkennen") | |
| binarization_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.3, | |
| step=0.1, | |
| label="Binarisierungsschwelle", | |
| ) | |
| box_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.1, | |
| step=0.1, | |
| label="Box-Schwelle", | |
| ) | |
| analyze_button = gr.Button("Seite analysieren", variant="primary") | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| input_image = gr.Image( | |
| label="Eingabeseite", width=1000, height=500 | |
| ) | |
| output_image = gr.Image( | |
| label="Ausgabeseite", width=1000, height=500 | |
| ) | |
| with gr.Row(): | |
| ocr_output = gr.Textbox(label="OCR-Ausgabe", scale=1, lines=20) | |
| ocr_simplified_out = gr.Textbox( | |
| label="Vereinfachte Ausgabe", scale=1, lines=20 | |
| ) | |
| analyze_button.click( | |
| lambda *inputs: analyze_page( | |
| *inputs, t2k_access_token=t2k_access_token | |
| ), | |
| inputs=[ | |
| upload, | |
| page_selection, | |
| det_model, | |
| reco_model, | |
| assume_straight, | |
| disable_crop_orientation, | |
| disable_page_orientation, | |
| straighten, | |
| export_as_straight_boxes, | |
| det_language, | |
| binarization_threshold, | |
| box_threshold, | |
| ], | |
| outputs=[input_image, output_image, ocr_output, ocr_simplified_out], | |
| ) | |
| # ββ Tab 2: Direct Text Simplification βββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Direkte Textvereinfachung"): | |
| gr.HTML( | |
| """ | |
| <h3>Geben Sie unten einen Text ein oder fΓΌgen Sie ihn ein und drΓΌcken Sie <em>Text vereinfachen</em>, um eine vereinfachte Version zu erhalten.</h3> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| direct_text_input = gr.Textbox( | |
| label="Eingabetext", | |
| placeholder="Geben Sie den zu vereinfachenden Text ein oder fΓΌgen Sie ihn einβ¦", | |
| lines=20, | |
| ) | |
| simplify_button = gr.Button("Text vereinfachen", variant="primary") | |
| with gr.Column(scale=1): | |
| direct_simplified_out = gr.Textbox( | |
| label="Vereinfachte Ausgabe", lines=20 | |
| ) | |
| simplify_button.click( | |
| lambda text: simplify_direct(text, t2k_access_token=t2k_access_token), | |
| inputs=[direct_text_input], | |
| outputs=[direct_simplified_out], | |
| ) | |
| demo.launch(inbrowser=True, allowed_paths=["./data/logo.jpg"]) | |
| # TODO: Add bounding boxes display | |
| # TODO: beautify + t2k logo | |