Update app.py
Browse files
app.py
CHANGED
|
@@ -10,34 +10,35 @@ import cv2
|
|
| 10 |
import tempfile
|
| 11 |
import os
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
from kiri_ocr import OCR
|
| 17 |
-
print("Loading OCR model...")
|
| 18 |
return OCR(
|
| 19 |
model_path="mrrtmob/kiri-ocr",
|
| 20 |
det_method="db",
|
|
|
|
| 21 |
device="cpu",
|
| 22 |
verbose=False
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
if ocr is None:
|
| 32 |
-
ocr = load_ocr()
|
| 33 |
-
return ocr
|
| 34 |
|
| 35 |
-
def process_document_stream(image):
|
| 36 |
"""
|
| 37 |
Process document image with real-time character streaming.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
image: Input image (PIL Image or numpy array)
|
|
|
|
| 41 |
|
| 42 |
Yields:
|
| 43 |
Tuple of (annotated_image, extracted_text)
|
|
@@ -47,7 +48,7 @@ def process_document_stream(image):
|
|
| 47 |
return
|
| 48 |
|
| 49 |
try:
|
| 50 |
-
ocr_engine = get_ocr()
|
| 51 |
|
| 52 |
# Save temp file for processing (required by current API)
|
| 53 |
# Convert PIL to BGR numpy array first if needed
|
|
@@ -117,16 +118,20 @@ def process_document_stream(image):
|
|
| 117 |
yield image, f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 118 |
|
| 119 |
|
| 120 |
-
def recognize_line_stream(image):
|
| 121 |
"""
|
| 122 |
Stream text from single line image.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
"""
|
| 124 |
if image is None:
|
| 125 |
yield "Please upload an image."
|
| 126 |
return
|
| 127 |
|
| 128 |
try:
|
| 129 |
-
ocr_engine = get_ocr()
|
| 130 |
|
| 131 |
# Save temp file
|
| 132 |
if isinstance(image, Image.Image):
|
|
@@ -181,6 +186,14 @@ with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft(
|
|
| 181 |
sources=["upload", "clipboard", "webcam"]
|
| 182 |
)
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
with gr.Row():
|
| 185 |
doc_btn = gr.Button("⚡ Stream Text", variant="primary")
|
| 186 |
doc_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False)
|
|
@@ -207,6 +220,15 @@ with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft(
|
|
| 207 |
type="pil",
|
| 208 |
sources=["upload", "clipboard"]
|
| 209 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
with gr.Row():
|
| 211 |
line_btn = gr.Button("⚡ Stream Recognize", variant="primary")
|
| 212 |
line_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False)
|
|
@@ -231,7 +253,7 @@ with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft(
|
|
| 231 |
outputs=[doc_btn, doc_stop]
|
| 232 |
).then(
|
| 233 |
fn=process_document_stream,
|
| 234 |
-
inputs=[doc_input],
|
| 235 |
outputs=[doc_output_img, doc_output_text]
|
| 236 |
).then(
|
| 237 |
fn=reset_doc_buttons,
|
|
@@ -255,7 +277,7 @@ with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft(
|
|
| 255 |
outputs=[line_btn, line_stop]
|
| 256 |
).then(
|
| 257 |
fn=recognize_line_stream,
|
| 258 |
-
inputs=line_input,
|
| 259 |
outputs=line_output_text
|
| 260 |
).then(
|
| 261 |
fn=reset_line_buttons,
|
|
@@ -270,10 +292,17 @@ with gr.Blocks(title="Kiri OCR - Streaming Demo", css=css, theme=gr.themes.Soft(
|
|
| 270 |
|
| 271 |
gr.Markdown(
|
| 272 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
[GitHub Repository](https://github.com/mrrtmob/kiri-ocr) | [Hugging Face Model](https://huggingface.co/mrrtmob/kiri-ocr)
|
| 274 |
"""
|
| 275 |
)
|
| 276 |
|
| 277 |
# Launch
|
| 278 |
if __name__ == "__main__":
|
| 279 |
-
demo.queue().launch()
|
|
|
|
| 10 |
import tempfile
|
| 11 |
import os
|
| 12 |
|
| 13 |
+
# Global OCR instances (one per decode method)
|
| 14 |
+
ocr_instances = {}
|
| 15 |
+
|
| 16 |
+
def load_ocr(decode_method="accurate"):
|
| 17 |
+
"""Load the OCR model with specified decode method."""
|
| 18 |
from kiri_ocr import OCR
|
| 19 |
+
print(f"Loading OCR model with decode_method={decode_method}...")
|
| 20 |
return OCR(
|
| 21 |
model_path="mrrtmob/kiri-ocr",
|
| 22 |
det_method="db",
|
| 23 |
+
decode_method=decode_method,
|
| 24 |
device="cpu",
|
| 25 |
verbose=False
|
| 26 |
)
|
| 27 |
|
| 28 |
+
def get_ocr(decode_method="accurate"):
|
| 29 |
+
"""Get or create OCR instance for the specified decode method."""
|
| 30 |
+
global ocr_instances
|
| 31 |
+
if decode_method not in ocr_instances:
|
| 32 |
+
ocr_instances[decode_method] = load_ocr(decode_method)
|
| 33 |
+
return ocr_instances[decode_method]
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def process_document_stream(image, decode_method):
|
| 36 |
"""
|
| 37 |
Process document image with real-time character streaming.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
image: Input image (PIL Image or numpy array)
|
| 41 |
+
decode_method: Decode method to use (fast, accurate, or beam)
|
| 42 |
|
| 43 |
Yields:
|
| 44 |
Tuple of (annotated_image, extracted_text)
|
|
|
|
| 48 |
return
|
| 49 |
|
| 50 |
try:
|
| 51 |
+
ocr_engine = get_ocr(decode_method)
|
| 52 |
|
| 53 |
# Save temp file for processing (required by current API)
|
| 54 |
# Convert PIL to BGR numpy array first if needed
|
|
|
|
| 118 |
yield image, f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 119 |
|
| 120 |
|
| 121 |
+
def recognize_line_stream(image, decode_method):
|
| 122 |
"""
|
| 123 |
Stream text from single line image.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
image: Input image
|
| 127 |
+
decode_method: Decode method to use (fast, accurate, or beam)
|
| 128 |
"""
|
| 129 |
if image is None:
|
| 130 |
yield "Please upload an image."
|
| 131 |
return
|
| 132 |
|
| 133 |
try:
|
| 134 |
+
ocr_engine = get_ocr(decode_method)
|
| 135 |
|
| 136 |
# Save temp file
|
| 137 |
if isinstance(image, Image.Image):
|
|
|
|
| 186 |
sources=["upload", "clipboard", "webcam"]
|
| 187 |
)
|
| 188 |
|
| 189 |
+
# Decode method selector
|
| 190 |
+
doc_decode_method = gr.Radio(
|
| 191 |
+
choices=["fast", "accurate", "beam"],
|
| 192 |
+
value="accurate",
|
| 193 |
+
label="Decode Method",
|
| 194 |
+
info="Fast: Fastest, lower accuracy | Accurate: Balanced | Beam: Slowest, highest accuracy"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
with gr.Row():
|
| 198 |
doc_btn = gr.Button("⚡ Stream Text", variant="primary")
|
| 199 |
doc_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False)
|
|
|
|
| 220 |
type="pil",
|
| 221 |
sources=["upload", "clipboard"]
|
| 222 |
)
|
| 223 |
+
|
| 224 |
+
# Decode method selector
|
| 225 |
+
line_decode_method = gr.Radio(
|
| 226 |
+
choices=["fast", "accurate", "beam"],
|
| 227 |
+
value="accurate",
|
| 228 |
+
label="Decode Method",
|
| 229 |
+
info="Fast: Fastest, lower accuracy | Accurate: Balanced | Beam: Slowest, highest accuracy"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
with gr.Row():
|
| 233 |
line_btn = gr.Button("⚡ Stream Recognize", variant="primary")
|
| 234 |
line_stop = gr.Button("⏹️ Stop", variant="secondary", visible=False)
|
|
|
|
| 253 |
outputs=[doc_btn, doc_stop]
|
| 254 |
).then(
|
| 255 |
fn=process_document_stream,
|
| 256 |
+
inputs=[doc_input, doc_decode_method],
|
| 257 |
outputs=[doc_output_img, doc_output_text]
|
| 258 |
).then(
|
| 259 |
fn=reset_doc_buttons,
|
|
|
|
| 277 |
outputs=[line_btn, line_stop]
|
| 278 |
).then(
|
| 279 |
fn=recognize_line_stream,
|
| 280 |
+
inputs=[line_input, line_decode_method],
|
| 281 |
outputs=line_output_text
|
| 282 |
).then(
|
| 283 |
fn=reset_line_buttons,
|
|
|
|
| 292 |
|
| 293 |
gr.Markdown(
|
| 294 |
"""
|
| 295 |
+
---
|
| 296 |
+
### 🔍 Decode Methods:
|
| 297 |
+
- **Fast**: Greedy decoding - fastest speed, good for quick previews
|
| 298 |
+
- **Accurate**: Default balanced mode - good speed and accuracy
|
| 299 |
+
- **Beam**: Beam search decoding - slowest but highest accuracy
|
| 300 |
+
|
| 301 |
+
---
|
| 302 |
[GitHub Repository](https://github.com/mrrtmob/kiri-ocr) | [Hugging Face Model](https://huggingface.co/mrrtmob/kiri-ocr)
|
| 303 |
"""
|
| 304 |
)
|
| 305 |
|
| 306 |
# Launch
|
| 307 |
if __name__ == "__main__":
|
| 308 |
+
demo.queue().launch()
|