|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
os.environ.setdefault("FLAGS_use_mkldnn", "0") |
|
|
os.environ.setdefault("FLAGS_enable_mkldnn", "0") |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
os.environ.setdefault("KMP_BLOCKTIME", "0") |
|
|
|
|
|
|
|
|
os.environ.setdefault("GRADIO_SERVER_NAME", "0.0.0.0") |
|
|
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") |
|
|
|
|
|
import io |
|
|
import sys |
|
|
import json |
|
|
import traceback |
|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import fitz |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
from paddleocr import PaddleOCR |
|
|
|
|
|
|
|
|
LANG = os.getenv("OCR_LANG", "en") |
|
|
USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() == "true" |
|
|
DET = os.getenv("OCR_DET_MODEL", "ch_PP-OCRv4_det") |
|
|
REC = os.getenv("OCR_REC_MODEL", "en_PP-OCRv4") |
|
|
CLS = True |
|
|
CONF_THRESHOLD = float(os.getenv("OCR_CONF_THRESHOLD", "0.0")) |
|
|
|
|
|
def _pil_to_cv(img: Image.Image) -> np.ndarray: |
|
|
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
def _build_ocr(use_cls: bool) -> PaddleOCR: |
|
|
return PaddleOCR( |
|
|
use_angle_cls=use_cls, |
|
|
lang=LANG, |
|
|
use_gpu=USE_GPU, |
|
|
det_model_dir=None, |
|
|
rec_model_dir=None, |
|
|
show_log=False |
|
|
) |
|
|
|
|
|
|
|
|
_OCR = _build_ocr(CLS) |
|
|
|
|
|
def ocr_image(pil_img: Image.Image) -> List[Tuple[str, float]]: |
|
|
img_cv = _pil_to_cv(pil_img) |
|
|
|
|
|
def _run(ocr: PaddleOCR, cls_flag: bool): |
|
|
return ocr.ocr(img_cv, cls=cls_flag) |
|
|
|
|
|
try: |
|
|
result = _run(_OCR, CLS) |
|
|
except RuntimeError as e: |
|
|
msg = str(e).lower() |
|
|
if "primitive" in msg or "mkldnn" in msg or "predictor.run" in msg: |
|
|
|
|
|
fallback_ocr = _build_ocr(False) |
|
|
result = _run(fallback_ocr, False) |
|
|
else: |
|
|
raise |
|
|
|
|
|
lines: List[Tuple[str, float]] = [] |
|
|
if not result: |
|
|
return lines |
|
|
for line in result[0]: |
|
|
txt = line[1][0] |
|
|
conf = float(line[1][1]) |
|
|
if conf >= CONF_THRESHOLD: |
|
|
lines.append((txt, conf)) |
|
|
return lines |
|
|
|
|
|
def read_image(filepath: str) -> Image.Image: |
|
|
with Image.open(filepath) as im: |
|
|
return im.convert("RGB") |
|
|
|
|
|
def read_pdf_pages(filepath: str): |
|
|
pages = [] |
|
|
with fitz.open(filepath) as doc: |
|
|
for page in doc: |
|
|
mat = fitz.Matrix(2, 2) |
|
|
pix = page.get_pixmap(matrix=mat, alpha=False) |
|
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
|
pages.append(img) |
|
|
return pages |
|
|
|
|
|
def extract_text_from_file(filepath: str) -> str: |
|
|
lower = filepath.lower() |
|
|
if lower.endswith(".pdf"): |
|
|
texts = [] |
|
|
for i, img in enumerate(read_pdf_pages(filepath), start=1): |
|
|
lines = ocr_image(img) |
|
|
page_text = "\n".join([t for t, _ in lines]) |
|
|
texts.append(f"--- Page {i} ---\n{page_text}".strip()) |
|
|
return "\n\n".join([t for t in texts if t]) |
|
|
elif lower.endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".webp")): |
|
|
lines = ocr_image(read_image(filepath)) |
|
|
return "\n".join([t for t, _ in lines]).strip() |
|
|
else: |
|
|
raise ValueError("Unsupported file type. Upload an image or a PDF.") |
|
|
|
|
|
def infer(file_obj) -> str: |
|
|
try: |
|
|
if file_obj is None: |
|
|
return "No file uploaded." |
|
|
filepath = file_obj.name if hasattr(file_obj, "name") else str(file_obj) |
|
|
text = extract_text_from_file(filepath) |
|
|
print("\n===== OCR RAW TEXT =====\n", text, "\n===== END =====\n", flush=True) |
|
|
return text or "[No text detected]" |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return f"Error during OCR: {e}" |
|
|
|
|
|
TITLE = "PaddleOCR Text Extractor (Images & PDFs)" |
|
|
DESC = "Upload an image or PDF. Runs PP-OCRv4 on CPU with Space-safe settings." |
|
|
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
|
gr.Markdown(f"# {TITLE}\n{DESC}") |
|
|
with gr.Row(): |
|
|
file_in = gr.File(label="Upload Image or PDF", file_count="single", file_types=["image", ".pdf"]) |
|
|
out = gr.Textbox(label="Extracted Text", lines=25, show_copy_button=True) |
|
|
run_btn = gr.Button("Run OCR", variant="primary") |
|
|
run_btn.click(fn=infer, inputs=[file_in], outputs=[out]) |
|
|
file_in.change(fn=infer, inputs=[file_in], outputs=[out]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), |
|
|
server_port=int(os.getenv("PORT", "7860")), |
|
|
show_error=True) |
|
|
|