kiri-ocr / app.py
mrrtmob's picture
Update app.py
a22d92e verified
"""
Kiri OCR - Gradio Demo for Hugging Face Spaces
A lightweight OCR library for English and Khmer documents with streaming output support.
"""
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import tempfile
import os
# Global OCR instances (one per decode method)
ocr_instances = {}
def load_ocr(decode_method="accurate"):
"""Load the OCR model with specified decode method."""
from kiri_ocr import OCR
print(f"Loading OCR model with decode_method={decode_method}...")
return OCR(
model_path="mrrtmob/kiri-ocr",
det_method="db",
decode_method=decode_method,
device="cpu",
verbose=False
)
def get_ocr(decode_method="accurate"):
"""Get or create OCR instance for the specified decode method."""
global ocr_instances
if decode_method not in ocr_instances:
ocr_instances[decode_method] = load_ocr(decode_method)
return ocr_instances[decode_method]
def process_document_stream(image, decode_method):
"""
Process document image with real-time character streaming.
Args:
image: Input image (PIL Image or numpy array)
decode_method: Decode method to use (fast, accurate, or beam)
Yields:
Tuple of (annotated_image, extracted_text)
"""
if image is None:
yield None, "Please upload an image."
return
try:
ocr_engine = get_ocr(decode_method)
# Save temp file for processing (required by current API)
# Convert PIL to BGR numpy array first if needed
if isinstance(image, Image.Image):
img_array = np.array(image)
else:
img_array = image
# Handle channels
if len(img_array.shape) == 2:
img_display = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
elif img_array.shape[2] == 4:
img_display = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
else:
img_display = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
temp_path = f.name
cv2.imwrite(temp_path, img_display)
# Variables for state tracking
annotated = img_display.copy()
extracted_text = ""
current_region_text = ""
# Use the streaming generator
for chunk in ocr_engine.extract_text_stream_chars(temp_path, mode="lines"):
# Handle region boundaries
if chunk.get("region_start"):
# Draw box for new region
if "box" in chunk:
x, y, w, h = chunk["box"]
# Draw box
cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 2)
# Draw region number
cv2.putText(
annotated, str(chunk["region_number"]), (x, y - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1
)
# Add newline if not first region
if chunk["region_number"] > 1:
extracted_text += "\n"
# Append new token
token = chunk.get("token", "")
if token:
extracted_text += token
current_region_text += token
# Update display every few chars or at region boundaries to keep UI responsive
# (Gradio streaming works best with frequent updates)
if chunk.get("region_start") or chunk.get("region_finished") or len(current_region_text) % 3 == 0:
# Convert BGR back to RGB for Gradio
yield cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB), extracted_text
# Final update
yield cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB), extracted_text
# Cleanup
os.unlink(temp_path)
except Exception as e:
import traceback
yield image, f"Error: {str(e)}\n{traceback.format_exc()}"
def recognize_line_stream(image, decode_method):
"""
Stream text from single line image.
Args:
image: Input image
decode_method: Decode method to use (fast, accurate, or beam)
"""
if image is None:
yield "Please upload an image."
return
try:
ocr_engine = get_ocr(decode_method)
# Save temp file
if isinstance(image, Image.Image):
image.save("temp_line.png")
path = "temp_line.png"
else:
cv2.imwrite("temp_line.png", cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
path = "temp_line.png"
extracted_text = ""
for chunk in ocr_engine.recognize_streaming(path):
token = chunk.get("token", "")
if token:
extracted_text += token
yield extracted_text
if os.path.exists(path):
os.unlink(path)
except Exception as e:
yield f"Error: {str(e)}"
# Custom CSS
css = """
.container { max-width: 1200px; margin: auto; }
.output-text { font-family: monospace; }
"""
# Create Gradio interface
with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ⚡ Kiri OCR Demo
**Real-time OCR for English and Khmer documents**
This demo showcases the **character-by-character streaming** capability of Kiri OCR.
"""
)
with gr.Tabs():
# Document OCR Tab
with gr.TabItem("📄 Document Stream"):
gr.Markdown("Upload a document to see text appear in real-time as it's recognized.")
with gr.Row():
with gr.Column(scale=1):
doc_input = gr.Image(
label="Upload Document",
type="pil",
sources=["upload", "clipboard", "webcam"]
)
# Decode method selector
doc_decode_method = gr.Radio(
choices=["fast", "accurate", "beam"],
value="accurate",
label="Decode Method",
info="Fast: Fastest, lower accuracy | Accurate: Balanced | Beam: Slowest, highest accuracy"
)
with gr.Row():
doc_btn = gr.Button("⚡ Stream Text", variant="primary")
doc_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False)
with gr.Column(scale=1):
# Annotated image updates in real-time
doc_output_img = gr.Image(label="Live Detection")
# Text updates character-by-character
doc_output_text = gr.Textbox(
label="Streaming Text",
lines=15,
autoscroll=True,
elem_classes=["output-text"]
)
# Single Line OCR Tab
with gr.TabItem("✏️ Single Line Stream"):
gr.Markdown("Stream text recognition for a single cropped text line.")
with gr.Row():
with gr.Column(scale=1):
line_input = gr.Image(
label="Upload Text Line",
type="pil",
sources=["upload", "clipboard"]
)
# Decode method selector
line_decode_method = gr.Radio(
choices=["fast", "accurate", "beam"],
value="accurate",
label="Decode Method",
info="Fast: Fastest, lower accuracy | Accurate: Balanced | Beam: Slowest, highest accuracy"
)
with gr.Row():
line_btn = gr.Button("⚡ Stream Recognize", variant="primary")
line_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False)
with gr.Column(scale=1):
line_output_text = gr.Textbox(
label="Streaming Output",
lines=3,
elem_classes=["output-text"]
)
# Toggle buttons visibility
def toggle_doc_buttons():
return gr.update(visible=False), gr.update(visible=True)
def reset_doc_buttons():
return gr.update(visible=True), gr.update(visible=False)
doc_event = doc_btn.click(
fn=toggle_doc_buttons,
outputs=[doc_btn, doc_stop]
).then(
fn=process_document_stream,
inputs=[doc_input, doc_decode_method],
outputs=[doc_output_img, doc_output_text]
).then(
fn=reset_doc_buttons,
outputs=[doc_btn, doc_stop]
)
doc_stop.click(
fn=reset_doc_buttons,
outputs=[doc_btn, doc_stop],
cancels=[doc_event]
)
def toggle_line_buttons():
return gr.update(visible=False), gr.update(visible=True)
def reset_line_buttons():
return gr.update(visible=True), gr.update(visible=False)
line_event = line_btn.click(
fn=toggle_line_buttons,
outputs=[line_btn, line_stop]
).then(
fn=recognize_line_stream,
inputs=[line_input, line_decode_method],
outputs=line_output_text
).then(
fn=reset_line_buttons,
outputs=[line_btn, line_stop]
)
line_stop.click(
fn=reset_line_buttons,
outputs=[line_btn, line_stop],
cancels=[line_event]
)
gr.Markdown(
"""
---
### 🔍 Decode Methods:
- **Fast**: Greedy decoding - fastest speed, good for quick previews
- **Accurate**: Default balanced mode - good speed and accuracy
- **Beam**: Beam search decoding - slowest but highest accuracy
---
[GitHub Repository](https://github.com/mrrtmob/kiri-ocr) | [Hugging Face Model](https://huggingface.co/mrrtmob/kiri-ocr)
"""
)
# Launch
if __name__ == "__main__":
demo.queue().launch()