innoconf-demo / app.py
t2k-list's picture
😱
b0ae3c9
Raw
History Blame Contribute Delete
16.2 kB
import io
from typing import Any
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from PIL import Image
import requests
from dotenv import load_dotenv
from onnxtr.io import DocumentFile
from onnxtr.models import EngineConfig, from_hub, ocr_predictor
from onnxtr.models.predictor import OCRPredictor
from onnxtr.utils.visualization import visualize_page
load_dotenv()
DET_ARCHS: list[str] = [
"fast_base",
"fast_small",
"fast_tiny",
"db_resnet50",
"db_resnet34",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
]
RECO_ARCHS: list[str] = [
"crnn_vgg16_bn",
"crnn_mobilenet_v3_small",
"crnn_mobilenet_v3_large",
"master",
"sar_resnet31",
"vitstr_small",
"vitstr_base",
"parseq",
"viptr_tiny",
]
CUSTOM_RECO_ARCHS: list[str] = [
"Felix92/onnxtr-parseq-multilingual-v1",
]
def load_predictor(
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
export_as_straight_boxes: bool,
detect_language: bool,
bin_thresh: float,
box_thresh: float,
disable_crop_orientation: bool = False,
disable_page_orientation: bool = False,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Args:
----
det_arch: detection architecture
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
disable_crop_orientation: whether to disable crop orientation or not
disable_page_orientation: whether to disable page orientation or not
straighten_pages: whether to straighten rotated pages or not
export_as_straight_boxes: whether to export straight boxes
detect_language: whether to detect the language of the text
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
Returns:
-------
instance of OCRPredictor
"""
engine_cfg = EngineConfig(
providers=[
("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})
]
)
predictor = ocr_predictor(
det_arch=det_arch,
reco_arch=reco_arch
if reco_arch not in CUSTOM_RECO_ARCHS
else from_hub(reco_arch),
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
detect_language=detect_language,
export_as_straight_boxes=export_as_straight_boxes,
detect_orientation=not assume_straight_pages,
disable_crop_orientation=disable_crop_orientation,
disable_page_orientation=disable_page_orientation,
det_engine_cfg=engine_cfg,
reco_engine_cfg=engine_cfg,
clf_engine_cfg=engine_cfg,
)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor
def forward_image(predictor: OCRPredictor, image: np.ndarray) -> np.ndarray:
"""Forward an image through the predictor
Args:
----
predictor: instance of OCRPredictor
image: image to process
Returns:
-------
segmentation map
"""
processed_batches = predictor.det_predictor.pre_processor([image])
out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
seg_map = out["out_map"]
return seg_map
def matplotlib_to_pil(fig: Figure | np.ndarray) -> Image.Image:
"""Convert a matplotlib figure to a PIL image
Args:
----
fig: matplotlib figure or numpy array
Returns:
-------
PIL image
"""
buf = io.BytesIO()
if isinstance(fig, Figure):
fig.savefig(buf)
else:
plt.imsave(buf, fig)
buf.seek(0)
return Image.open(buf)
def _simplify(text: str, access_token: str) -> str:
"""Simplify text using the T2K API
Args:
----
text: text to simplify
access_token: API token for the T2K API
Returns:
-------
simplified text
"""
if not access_token:
return (
"Kein API-Zugriffstoken verfΓΌgbar. Bitte setzen Sie SIMPLIFICATION_TOKEN."
)
url = os.getenv("SIMPLIFICATION_URL", "")
json_data = {
"stream": False,
"inputText": text,
"maxNewTokens": 256,
"batchSize": 4,
"decodingStrategy": "greedy",
"separateCompounds": False,
"filterComplexWords": False,
"simplificationLevel": "A1",
}
headers = {
"X-API-Token": access_token,
"Content-Type": "application/json",
"Accept": "application/json",
}
response = requests.post(url, json=json_data, headers=headers)
if response.status_code == 200:
return response.json()["result"]
else:
return f"Fehler in der Vereinfachungs-API (HTTP {response.status_code})"
def analyze_page(
uploaded_file: Any,
page_idx: int,
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
disable_crop_orientation: bool,
disable_page_orientation: bool,
straighten_pages: bool,
export_as_straight_boxes: bool,
detect_language: bool,
bin_thresh: float,
box_thresh: float,
t2k_access_token: str = "",
):
"""Analyze a page
Args:
----
uploaded_file: file to analyze
page_idx: index of the page to analyze
det_arch: detection architecture
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
disable_crop_orientation: whether to disable crop orientation or not
disable_page_orientation: whether to disable page orientation or not
straighten_pages: whether to straighten rotated pages or not
export_as_straight_boxes: whether to export straight boxes
detect_language: whether to detect the language of the text
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
t2k_access_token: Access token for the T2K simplification API
Returns:
-------
input image, output image, OCR output, simplified output
"""
if uploaded_file is None:
return None, None, "Bitte laden Sie ein Dokument hoch", None
if uploaded_file.name.endswith(".pdf"):
doc = DocumentFile.from_pdf(uploaded_file)
else:
doc = DocumentFile.from_images(uploaded_file)
try:
page = doc[page_idx - 1]
except IndexError:
page = doc[-1]
img = page
predictor = load_predictor(
det_arch=det_arch,
reco_arch=reco_arch,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=export_as_straight_boxes,
detect_language=detect_language,
bin_thresh=bin_thresh,
box_thresh=box_thresh,
disable_crop_orientation=disable_crop_orientation,
disable_page_orientation=disable_page_orientation,
)
out = predictor([page])
page_rendered = out.pages[0].render()
fig = visualize_page(
out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False
)
out_img = matplotlib_to_pil(fig)
simplified_out = _simplify(page_rendered, t2k_access_token)
return img, out_img, page_rendered, simplified_out
def simplify_direct(text: str, t2k_access_token: str = "") -> str:
"""Simplify text entered directly by the user
Args:
----
text: text to simplify
t2k_access_token: Access token for the T2K simplification API
Returns:
-------
simplified text
"""
if not text or not text.strip():
return "Bitte geben Sie einen Text ein, der vereinfacht werden soll."
return _simplify(text.strip(), t2k_access_token)
# Attempt to obtain an API access token at startup
t2k_access_token = os.getenv("SIMPLIFICATION_TOKEN", "")
DISCLAIMER_HTML = """
<div style="
background: linear-gradient(135deg, #e3f2fd 0%, #f0f7ff 100%);
border-bottom: 1px solid #bae6fd;
padding: 16px 24px;
border-radius: 0 0 12px 12px;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
margin-bottom: 24px;
font-family: sans-serif;
font-size: 14px;
color: #1e40af;
">
<div style="display: flex; align-items: center; justify-content: space-between; max-width: 1200px; margin: 0 auto;">
<div style="display: flex; align-items: center; gap: 12px; flex: 1;">
<span style="font-size: 20px;">i</span>
<div style="flex: 1;">
<strong style="color: #0c4a6e; font-size: 15px; display: block; margin-bottom: 4px;">
Hinweis zur Nutzung
</strong>
<div style="color: #374151; line-height: 1.5;">
Dieser Dienst nutzt KI-Modelle zur Textverarbeitung und -vereinfachung.
Die Ergebnisse werden automatisch generiert und kΓΆnnen Fehler enthalten.
Es werden keine Daten gespeichert - alle Verarbeitungen erfolgen ausschließlich im Arbeitsspeicher.
<span style="color: #0c4a6e; font-weight: 500;">t2k GmbH ΓΌbernimmt keine Haftung fΓΌr die generierten Ausgaben.</span>
</div>
</div>
</div>
<div style="margin-left: 20px;">
<button onclick="this.parentElement.parentElement.parentElement.style.display='none'"
style="
background: white;
border: 1px solid #bae6fd;
color: #0c4a6e;
border-radius: 8px;
padding: 8px 20px;
cursor: pointer;
font-size: 14px;
font-weight: 500;
transition: all 0.2s;
"
onmouseover="this.style.background='#f0f9ff'; this.style.borderColor='#67e8f9'"
onmouseout="this.style.background='white'; this.style.borderColor='#bae6fd'">
Verstanden
</button>
</div>
</div>
</div>
"""
with gr.Blocks(fill_height=True) as demo:
gr.HTML(DISCLAIMER_HTML)
gr.HTML(
"""
<div style="text-align: center;">
<h1>OCR + Sprachvereinfachungs-Demo</h1>
</div>
"""
)
with gr.Tabs():
# ── Tab 1: Document OCR + Simplification ──────────────────────────────
with gr.TabItem("Dokument-OCR"):
gr.HTML(
"""
<h3>1. Dokument hochladen (PDF, JPG oder PNG)</h3>
<h3>2. WΓ€hlen Sie die Modellarchitekturen fΓΌr Texterkennung und -erkennung aus</h3>
<h3>3. DrΓΌcken Sie <em>Seite analysieren</em>, um OCR auszufΓΌhren und den extrahierten Text zu vereinfachen</h3>
"""
)
with gr.Row():
with gr.Column(scale=1):
upload = gr.File(
label="Datei hochladen [JPG | PNG | PDF]",
file_types=[".pdf", ".jpg", ".png"],
)
page_selection = gr.Slider(
minimum=1, maximum=10, step=1, value=1, label="Seitenauswahl"
)
det_model = gr.Dropdown(
choices=DET_ARCHS,
value=DET_ARCHS[-1],
label="Text-Erkennungsmodell",
)
reco_model = gr.Dropdown(
choices=RECO_ARCHS + CUSTOM_RECO_ARCHS,
value=CUSTOM_RECO_ARCHS[0],
label="Text-Recognition Modell",
)
assume_straight = gr.Checkbox(
value=True, label="Gerade Seiten annehmen"
)
disable_crop_orientation = gr.Checkbox(
value=False, label="Zuschneidorientierung deaktivieren"
)
disable_page_orientation = gr.Checkbox(
value=False, label="Seitenorientierung deaktivieren"
)
straighten = gr.Checkbox(value=False, label="Seiten begraben")
export_as_straight_boxes = gr.Checkbox(
value=True, label="Als gerade Boxen exportieren"
)
det_language = gr.Checkbox(value=True, label="Sprache erkennen")
binarization_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.1,
label="Binarisierungsschwelle",
)
box_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.1,
step=0.1,
label="Box-Schwelle",
)
analyze_button = gr.Button("Seite analysieren", variant="primary")
with gr.Column(scale=3):
with gr.Row():
input_image = gr.Image(
label="Eingabeseite", width=1000, height=500
)
output_image = gr.Image(
label="Ausgabeseite", width=1000, height=500
)
with gr.Row():
ocr_output = gr.Textbox(label="OCR-Ausgabe", scale=1, lines=20)
ocr_simplified_out = gr.Textbox(
label="Vereinfachte Ausgabe", scale=1, lines=20
)
analyze_button.click(
lambda *inputs: analyze_page(
*inputs, t2k_access_token=t2k_access_token
),
inputs=[
upload,
page_selection,
det_model,
reco_model,
assume_straight,
disable_crop_orientation,
disable_page_orientation,
straighten,
export_as_straight_boxes,
det_language,
binarization_threshold,
box_threshold,
],
outputs=[input_image, output_image, ocr_output, ocr_simplified_out],
)
# ── Tab 2: Direct Text Simplification ─────────────────────────────────
with gr.TabItem("Direkte Textvereinfachung"):
gr.HTML(
"""
<h3>Geben Sie unten einen Text ein oder fΓΌgen Sie ihn ein und drΓΌcken Sie <em>Text vereinfachen</em>, um eine vereinfachte Version zu erhalten.</h3>
"""
)
with gr.Row():
with gr.Column(scale=1):
direct_text_input = gr.Textbox(
label="Eingabetext",
placeholder="Geben Sie den zu vereinfachenden Text ein oder fΓΌgen Sie ihn ein…",
lines=20,
)
simplify_button = gr.Button("Text vereinfachen", variant="primary")
with gr.Column(scale=1):
direct_simplified_out = gr.Textbox(
label="Vereinfachte Ausgabe", lines=20
)
simplify_button.click(
lambda text: simplify_direct(text, t2k_access_token=t2k_access_token),
inputs=[direct_text_input],
outputs=[direct_simplified_out],
)
demo.launch(inbrowser=True, allowed_paths=["./data/logo.jpg"])
# TODO: Add bounding boxes display
# TODO: beautify + t2k logo