File size: 3,876 Bytes
89f9504
 
0a5fe17
89f9504
 
0a5fe17
89f9504
 
 
 
 
 
 
 
 
0a5fe17
89f9504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a5fe17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89f9504
 
 
 
 
 
 
 
 
 
 
 
 
0a5fe17
 
 
 
 
 
89f9504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6703a47
89f9504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations

import importlib.util
import os
from functools import lru_cache
from types import ModuleType
from typing import Optional

import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image

DEFAULT_MODEL_REPO = os.environ.get("OCR_MODEL_REPO_ID", "metythorn/ViTOCR-base")
DEFAULT_MODEL_FILENAME = os.environ.get("OCR_MODEL_FILENAME", "model.onnx")
DEFAULT_CONFIG_FILENAME = os.environ.get("OCR_CONFIG_FILENAME", "config.json")
DEFAULT_INFERENCE_FILENAME = os.environ.get("OCR_INFERENCE_FILENAME", "onnx_inference.py")
DEFAULT_REVISION = os.environ.get("OCR_MODEL_REVISION")


def _resolve_providers() -> list[str]:
    raw = os.environ.get("OCR_ONNX_PROVIDERS", "CPUExecutionProvider")
    providers = [p.strip() for p in raw.split(",") if p.strip()]
    return providers or ["CPUExecutionProvider"]


def _download_from_hub(filename: str) -> str:
    token = os.environ.get("HF_TOKEN")
    return hf_hub_download(
        repo_id=DEFAULT_MODEL_REPO,
        filename=filename,
        repo_type="model",
        revision=DEFAULT_REVISION,
        token=token,
    )


@lru_cache(maxsize=1)
def _load_inference_module() -> ModuleType:
    inference_path = os.environ.get("OCR_ONNX_INFERENCE_PATH")
    if inference_path:
        if not os.path.exists(inference_path):
            raise FileNotFoundError(f"OCR_ONNX_INFERENCE_PATH does not exist: {inference_path}")
    else:
        inference_path = _download_from_hub(DEFAULT_INFERENCE_FILENAME)

    spec = importlib.util.spec_from_file_location("onnx_inference_runtime", inference_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load module spec from: {inference_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


@lru_cache(maxsize=1)
def _get_predictor():
    model_path = os.environ.get("OCR_ONNX_PATH")
    config_path = os.environ.get("OCR_CONFIG_PATH")

    if model_path and not os.path.exists(model_path):
        raise FileNotFoundError(f"OCR_ONNX_PATH does not exist: {model_path}")
    if config_path and not os.path.exists(config_path):
        raise FileNotFoundError(f"OCR_CONFIG_PATH does not exist: {config_path}")

    if not model_path:
        model_path = _download_from_hub(DEFAULT_MODEL_FILENAME)
    if not config_path:
        config_path = _download_from_hub(DEFAULT_CONFIG_FILENAME)

    inference_module = _load_inference_module()
    predictor_cls = getattr(inference_module, "ONNXPredictor", None)
    if predictor_cls is None:
        raise AttributeError("onnx_inference.py does not define ONNXPredictor")

    return predictor_cls(
        model_path=model_path,
        config_path=config_path,
        providers=_resolve_providers(),
    )


def ocr_image(image: Optional[Image.Image]) -> str:
    if image is None:
        return ""
    try:
        predictor = _get_predictor()
        return predictor.predict(image)
    except Exception as exc:
        return f"OCR error: {exc}"


def build_app() -> gr.Blocks:
    with gr.Blocks(title="ViTOCR ONNX Demo") as demo:
        gr.Markdown(
            "Upload an image to test Khmer OCR. Inference on CPU can be slow, so please be patient."
        )

        with gr.Row():
            input_image = gr.Image(
                label="Input Image",
                type="pil",
                sources=["upload", "clipboard", "webcam"],
                image_mode="RGB",
            )
            output_text = gr.Textbox(label="Recognized Text", lines=8)

        run_btn = gr.Button("Run OCR", variant="primary")

        run_btn.click(fn=ocr_image, inputs=input_image, outputs=output_text)
        input_image.change(fn=ocr_image, inputs=input_image, outputs=output_text)

    return demo


if __name__ == "__main__":
    app = build_app()
    app.launch(server_name="0.0.0.0")