ViTOCR-base / app.py
metythorn's picture
add torchvision
6703a47
from __future__ import annotations
import importlib.util
import os
from functools import lru_cache
from types import ModuleType
from typing import Optional
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
DEFAULT_MODEL_REPO = os.environ.get("OCR_MODEL_REPO_ID", "metythorn/ViTOCR-base")
DEFAULT_MODEL_FILENAME = os.environ.get("OCR_MODEL_FILENAME", "model.onnx")
DEFAULT_CONFIG_FILENAME = os.environ.get("OCR_CONFIG_FILENAME", "config.json")
DEFAULT_INFERENCE_FILENAME = os.environ.get("OCR_INFERENCE_FILENAME", "onnx_inference.py")
DEFAULT_REVISION = os.environ.get("OCR_MODEL_REVISION")
def _resolve_providers() -> list[str]:
raw = os.environ.get("OCR_ONNX_PROVIDERS", "CPUExecutionProvider")
providers = [p.strip() for p in raw.split(",") if p.strip()]
return providers or ["CPUExecutionProvider"]
def _download_from_hub(filename: str) -> str:
token = os.environ.get("HF_TOKEN")
return hf_hub_download(
repo_id=DEFAULT_MODEL_REPO,
filename=filename,
repo_type="model",
revision=DEFAULT_REVISION,
token=token,
)
@lru_cache(maxsize=1)
def _load_inference_module() -> ModuleType:
inference_path = os.environ.get("OCR_ONNX_INFERENCE_PATH")
if inference_path:
if not os.path.exists(inference_path):
raise FileNotFoundError(f"OCR_ONNX_INFERENCE_PATH does not exist: {inference_path}")
else:
inference_path = _download_from_hub(DEFAULT_INFERENCE_FILENAME)
spec = importlib.util.spec_from_file_location("onnx_inference_runtime", inference_path)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load module spec from: {inference_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
@lru_cache(maxsize=1)
def _get_predictor():
model_path = os.environ.get("OCR_ONNX_PATH")
config_path = os.environ.get("OCR_CONFIG_PATH")
if model_path and not os.path.exists(model_path):
raise FileNotFoundError(f"OCR_ONNX_PATH does not exist: {model_path}")
if config_path and not os.path.exists(config_path):
raise FileNotFoundError(f"OCR_CONFIG_PATH does not exist: {config_path}")
if not model_path:
model_path = _download_from_hub(DEFAULT_MODEL_FILENAME)
if not config_path:
config_path = _download_from_hub(DEFAULT_CONFIG_FILENAME)
inference_module = _load_inference_module()
predictor_cls = getattr(inference_module, "ONNXPredictor", None)
if predictor_cls is None:
raise AttributeError("onnx_inference.py does not define ONNXPredictor")
return predictor_cls(
model_path=model_path,
config_path=config_path,
providers=_resolve_providers(),
)
def ocr_image(image: Optional[Image.Image]) -> str:
if image is None:
return ""
try:
predictor = _get_predictor()
return predictor.predict(image)
except Exception as exc:
return f"OCR error: {exc}"
def build_app() -> gr.Blocks:
with gr.Blocks(title="ViTOCR ONNX Demo") as demo:
gr.Markdown(
"Upload an image to test Khmer OCR. Inference on CPU can be slow, so please be patient."
)
with gr.Row():
input_image = gr.Image(
label="Input Image",
type="pil",
sources=["upload", "clipboard", "webcam"],
image_mode="RGB",
)
output_text = gr.Textbox(label="Recognized Text", lines=8)
run_btn = gr.Button("Run OCR", variant="primary")
run_btn.click(fn=ocr_image, inputs=input_image, outputs=output_text)
input_image.change(fn=ocr_image, inputs=input_image, outputs=output_text)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0")