dhivehi-ocr / app.py
alakxender's picture
4
286aa99
Raw
History Blame Contribute Delete
19.4 kB
import spaces
import gradio as gr
# Model names defined here to avoid importing torch/transformers at startup.
# Heavy imports (torch, transformers, peft) are deferred until first GPU call.
PALIGEMMA_MODELS = {
"Medium-14k, Single Line": {},
"Medium-16k, Single Line": {},
"Small, Single Line": {},
}
GEMMA_MODELS = {
"Gemma-3 10k": {},
}
GEMMA_MULTILINE_MODELS = {
"Gemma Multiline - no-format": "",
"Gemma Multiline - line": "",
}
_paligemma_handler = None
_gemma_handler = None
_gemma_multiline_handler = None
def get_paligemma_handler():
global _paligemma_handler
if _paligemma_handler is None:
from paligemma2 import PaliGemma2Handler
_paligemma_handler = PaliGemma2Handler()
return _paligemma_handler
def get_gemma_handler():
global _gemma_handler
if _gemma_handler is None:
from gemma import GemmaHandler
_gemma_handler = GemmaHandler()
return _gemma_handler
def get_gemma_multiline_handler():
global _gemma_multiline_handler
if _gemma_multiline_handler is None:
from gemma_multiline import GemmaMultilineHandler
_gemma_multiline_handler = GemmaMultilineHandler()
return _gemma_multiline_handler
@spaces.GPU
def process_image_paligemma(model_name, image, progress=gr.Progress()):
"""Process a single image with PaliGemma2"""
return get_paligemma_handler().process_image(model_name, image, progress)
@spaces.GPU
def process_image_gemma(model_name, image, progress=gr.Progress()):
"""Process a single image with Gemma"""
return get_gemma_handler().process_image(model_name, image, progress)
@spaces.GPU
def process_pdf_paligemma(pdf_path, model_name, progress=gr.Progress()):
"""Process a PDF file with PaliGemma2"""
return get_paligemma_handler().process_pdf(pdf_path, model_name, progress)
@spaces.GPU
def process_pdf_gemma(pdf_path, model_name, progress=gr.Progress()):
"""Process a PDF file with Gemma"""
return get_gemma_handler().process_pdf(pdf_path, model_name, progress)
@spaces.GPU
def process_image_multiline(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()):
return get_gemma_multiline_handler().generate_text_from_image(model_name, image, temp, top_p, repetition_penalty, progress)
@spaces.GPU
def process_image_multiline_stream(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()):
yield from get_gemma_multiline_handler().generate_text_stream(model_name, image, temp, top_p, repetition_penalty, progress)
@spaces.GPU
def process_pdf_multiline(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()):
return get_gemma_multiline_handler().process_pdf(model_name, pdf, temp, top_p, repetition_penalty, progress)
@spaces.GPU
def process_pdf_multiline_stream(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()):
yield from get_gemma_multiline_handler().process_pdf_stream(model_name, pdf, temp, top_p, repetition_penalty, progress)
# Example images for document-level OCR
document_examples = [
["ml.png", "Multi-line Dhivehi text sample"],
["ml1.png", "Multi-line Dhivehi text sample 2"],
["ml2.png", "Multi-line Dhivehi text sample 3"],
["ml3.png", "Multi-line Dhivehi text sample 4"],
]
# Example images for sentence-level OCR
sentence_examples = [
["type_1_sl.png", "Typed Dhivehi text sample 1"],
["type_2_sl.png", "Typed Dhivehi text sample 2"],
["hw_1_sl.png", "Handwritten Dhivehi text sample 1"],
["hw_2_sl.jpg", "Handwritten Dhivehi text sample 2"],
["hw_3_sl.png", "Handwritten Dhivehi text sample 3"],
["hw_4_sl.png", "Handwritten Dhivehi text sample 4"],
["ml.png", "Multi-line Dhivehi text sample"],
]
css = """
.textbox1 textarea {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
line-height: 1.8 !important;
}
.textbox2 textarea {
display: none;
}
"""
with gr.Blocks(title="Dhivehi Image to Text",css=css) as demo:
gr.Markdown("# Dhivehi Image to Text")
gr.Markdown("Dhivehi Image to Text experimental finetunes")
with gr.Tabs():
with gr.Tab("Gemma Document"):
with gr.Row():
model_path_dropdown = gr.Dropdown(
label="Model Checkpoint",
choices=list(GEMMA_MULTILINE_MODELS.keys()),
value=list(GEMMA_MULTILINE_MODELS.keys())[0],
interactive=True,
scale=2
)
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.1, maximum=1.9, value=0.2, step=0.1,
label="Temperature", info="Controls randomness in generation"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=1, step=0.1,
label="Top-p", info="Controls diversity via nucleus sampling"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, step=0.1,
label="Repetition Penalty", info="Penalizes repeated tokens. >1 encourages new tokens."
)
with gr.Tabs():
with gr.Tab("Image Input"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Row():
generate_button = gr.Button("Generate Text (Non-streaming)")
stream_button = gr.Button("Generate Text (Streaming)", variant="primary")
stop_button = gr.Button("Stop", visible=False, variant="stop")
gr.Examples(
examples=[[img] for img, _ in document_examples],
inputs=[image_input],
outputs=None,
label="Example Images",
examples_per_page=7
)
with gr.Column():
text_output = gr.Textbox(
label="Extracted Dhivehi Text",
lines=20,
rtl=True,
elem_classes=["textbox1"],
show_copy_button=True,
scale=2
)
def show_stop_button_image():
return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
def hide_stop_button_image():
return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
generate_button.click(
fn=process_image_multiline,
inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider],
outputs=text_output,
show_progress="full"
)
show_event = stream_button.click(fn=show_stop_button_image, outputs=[stop_button, stream_button, generate_button])
gen_event = show_event.then(fn=process_image_multiline_stream, inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=text_output, show_progress="full")
gen_event.then(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button])
stop_button.click(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button], cancels=[gen_event])
with gr.Tab("PDF Input"):
with gr.Row():
with gr.Column():
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
with gr.Row():
pdf_generate_button = gr.Button("Generate Text (Non-streaming)")
pdf_stream_button = gr.Button("Generate Text (Streaming)", variant="primary")
pdf_stop_button = gr.Button("Stop", visible=False, variant="stop")
gr.Examples(
examples=[["example.pdf", "Example PDF"]],
inputs=[pdf_input],
outputs=None,
label="Example PDFs",
examples_per_page=7
)
with gr.Column():
pdf_text_output = gr.Textbox(
label="Extracted Dhivehi Text",
lines=20,
rtl=True,
elem_classes=["textbox1"],
show_copy_button=True,
scale=2
)
def show_stop_button_pdf():
return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
def hide_stop_button_pdf():
return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
pdf_generate_button.click(
fn=process_pdf_multiline,
inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider],
outputs=pdf_text_output,
show_progress="full"
)
pdf_show_event = pdf_stream_button.click(fn=show_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
pdf_gen_event = pdf_show_event.then(fn=process_pdf_multiline_stream, inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=pdf_text_output, show_progress="full")
pdf_gen_event.then(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
pdf_stop_button.click(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button], cancels=[pdf_gen_event])
# model_path_dropdown.change(fn=load_model_multiline, inputs=model_path_dropdown)
with gr.Tab("PaliGemma"):
model_dropdown_paligemma = gr.Dropdown(
choices=list(PALIGEMMA_MODELS.keys()),
value=list(PALIGEMMA_MODELS.keys())[0],
label="Select PaliGemma Model"
)
with gr.Tabs():
with gr.Tab("Image Input"):
with gr.Row():
with gr.Column(scale=2):
image_input_paligemma = gr.Image(type="pil", label="Input Image")
image_submit_btn_paligemma = gr.Button("Extract Text")
# Image examples
gr.Examples(
examples=[[img] for img, _ in sentence_examples],
inputs=[image_input_paligemma],
label="Example Images",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
image_text_output_paligemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
image_bbox_output_paligemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
with gr.Tab("PDF Input"):
with gr.Row():
with gr.Column(scale=2):
pdf_input_paligemma = gr.File(
label="Input PDF",
file_types=[".pdf"]
)
pdf_submit_btn_paligemma = gr.Button("Extract Text from PDF")
# PDF examples
gr.Examples(
examples=[
["example.pdf", "Example 1"],
],
inputs=[pdf_input_paligemma],
label="Example PDFs",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
pdf_text_output_paligemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
pdf_bbox_output_paligemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
with gr.Tab("Gemma Sentence"):
model_dropdown_gemma = gr.Dropdown(
choices=list(GEMMA_MODELS.keys()),
value=list(GEMMA_MODELS.keys())[0],
label="Select Gemma Model"
)
with gr.Tabs():
with gr.Tab("Image Input"):
with gr.Row():
with gr.Column(scale=2):
image_input_gemma = gr.Image(type="pil", label="Input Image")
image_submit_btn_gemma = gr.Button("Extract Text")
# Image examples
gr.Examples(
examples=[[img] for img, _ in sentence_examples],
inputs=[image_input_gemma],
label="Example Images",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
image_text_output_gemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
image_bbox_output_gemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
with gr.Tab("PDF Input"):
with gr.Row():
with gr.Column(scale=2):
pdf_input_gemma = gr.File(
label="Input PDF",
file_types=[".pdf"]
)
pdf_submit_btn_gemma = gr.Button("Extract Text from PDF")
# PDF examples
gr.Examples(
examples=[
["example.pdf", "Example 1"],
],
inputs=[pdf_input_gemma],
label="Example PDFs",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
pdf_text_output_gemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
pdf_bbox_output_gemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
# PaliGemma event handlers
image_submit_btn_paligemma.click(
fn=process_image_paligemma,
inputs=[model_dropdown_paligemma, image_input_paligemma],
outputs=[image_text_output_paligemma, image_bbox_output_paligemma]
)
pdf_submit_btn_paligemma.click(
fn=process_pdf_paligemma,
inputs=[pdf_input_paligemma, model_dropdown_paligemma],
outputs=[pdf_text_output_paligemma, pdf_bbox_output_paligemma]
)
# Gemma event handlers
image_submit_btn_gemma.click(
fn=process_image_gemma,
inputs=[model_dropdown_gemma, image_input_gemma],
outputs=[image_text_output_gemma, image_bbox_output_gemma]
)
pdf_submit_btn_gemma.click(
fn=process_pdf_gemma,
inputs=[pdf_input_gemma, model_dropdown_gemma],
outputs=[pdf_text_output_gemma, pdf_bbox_output_gemma]
)
# Launch the app
if __name__ == "__main__":
demo.launch()