Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| # Model names defined here to avoid importing torch/transformers at startup. | |
| # Heavy imports (torch, transformers, peft) are deferred until first GPU call. | |
| PALIGEMMA_MODELS = { | |
| "Medium-14k, Single Line": {}, | |
| "Medium-16k, Single Line": {}, | |
| "Small, Single Line": {}, | |
| } | |
| GEMMA_MODELS = { | |
| "Gemma-3 10k": {}, | |
| } | |
| GEMMA_MULTILINE_MODELS = { | |
| "Gemma Multiline - no-format": "", | |
| "Gemma Multiline - line": "", | |
| } | |
| _paligemma_handler = None | |
| _gemma_handler = None | |
| _gemma_multiline_handler = None | |
| def get_paligemma_handler(): | |
| global _paligemma_handler | |
| if _paligemma_handler is None: | |
| from paligemma2 import PaliGemma2Handler | |
| _paligemma_handler = PaliGemma2Handler() | |
| return _paligemma_handler | |
| def get_gemma_handler(): | |
| global _gemma_handler | |
| if _gemma_handler is None: | |
| from gemma import GemmaHandler | |
| _gemma_handler = GemmaHandler() | |
| return _gemma_handler | |
| def get_gemma_multiline_handler(): | |
| global _gemma_multiline_handler | |
| if _gemma_multiline_handler is None: | |
| from gemma_multiline import GemmaMultilineHandler | |
| _gemma_multiline_handler = GemmaMultilineHandler() | |
| return _gemma_multiline_handler | |
| def process_image_paligemma(model_name, image, progress=gr.Progress()): | |
| """Process a single image with PaliGemma2""" | |
| return get_paligemma_handler().process_image(model_name, image, progress) | |
| def process_image_gemma(model_name, image, progress=gr.Progress()): | |
| """Process a single image with Gemma""" | |
| return get_gemma_handler().process_image(model_name, image, progress) | |
| def process_pdf_paligemma(pdf_path, model_name, progress=gr.Progress()): | |
| """Process a PDF file with PaliGemma2""" | |
| return get_paligemma_handler().process_pdf(pdf_path, model_name, progress) | |
| def process_pdf_gemma(pdf_path, model_name, progress=gr.Progress()): | |
| """Process a PDF file with Gemma""" | |
| return get_gemma_handler().process_pdf(pdf_path, model_name, progress) | |
| def process_image_multiline(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
| return get_gemma_multiline_handler().generate_text_from_image(model_name, image, temp, top_p, repetition_penalty, progress) | |
| def process_image_multiline_stream(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
| yield from get_gemma_multiline_handler().generate_text_stream(model_name, image, temp, top_p, repetition_penalty, progress) | |
| def process_pdf_multiline(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
| return get_gemma_multiline_handler().process_pdf(model_name, pdf, temp, top_p, repetition_penalty, progress) | |
| def process_pdf_multiline_stream(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
| yield from get_gemma_multiline_handler().process_pdf_stream(model_name, pdf, temp, top_p, repetition_penalty, progress) | |
| # Example images for document-level OCR | |
| document_examples = [ | |
| ["ml.png", "Multi-line Dhivehi text sample"], | |
| ["ml1.png", "Multi-line Dhivehi text sample 2"], | |
| ["ml2.png", "Multi-line Dhivehi text sample 3"], | |
| ["ml3.png", "Multi-line Dhivehi text sample 4"], | |
| ] | |
| # Example images for sentence-level OCR | |
| sentence_examples = [ | |
| ["type_1_sl.png", "Typed Dhivehi text sample 1"], | |
| ["type_2_sl.png", "Typed Dhivehi text sample 2"], | |
| ["hw_1_sl.png", "Handwritten Dhivehi text sample 1"], | |
| ["hw_2_sl.jpg", "Handwritten Dhivehi text sample 2"], | |
| ["hw_3_sl.png", "Handwritten Dhivehi text sample 3"], | |
| ["hw_4_sl.png", "Handwritten Dhivehi text sample 4"], | |
| ["ml.png", "Multi-line Dhivehi text sample"], | |
| ] | |
| css = """ | |
| .textbox1 textarea { | |
| font-size: 18px !important; | |
| font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
| line-height: 1.8 !important; | |
| } | |
| .textbox2 textarea { | |
| display: none; | |
| } | |
| """ | |
| with gr.Blocks(title="Dhivehi Image to Text",css=css) as demo: | |
| gr.Markdown("# Dhivehi Image to Text") | |
| gr.Markdown("Dhivehi Image to Text experimental finetunes") | |
| with gr.Tabs(): | |
| with gr.Tab("Gemma Document"): | |
| with gr.Row(): | |
| model_path_dropdown = gr.Dropdown( | |
| label="Model Checkpoint", | |
| choices=list(GEMMA_MULTILINE_MODELS.keys()), | |
| value=list(GEMMA_MULTILINE_MODELS.keys())[0], | |
| interactive=True, | |
| scale=2 | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, maximum=1.9, value=0.2, step=0.1, | |
| label="Temperature", info="Controls randomness in generation" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=1, step=0.1, | |
| label="Top-p", info="Controls diversity via nucleus sampling" | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.2, step=0.1, | |
| label="Repetition Penalty", info="Penalizes repeated tokens. >1 encourages new tokens." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Image Input"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate Text (Non-streaming)") | |
| stream_button = gr.Button("Generate Text (Streaming)", variant="primary") | |
| stop_button = gr.Button("Stop", visible=False, variant="stop") | |
| gr.Examples( | |
| examples=[[img] for img, _ in document_examples], | |
| inputs=[image_input], | |
| outputs=None, | |
| label="Example Images", | |
| examples_per_page=7 | |
| ) | |
| with gr.Column(): | |
| text_output = gr.Textbox( | |
| label="Extracted Dhivehi Text", | |
| lines=20, | |
| rtl=True, | |
| elem_classes=["textbox1"], | |
| show_copy_button=True, | |
| scale=2 | |
| ) | |
| def show_stop_button_image(): | |
| return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False) | |
| def hide_stop_button_image(): | |
| return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True) | |
| generate_button.click( | |
| fn=process_image_multiline, | |
| inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], | |
| outputs=text_output, | |
| show_progress="full" | |
| ) | |
| show_event = stream_button.click(fn=show_stop_button_image, outputs=[stop_button, stream_button, generate_button]) | |
| gen_event = show_event.then(fn=process_image_multiline_stream, inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=text_output, show_progress="full") | |
| gen_event.then(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button]) | |
| stop_button.click(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button], cancels=[gen_event]) | |
| with gr.Tab("PDF Input"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| with gr.Row(): | |
| pdf_generate_button = gr.Button("Generate Text (Non-streaming)") | |
| pdf_stream_button = gr.Button("Generate Text (Streaming)", variant="primary") | |
| pdf_stop_button = gr.Button("Stop", visible=False, variant="stop") | |
| gr.Examples( | |
| examples=[["example.pdf", "Example PDF"]], | |
| inputs=[pdf_input], | |
| outputs=None, | |
| label="Example PDFs", | |
| examples_per_page=7 | |
| ) | |
| with gr.Column(): | |
| pdf_text_output = gr.Textbox( | |
| label="Extracted Dhivehi Text", | |
| lines=20, | |
| rtl=True, | |
| elem_classes=["textbox1"], | |
| show_copy_button=True, | |
| scale=2 | |
| ) | |
| def show_stop_button_pdf(): | |
| return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False) | |
| def hide_stop_button_pdf(): | |
| return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True) | |
| pdf_generate_button.click( | |
| fn=process_pdf_multiline, | |
| inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], | |
| outputs=pdf_text_output, | |
| show_progress="full" | |
| ) | |
| pdf_show_event = pdf_stream_button.click(fn=show_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button]) | |
| pdf_gen_event = pdf_show_event.then(fn=process_pdf_multiline_stream, inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=pdf_text_output, show_progress="full") | |
| pdf_gen_event.then(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button]) | |
| pdf_stop_button.click(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button], cancels=[pdf_gen_event]) | |
| # model_path_dropdown.change(fn=load_model_multiline, inputs=model_path_dropdown) | |
| with gr.Tab("PaliGemma"): | |
| model_dropdown_paligemma = gr.Dropdown( | |
| choices=list(PALIGEMMA_MODELS.keys()), | |
| value=list(PALIGEMMA_MODELS.keys())[0], | |
| label="Select PaliGemma Model" | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Image Input"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input_paligemma = gr.Image(type="pil", label="Input Image") | |
| image_submit_btn_paligemma = gr.Button("Extract Text") | |
| # Image examples | |
| gr.Examples( | |
| examples=[[img] for img, _ in sentence_examples], | |
| inputs=[image_input_paligemma], | |
| label="Example Images", | |
| examples_per_page=8 | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab("Extracted Text"): | |
| image_text_output_paligemma = gr.Textbox( | |
| lines=5, | |
| label="Extracted Text", | |
| show_copy_button=True, | |
| rtl=True, | |
| elem_classes="textbox1" | |
| ) | |
| with gr.Tab("Detected Text Regions"): | |
| image_bbox_output_paligemma = gr.Gallery( | |
| label="Detected Text Regions", | |
| show_label=True, | |
| columns=2 | |
| ) | |
| with gr.Tab("PDF Input"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| pdf_input_paligemma = gr.File( | |
| label="Input PDF", | |
| file_types=[".pdf"] | |
| ) | |
| pdf_submit_btn_paligemma = gr.Button("Extract Text from PDF") | |
| # PDF examples | |
| gr.Examples( | |
| examples=[ | |
| ["example.pdf", "Example 1"], | |
| ], | |
| inputs=[pdf_input_paligemma], | |
| label="Example PDFs", | |
| examples_per_page=8 | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab("Extracted Text"): | |
| pdf_text_output_paligemma = gr.Textbox( | |
| lines=5, | |
| label="Extracted Text", | |
| show_copy_button=True, | |
| rtl=True, | |
| elem_classes="textbox1" | |
| ) | |
| with gr.Tab("Detected Text Regions"): | |
| pdf_bbox_output_paligemma = gr.Gallery( | |
| label="Detected Text Regions", | |
| show_label=True, | |
| columns=2 | |
| ) | |
| with gr.Tab("Gemma Sentence"): | |
| model_dropdown_gemma = gr.Dropdown( | |
| choices=list(GEMMA_MODELS.keys()), | |
| value=list(GEMMA_MODELS.keys())[0], | |
| label="Select Gemma Model" | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Image Input"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input_gemma = gr.Image(type="pil", label="Input Image") | |
| image_submit_btn_gemma = gr.Button("Extract Text") | |
| # Image examples | |
| gr.Examples( | |
| examples=[[img] for img, _ in sentence_examples], | |
| inputs=[image_input_gemma], | |
| label="Example Images", | |
| examples_per_page=8 | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab("Extracted Text"): | |
| image_text_output_gemma = gr.Textbox( | |
| lines=5, | |
| label="Extracted Text", | |
| show_copy_button=True, | |
| rtl=True, | |
| elem_classes="textbox1" | |
| ) | |
| with gr.Tab("Detected Text Regions"): | |
| image_bbox_output_gemma = gr.Gallery( | |
| label="Detected Text Regions", | |
| show_label=True, | |
| columns=2 | |
| ) | |
| with gr.Tab("PDF Input"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| pdf_input_gemma = gr.File( | |
| label="Input PDF", | |
| file_types=[".pdf"] | |
| ) | |
| pdf_submit_btn_gemma = gr.Button("Extract Text from PDF") | |
| # PDF examples | |
| gr.Examples( | |
| examples=[ | |
| ["example.pdf", "Example 1"], | |
| ], | |
| inputs=[pdf_input_gemma], | |
| label="Example PDFs", | |
| examples_per_page=8 | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab("Extracted Text"): | |
| pdf_text_output_gemma = gr.Textbox( | |
| lines=5, | |
| label="Extracted Text", | |
| show_copy_button=True, | |
| rtl=True, | |
| elem_classes="textbox1" | |
| ) | |
| with gr.Tab("Detected Text Regions"): | |
| pdf_bbox_output_gemma = gr.Gallery( | |
| label="Detected Text Regions", | |
| show_label=True, | |
| columns=2 | |
| ) | |
| # PaliGemma event handlers | |
| image_submit_btn_paligemma.click( | |
| fn=process_image_paligemma, | |
| inputs=[model_dropdown_paligemma, image_input_paligemma], | |
| outputs=[image_text_output_paligemma, image_bbox_output_paligemma] | |
| ) | |
| pdf_submit_btn_paligemma.click( | |
| fn=process_pdf_paligemma, | |
| inputs=[pdf_input_paligemma, model_dropdown_paligemma], | |
| outputs=[pdf_text_output_paligemma, pdf_bbox_output_paligemma] | |
| ) | |
| # Gemma event handlers | |
| image_submit_btn_gemma.click( | |
| fn=process_image_gemma, | |
| inputs=[model_dropdown_gemma, image_input_gemma], | |
| outputs=[image_text_output_gemma, image_bbox_output_gemma] | |
| ) | |
| pdf_submit_btn_gemma.click( | |
| fn=process_pdf_gemma, | |
| inputs=[pdf_input_gemma, model_dropdown_gemma], | |
| outputs=[pdf_text_output_gemma, pdf_bbox_output_gemma] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |