#!/usr/bin/env python3 import subprocess import sys # CRITICAL: Import spaces FIRST before any CUDA initialization import spaces # Now we can import torch and other packages import torch # Install flash-attn for GPU only (after spaces import) if torch.cuda.is_available(): print("CUDA detected - installing flash-attn for optimal GPU performance...") subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import gradio as gr from PIL import Image from io import BytesIO import pypdfium2 as pdfium from transformers import ( LightOnOCRForConditionalGeneration, LightOnOCRProcessor, ) device = "cuda" if torch.cuda.is_available() else "cpu" # Choose best attention implementation based on device if device == "cuda": attn_implementation = "flash_attention_2" # Best for GPU dtype = torch.bfloat16 print("Using flash_attention_2 for GPU") else: attn_implementation = "eager" # Best for CPU dtype = torch.float32 print("Using eager attention for CPU") # Initialize the LightOnOCR model and processor print(f"Loading model on {device} with {attn_implementation} attention...") model = LightOnOCRForConditionalGeneration.from_pretrained( "lightonai/LightOnOCR-1B-1025", attn_implementation=attn_implementation, torch_dtype=dtype, trust_remote_code=True ).to(device).eval() processor = LightOnOCRProcessor.from_pretrained( "lightonai/LightOnOCR-1B-1025", trust_remote_code=True ) print("Model loaded successfully!") def render_pdf_page(page, max_resolution=1540, scale=2.77): """Render a PDF page to PIL Image.""" width, height = page.get_size() pixel_width = width * scale pixel_height = height * scale resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height) target_scale = scale * resize_factor return page.render(scale=target_scale, rev_byteorder=True).to_pil() def process_pdf(pdf_path, page_num=1): """Extract a specific page from PDF.""" pdf = pdfium.PdfDocument(pdf_path) total_pages = len(pdf) page_idx = min(max(int(page_num) - 1, 0), total_pages - 1) page = pdf[page_idx] img = render_pdf_page(page) pdf.close() return img, total_pages, page_idx + 1 @spaces.GPU def extract_text_from_image(image, temperature=0.2): """Extract text from image using LightOnOCR model.""" # Prepare the chat format chat = [ { "role": "user", "content": [ {"type": "image", "url": image}, ], } ] # Apply chat template and tokenize inputs = processor.apply_chat_template( chat, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) # Move inputs to device inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Generate text with appropriate settings with torch.no_grad(): # Disable gradients for inference outputs = model.generate( **inputs, max_new_tokens=2048, temperature=temperature if temperature > 0 else 0.0, use_cache=True, do_sample=temperature > 0, ) # Decode the output output_text = processor.decode(outputs[0], skip_special_tokens=True) return output_text def process_input(file_input, temperature, page_num): """Process uploaded file (image or PDF) and extract text.""" if file_input is None: return "Please upload an image or PDF first.", "", "", None, gr.update() image_to_process = None page_info = "" file_path = file_input if isinstance(file_input, str) else file_input.name # Handle PDF files if file_path.lower().endswith('.pdf'): try: image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num)) page_info = f"Processing page {actual_page} of {total_pages}" except Exception as e: return f"Error processing PDF: {str(e)}", "", "", None, gr.update() # Handle image files else: try: image_to_process = Image.open(file_path) page_info = "Processing image" except Exception as e: return f"Error opening image: {str(e)}", "", "", None, gr.update() try: # Extract text using LightOnOCR extracted_text = extract_text_from_image(image_to_process, temperature) return extracted_text, extracted_text, page_info, image_to_process, gr.update() except Exception as e: error_msg = f"Error during text extraction: {str(e)}" return error_msg, error_msg, page_info, image_to_process, gr.update() def update_slider(file_input): """Update page slider based on PDF page count.""" if file_input is None: return gr.update(maximum=20, value=1) file_path = file_input if isinstance(file_input, str) else file_input.name if file_path.lower().endswith('.pdf'): try: pdf = pdfium.PdfDocument(file_path) total_pages = len(pdf) pdf.close() return gr.update(maximum=total_pages, value=1) except: return gr.update(maximum=20, value=1) else: return gr.update(maximum=1, value=1) # Create Gradio interface with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo: gr.Markdown(f""" # 📖 Image/PDF to Text Extraction (LightOnOCR + Zero GPU) **💡 How to use:** 1. Upload an image or PDF 2. For PDFs: select which page to extract (1-20) 3. Adjust temperature if needed (0.0 for deterministic, higher for more varied output) 4. Click "Extract Text" **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables! **Model:** LightOnOCR-1B-1025 by LightOn AI **Device:** {device.upper()} **Attention:** {attn_implementation} """) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="🖼️ Upload Image or PDF", file_types=[".pdf", ".png", ".jpg", ".jpeg"], type="filepath" ) rendered_image = gr.Image( label="📄 Preview", type="pil", height=400, interactive=False ) num_pages = gr.Slider( minimum=1, maximum=20, value=1, step=1, label="PDF: Page Number", info="Select which page to extract" ) page_info = gr.Textbox( label="Processing Info", value="", interactive=False ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Temperature", info="0.0 = deterministic, Higher = more varied" ) submit_btn = gr.Button("Extract Text", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=2): output_text = gr.Markdown( label="📄 Extracted Text (Rendered)", value="*Extracted text will appear here...*" ) with gr.Row(): with gr.Column(): raw_output = gr.Textbox( label="Raw Markdown Output", placeholder="Raw text will appear here...", lines=20, max_lines=30, show_copy_button=True ) # Event handlers submit_btn.click( fn=process_input, inputs=[file_input, temperature, num_pages], outputs=[output_text, raw_output, page_info, rendered_image, num_pages] ) file_input.change( fn=update_slider, inputs=[file_input], outputs=[num_pages] ) clear_btn.click( fn=lambda: (None, "*Extracted text will appear here...*", "", "", None, 1), outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages] ) if __name__ == "__main__": demo.launch()