|
|
""" |
|
|
Kiri OCR - Gradio Demo for Hugging Face Spaces |
|
|
|
|
|
A lightweight OCR library for English and Khmer documents with streaming output support. |
|
|
""" |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
ocr_instances = {} |
|
|
|
|
|
def load_ocr(decode_method="accurate"): |
|
|
"""Load the OCR model with specified decode method.""" |
|
|
from kiri_ocr import OCR |
|
|
print(f"Loading OCR model with decode_method={decode_method}...") |
|
|
return OCR( |
|
|
model_path="mrrtmob/kiri-ocr", |
|
|
det_method="db", |
|
|
decode_method=decode_method, |
|
|
device="cpu", |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
def get_ocr(decode_method="accurate"): |
|
|
"""Get or create OCR instance for the specified decode method.""" |
|
|
global ocr_instances |
|
|
if decode_method not in ocr_instances: |
|
|
ocr_instances[decode_method] = load_ocr(decode_method) |
|
|
return ocr_instances[decode_method] |
|
|
|
|
|
def process_document_stream(image, decode_method): |
|
|
""" |
|
|
Process document image with real-time character streaming. |
|
|
|
|
|
Args: |
|
|
image: Input image (PIL Image or numpy array) |
|
|
decode_method: Decode method to use (fast, accurate, or beam) |
|
|
|
|
|
Yields: |
|
|
Tuple of (annotated_image, extracted_text) |
|
|
""" |
|
|
if image is None: |
|
|
yield None, "Please upload an image." |
|
|
return |
|
|
|
|
|
try: |
|
|
ocr_engine = get_ocr(decode_method) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
img_array = np.array(image) |
|
|
else: |
|
|
img_array = image |
|
|
|
|
|
|
|
|
if len(img_array.shape) == 2: |
|
|
img_display = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR) |
|
|
elif img_array.shape[2] == 4: |
|
|
img_display = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR) |
|
|
else: |
|
|
img_display = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: |
|
|
temp_path = f.name |
|
|
|
|
|
cv2.imwrite(temp_path, img_display) |
|
|
|
|
|
|
|
|
annotated = img_display.copy() |
|
|
extracted_text = "" |
|
|
current_region_text = "" |
|
|
|
|
|
|
|
|
for chunk in ocr_engine.extract_text_stream_chars(temp_path, mode="lines"): |
|
|
|
|
|
|
|
|
if chunk.get("region_start"): |
|
|
|
|
|
if "box" in chunk: |
|
|
x, y, w, h = chunk["box"] |
|
|
|
|
|
cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 2) |
|
|
|
|
|
cv2.putText( |
|
|
annotated, str(chunk["region_number"]), (x, y - 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1 |
|
|
) |
|
|
|
|
|
|
|
|
if chunk["region_number"] > 1: |
|
|
extracted_text += "\n" |
|
|
|
|
|
|
|
|
token = chunk.get("token", "") |
|
|
if token: |
|
|
extracted_text += token |
|
|
current_region_text += token |
|
|
|
|
|
|
|
|
|
|
|
if chunk.get("region_start") or chunk.get("region_finished") or len(current_region_text) % 3 == 0: |
|
|
|
|
|
yield cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB), extracted_text |
|
|
|
|
|
|
|
|
yield cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB), extracted_text |
|
|
|
|
|
|
|
|
os.unlink(temp_path) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
yield image, f"Error: {str(e)}\n{traceback.format_exc()}" |
|
|
|
|
|
|
|
|
def recognize_line_stream(image, decode_method): |
|
|
""" |
|
|
Stream text from single line image. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
decode_method: Decode method to use (fast, accurate, or beam) |
|
|
""" |
|
|
if image is None: |
|
|
yield "Please upload an image." |
|
|
return |
|
|
|
|
|
try: |
|
|
ocr_engine = get_ocr(decode_method) |
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image.save("temp_line.png") |
|
|
path = "temp_line.png" |
|
|
else: |
|
|
cv2.imwrite("temp_line.png", cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) |
|
|
path = "temp_line.png" |
|
|
|
|
|
extracted_text = "" |
|
|
|
|
|
for chunk in ocr_engine.recognize_streaming(path): |
|
|
token = chunk.get("token", "") |
|
|
if token: |
|
|
extracted_text += token |
|
|
yield extracted_text |
|
|
|
|
|
if os.path.exists(path): |
|
|
os.unlink(path) |
|
|
|
|
|
except Exception as e: |
|
|
yield f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
css = """ |
|
|
.container { max-width: 1200px; margin: auto; } |
|
|
.output-text { font-family: monospace; } |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ⚡ Kiri OCR Demo |
|
|
|
|
|
**Real-time OCR for English and Khmer documents** |
|
|
|
|
|
This demo showcases the **character-by-character streaming** capability of Kiri OCR. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("📄 Document Stream"): |
|
|
gr.Markdown("Upload a document to see text appear in real-time as it's recognized.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
doc_input = gr.Image( |
|
|
label="Upload Document", |
|
|
type="pil", |
|
|
sources=["upload", "clipboard", "webcam"] |
|
|
) |
|
|
|
|
|
|
|
|
doc_decode_method = gr.Radio( |
|
|
choices=["fast", "accurate", "beam"], |
|
|
value="accurate", |
|
|
label="Decode Method", |
|
|
info="Fast: Fastest, lower accuracy | Accurate: Balanced | Beam: Slowest, highest accuracy" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
doc_btn = gr.Button("⚡ Stream Text", variant="primary") |
|
|
doc_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
doc_output_img = gr.Image(label="Live Detection") |
|
|
|
|
|
doc_output_text = gr.Textbox( |
|
|
label="Streaming Text", |
|
|
lines=15, |
|
|
autoscroll=True, |
|
|
elem_classes=["output-text"] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("✏️ Single Line Stream"): |
|
|
gr.Markdown("Stream text recognition for a single cropped text line.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
line_input = gr.Image( |
|
|
label="Upload Text Line", |
|
|
type="pil", |
|
|
sources=["upload", "clipboard"] |
|
|
) |
|
|
|
|
|
|
|
|
line_decode_method = gr.Radio( |
|
|
choices=["fast", "accurate", "beam"], |
|
|
value="accurate", |
|
|
label="Decode Method", |
|
|
info="Fast: Fastest, lower accuracy | Accurate: Balanced | Beam: Slowest, highest accuracy" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
line_btn = gr.Button("⚡ Stream Recognize", variant="primary") |
|
|
line_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
line_output_text = gr.Textbox( |
|
|
label="Streaming Output", |
|
|
lines=3, |
|
|
elem_classes=["output-text"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def toggle_doc_buttons(): |
|
|
return gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
def reset_doc_buttons(): |
|
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
doc_event = doc_btn.click( |
|
|
fn=toggle_doc_buttons, |
|
|
outputs=[doc_btn, doc_stop] |
|
|
).then( |
|
|
fn=process_document_stream, |
|
|
inputs=[doc_input, doc_decode_method], |
|
|
outputs=[doc_output_img, doc_output_text] |
|
|
).then( |
|
|
fn=reset_doc_buttons, |
|
|
outputs=[doc_btn, doc_stop] |
|
|
) |
|
|
|
|
|
doc_stop.click( |
|
|
fn=reset_doc_buttons, |
|
|
outputs=[doc_btn, doc_stop], |
|
|
cancels=[doc_event] |
|
|
) |
|
|
|
|
|
def toggle_line_buttons(): |
|
|
return gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
def reset_line_buttons(): |
|
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
line_event = line_btn.click( |
|
|
fn=toggle_line_buttons, |
|
|
outputs=[line_btn, line_stop] |
|
|
).then( |
|
|
fn=recognize_line_stream, |
|
|
inputs=[line_input, line_decode_method], |
|
|
outputs=line_output_text |
|
|
).then( |
|
|
fn=reset_line_buttons, |
|
|
outputs=[line_btn, line_stop] |
|
|
) |
|
|
|
|
|
line_stop.click( |
|
|
fn=reset_line_buttons, |
|
|
outputs=[line_btn, line_stop], |
|
|
cancels=[line_event] |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### 🔍 Decode Methods: |
|
|
- **Fast**: Greedy decoding - fastest speed, good for quick previews |
|
|
- **Accurate**: Default balanced mode - good speed and accuracy |
|
|
- **Beam**: Beam search decoding - slowest but highest accuracy |
|
|
|
|
|
--- |
|
|
[GitHub Repository](https://github.com/mrrtmob/kiri-ocr) | [Hugging Face Model](https://huggingface.co/mrrtmob/kiri-ocr) |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |