Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| import re | |
| import tempfile | |
| from PIL import Image | |
| from docx import Document | |
| from bs4 import BeautifulSoup | |
| from threading import Thread | |
| # --- Transformers Import --- | |
| try: | |
| from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor, TextIteratorStreamer | |
| except ImportError as e: | |
| raise ImportError("Transformers library not found. Please install git+https://github.com/huggingface/transformers.git") from e | |
| # --- Global Model Loading --- | |
| print("Loading AI Model (2.1B Parameters)... This may take a minute...") | |
| try: | |
| # OPTIMIZATION: Check for CUDA but don't force it if we are on a CPU tier to avoid errors | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| print(f"Running on GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 # CPUs handle float32 best | |
| print("Running on CPU mode") | |
| model_id = "lightonai/LightOnOCR-2-1B" | |
| processor = LightOnOcrProcessor.from_pretrained(model_id) | |
| # Load model | |
| model = LightOnOcrForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| attn_implementation="sdpa", # Use SDPA for both CPU and GPU (faster on PyTorch 2.0+) | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| model.eval() | |
| print("Model Loaded Successfully!") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| model = None | |
| processor = None | |
| # --- Helper Functions --- | |
| def resize_for_ocr(image, max_dim=768): | |
| """ | |
| Resize image to be faster. | |
| Lowered max_dim from 1280->896->768 for CPU deployment to ensure reasonable speed. | |
| """ | |
| if image is None: return None | |
| w, h = image.size | |
| if max(w, h) > max_dim: | |
| scale = max_dim / max(w, h) | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| return image.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| return image | |
| def clean_latex_for_word(text): | |
| """Clean simple LaTeX commands for better readability in Word.""" | |
| text = re.sub(r'\\begin\{array\}\{.*?\}', '', text) | |
| text = text.replace(r'\end{array}', '') | |
| text = re.sub(r'\\text\{([^}]*)\}', r'\1', text) | |
| text = re.sub(r'\\textbf\{([^}]*)\}', r'\1', text) | |
| text = re.sub(r'\\textit\{([^}]*)\}', r'\1', text) | |
| text = text.replace(r'\\', '\n') | |
| text = text.replace(r'\rightarrow', 'β').replace(r'\leftarrow', 'β') | |
| text = text.replace(r'\leftrightarrow', 'β').replace(r'\Rightarrow', 'β') | |
| text = text.replace(r'\downarrow', 'β').replace(r'\uparrow', 'β') | |
| text = text.replace(r'\ldots', '...').replace(r'\cdots', '...') | |
| text = text.replace(r'\times', 'Γ').replace(r'\approx', 'β') | |
| text = text.replace(r'\le', 'β€').replace(r'\ge', 'β₯') | |
| return text | |
| def format_latex_for_display(text): | |
| """ | |
| Auto-detects lines containing LaTeX (math/chemical equations) and wraps them in $$ | |
| so Gradio/Markdown renders them correctly. | |
| """ | |
| lines = text.split('\n') | |
| formatted = [] | |
| # Regex to detect lines that look like chemical equations (have arrows, subscripts, superscripts) | |
| # Checks for: \xrightarrow, \rightarrow, _{num}, ^{num}, \frac, etc. | |
| chem_pattern = re.compile(r"(\\xrightarrow|\\rightarrow|\\frac|\^\{|_\{|_[0-9]|[A-Z][a-z]?_\d)") | |
| for line in lines: | |
| # If line contains LaTeX indicators and isn't already wrapped in $$ | |
| if chem_pattern.search(line) and "$$" not in line: | |
| # Avoid wrapping lines that look like plain text but just have one subscript | |
| # But for chemistry usually even simple formulas look better in math mode | |
| formatted.append(f"$${line}$$") | |
| else: | |
| formatted.append(line) | |
| return "\n".join(formatted) | |
| def process_markdown_segment(text, doc): | |
| """Process standard markdown text lines.""" | |
| lines = text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if not line: continue | |
| line = clean_latex_for_word(line) | |
| if line.startswith('#'): | |
| parts = line.split(' ', 1) | |
| if len(parts) > 1: | |
| hashes, content = parts | |
| if all(c == '#' for c in hashes): | |
| doc.add_heading(content, level=min(len(hashes), 9)) | |
| continue | |
| if '$' in line: | |
| p = doc.add_paragraph() | |
| parts = line.split('$') | |
| for i, part in enumerate(parts): | |
| if i % 2 == 1: | |
| run = p.add_run(part) | |
| run.italic = True | |
| run.font.name = 'Cambria Math' | |
| else: | |
| p.add_run(part) | |
| continue | |
| if line.startswith('- ') or line.startswith('* '): | |
| doc.add_paragraph(line[2:].strip(), style='List Bullet') | |
| else: | |
| doc.add_paragraph(line) | |
| def process_html_table(html_str, doc): | |
| """Parse HTML table and add to Docx.""" | |
| try: | |
| soup = BeautifulSoup(html_str, 'html.parser') | |
| rows = soup.find_all('tr') | |
| if not rows: return | |
| max_cols = max([len(row.find_all(['td', 'th'])) for row in rows]) if rows else 0 | |
| if max_cols == 0: return | |
| table = doc.add_table(rows=len(rows), cols=max_cols) | |
| table.style = 'Table Grid' | |
| for i, row in enumerate(rows): | |
| cols = row.find_all(['td', 'th']) | |
| for j, col in enumerate(cols): | |
| if j < max_cols: | |
| table.cell(i, j).text = col.get_text(strip=True) | |
| except Exception as e: | |
| doc.add_paragraph(f"[Error parsing table]") | |
| def markdown_to_docx(text): | |
| """Convert extracted text to Docx object.""" | |
| doc = Document() | |
| table_pattern = re.compile(r'(<table.*?>.*?</table>)', re.IGNORECASE | re.DOTALL) | |
| parts = table_pattern.split(text) | |
| for part in parts: | |
| if not part.strip(): continue | |
| if part.strip().lower().startswith('<table'): | |
| process_html_table(part, doc) | |
| else: | |
| process_markdown_segment(part, doc) | |
| return doc | |
| # --- Gradio Logic --- | |
| def stream_ocr(image): | |
| if model is None: | |
| yield "Error: Model not loaded.", None | |
| return | |
| if image is None: | |
| yield "Please upload an image.", None | |
| return | |
| try: | |
| # Resize - Crucial for CPU speed | |
| valid_image = resize_for_ocr(image, max_dim=896) | |
| # Prepare Inputs | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": valid_image}, | |
| {"type": "text", "text": "Transcribe this document exactly."} | |
| ] | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(dtype=dtype) | |
| # Setup Streaming | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=2048, | |
| repetition_penalty=1.1, # Reduced from 1.2 to slightly speed up | |
| do_sample=False, # GREEDY DECODING: Much faster than sampling on CPU | |
| # temperature=0.2, # Not used in greedy | |
| # top_p=0.95, # Not used in greedy | |
| use_cache=True | |
| ) | |
| # Start Thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| # Yield partial text with LaTeX formatting applied | |
| formatted_text = format_latex_for_display(generated_text) | |
| yield formatted_text, None | |
| # Final Doc Generation | |
| doc = markdown_to_docx(generated_text) # Use raw text for DOCX generation logic | |
| # Save to temp file | |
| temp_dir = tempfile.gettempdir() | |
| output_path = os.path.join(temp_dir, "ocr_result.docx") | |
| doc.save(output_path) | |
| # Yield final text (formatted) and file | |
| yield format_latex_for_display(generated_text), output_path | |
| except Exception as e: | |
| yield f"Error during processing: {str(e)}", None | |
| # --- Prepare Examples --- | |
| example_images = [] | |
| # Ensure absolute path for robustness | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| data_dir = os.path.join(base_dir, 'data') | |
| if os.path.exists(data_dir): | |
| valid_exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} | |
| # Found files list | |
| found_files = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in valid_exts] | |
| print(f"DEBUG: Found {len(found_files)} images in {data_dir}") | |
| # Use ABSOLUTE paths (Matches app.py which works) | |
| example_images = [[os.path.join(data_dir, f)] for f in found_files] | |
| # Limit to 5 examples to prevent UI clutter if many files exist | |
| example_images = example_images[:5] | |
| else: | |
| print(f"DEBUG: Data directory not found at {data_dir}") | |
| # --- Aesthetic Custom CSS --- | |
| custom_css = """ | |
| /* Dark Purple Gradient Background */ | |
| body, .gradio-container { | |
| background-color: #0f0c29 !important; /* Fallback */ | |
| background: linear-gradient(-45deg, #0f0c29, #302b63, #24243e) !important; | |
| background-size: 400% 400%; | |
| animation: gradient 15s ease infinite; | |
| color: #e0e7ff !important; | |
| } | |
| /* | |
| UI Fixes for deployment | |
| - Ensure inputs and buttons are clearly visible | |
| - Remove overlay icons on images | |
| */ | |
| /* Reset z-indexes to avoid layering issues */ | |
| .gradio-container button, .gradio-container img { | |
| z-index: auto; | |
| } | |
| /* Specific fix for the main image container to prevent glass overlay */ | |
| .image-container, div[data-testid="image"] { | |
| background: transparent !important; | |
| border: none !important; | |
| backdrop-filter: none !important; | |
| } | |
| /* Hide the 'upload' icon/placeholder when an image is showing */ | |
| /* This targets the SVG usually found in the center */ | |
| div[data-testid="image"] svg { | |
| display: none !important; | |
| } | |
| /* Styling for the buttons to pop out */ | |
| button.primary { | |
| background: linear-gradient(90deg, #8b5cf6, #d946ef) !important; | |
| border: none !important; | |
| color: white !important; | |
| box-shadow: 0 4px 15px rgba(139, 92, 246, 0.4) !important; | |
| } | |
| /* Hide the label 'Document Source' if it overlaps */ | |
| label span { | |
| color: #e0e7ff !important; | |
| font-weight: bold; | |
| font-size: 1.1em; | |
| } | |
| @keyframes gradient { | |
| 0% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| 100% { background-position: 0% 50%; } | |
| } | |
| /* Enhanced Glassmorphism Classes */ | |
| .header-text { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| padding: 3rem; | |
| background: rgba(255, 255, 255, 0.05); | |
| border-radius: 20px; | |
| backdrop-filter: blur(16px); | |
| -webkit-backdrop-filter: blur(16px); | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.37); | |
| } | |
| .header-text h1 { | |
| font-family: 'Inter', sans-serif; | |
| font-weight: 800; | |
| color: #ffffff; | |
| text-shadow: 0 0 25px rgba(167, 139, 250, 0.6); | |
| margin-bottom: 0.8rem; | |
| font-size: 3.5rem; | |
| letter-spacing: -1.5px; | |
| } | |
| .header-text p { | |
| font-size: 1.1rem; | |
| color: #c4b5fd; | |
| font-weight: 400; | |
| letter-spacing: 2px; | |
| text-transform: uppercase; | |
| } | |
| /* Scrollable Markdown Area */ | |
| .scrollable-md { | |
| height: 400px; | |
| overflow-y: auto; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| border-radius: 8px; | |
| padding: 10px; | |
| background: rgba(0, 0, 0, 0.2); | |
| } | |
| """ | |
| theme = gr.themes.Glass( | |
| primary_hue="violet", | |
| secondary_hue="slate", | |
| neutral_hue="stone", | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ).set( | |
| body_background_fill="transparent", | |
| body_text_color="#e0e7ff", | |
| background_fill_primary="rgba(20, 20, 35, 0.2)", | |
| background_fill_secondary="rgba(20, 20, 35, 0.2)", | |
| border_color_primary="rgba(255, 255, 255, 0.1)", | |
| block_background_fill="rgba(30, 25, 45, 0.2)", | |
| block_border_width="1px", | |
| block_label_background_fill="rgba(50, 40, 70, 0.4)", | |
| input_background_fill="rgba(20, 20, 40, 0.3)", | |
| button_primary_background_fill="linear-gradient(90deg, #8b5cf6 0%, #6d28d9 100%)", | |
| button_primary_border_color="rgba(255, 255, 255, 0.3)", | |
| button_primary_text_color="#ffffff", | |
| button_primary_shadow="0 0 20px rgba(139, 92, 246, 0.6)", | |
| slider_color="#8b5cf6", | |
| ) | |
| # --- Gradio UI Layout --- | |
| with gr.Blocks(title="Ultra OCR", theme=theme, css=custom_css) as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| <div class="header-text"> | |
| <h1>π€ Ultra OCR</h1> | |
| <p>Crafted with β€οΈ by The Best Team</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(equal_height=False, variant="panel"): | |
| with gr.Column(scale=4): | |
| input_img = gr.Image( | |
| type="pil", | |
| label="Document Source", | |
| height=500, | |
| sources=['upload', 'clipboard'], | |
| format="png", | |
| show_label=False # Hide label to prevent text overlay on image | |
| ) | |
| run_btn = gr.Button("β‘ Start Transcription", variant="primary", size="lg") | |
| with gr.Column(scale=5): | |
| with gr.Tabs(): | |
| with gr.TabItem("π Live Text"): | |
| output_text = gr.Markdown( | |
| label="Real-time Extraction", | |
| elem_classes=["scrollable-md"] | |
| ) | |
| with gr.TabItem("πΎ Export"): | |
| gr.Markdown("### Download Results") | |
| output_file = gr.File(label="Download Word (.docx)", type="filepath") | |
| # Example Gallery | |
| if example_images: | |
| gr.HTML("<hr>") | |
| gr.Markdown("### π Sample Documents") | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=input_img, | |
| label="Click a sample to test", | |
| examples_per_page=5 | |
| ) | |
| # Interactions | |
| run_btn.click( | |
| fn=stream_ocr, | |
| inputs=[input_img], | |
| outputs=[output_text, output_file], | |
| concurrency_limit=5 | |
| ) | |
| if __name__ == "__main__": | |
| # Removed ssr_mode=False to fix gallery previews. | |
| # Using absolute paths with allowed_paths matches the working app.py config. | |
| demo.launch( | |
| allowed_paths=[os.path.dirname(os.path.abspath(__file__))] | |
| ) | |