OCR-MixVL / app.py
automation
Upd req
4591541
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import spaces
import os
import sys
import tempfile
import shutil
from PIL import Image, ImageDraw, ImageFont, ImageOps
import fitz
import re
import warnings
import numpy as np
import base64
from io import StringIO, BytesIO
import subprocess
import importlib
import time
import zipfile
import atexit
import functools
from queue import Queue
from threading import Event, Thread
from threading import Lock
from itertools import cycle
import logging
import json
# Suppress transformers warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", message=".*attention mask.*")
warnings.filterwarnings("ignore", message=".*pad token.*")
warnings.filterwarnings("ignore", message=".*seen_tokens.*")
warnings.filterwarnings("ignore", message=".*get_usable_length.*")
warnings.filterwarnings("ignore", message=".*get_max_cache.*")
warnings.filterwarnings("ignore", message=".*get_seq_length.*")
warnings.filterwarnings("ignore", message=".*position_ids.*")
warnings.filterwarnings("ignore", message=".*position_embeddings.*")
warnings.filterwarnings("ignore", message=".*Setting `pad_token_id`.*")
# Patch DynamicCache to fix deprecated attribute errors
# This is a compatibility fix for transformers >= 4.47.0 where seen_tokens, get_max_length, and get_usable_length were deprecated
try:
from transformers.cache_utils import DynamicCache
if not hasattr(DynamicCache, 'seen_tokens'):
# Add seen_tokens property for backward compatibility
def _get_seen_tokens(self):
"""Backward compatibility property for seen_tokens"""
# Calculate seen_tokens from the cache structure
if hasattr(self, 'key_cache') and self.key_cache:
# Return the length of the first layer's key cache
first_layer_keys = list(self.key_cache.values())[0] if self.key_cache else None
if first_layer_keys is not None and len(first_layer_keys) > 0:
return first_layer_keys[0].shape[-2] if hasattr(first_layer_keys[0], 'shape') else 0
return 0
DynamicCache.seen_tokens = property(_get_seen_tokens)
if not hasattr(DynamicCache, 'get_max_length'):
# Add get_max_length method for backward compatibility
# In newer transformers, this was replaced with cache_position or similar
def _get_max_length(self):
"""Backward compatibility method for get_max_length"""
# Try to get max length from cache structure
if hasattr(self, 'key_cache') and self.key_cache:
first_layer_keys = list(self.key_cache.values())[0] if self.key_cache else None
if first_layer_keys is not None and len(first_layer_keys) > 0:
# Return the sequence length dimension
if hasattr(first_layer_keys[0], 'shape') and len(first_layer_keys[0].shape) >= 2:
return first_layer_keys[0].shape[-2]
# Fallback: try cache_position if available (newer API)
if hasattr(self, 'cache_position') and self.cache_position is not None:
if hasattr(self.cache_position, '__len__'):
return len(self.cache_position)
elif hasattr(self.cache_position, 'shape'):
return self.cache_position.shape[-1] if len(self.cache_position.shape) > 0 else 0
# Default fallback
return 0
DynamicCache.get_max_length = _get_max_length
if not hasattr(DynamicCache, 'get_usable_length'):
# Add get_usable_length method for backward compatibility
# In newer transformers, this was replaced with get_seq_length or similar
def _get_usable_length(self, seq_length=None):
"""Backward compatibility method for get_usable_length
Args:
seq_length: Optional sequence length parameter (for compatibility with old API)
"""
# Try to use get_seq_length if available (newer API)
if hasattr(self, 'get_seq_length'):
try:
return self.get_seq_length()
except:
pass
# Try to get usable length from cache structure
if hasattr(self, 'key_cache') and self.key_cache:
first_layer_keys = list(self.key_cache.values())[0] if self.key_cache else None
if first_layer_keys is not None and len(first_layer_keys) > 0:
# Return the sequence length dimension
if hasattr(first_layer_keys[0], 'shape') and len(first_layer_keys[0].shape) >= 2:
cache_length = first_layer_keys[0].shape[-2]
# If seq_length is provided, return the minimum (usable portion)
if seq_length is not None:
return min(cache_length, seq_length)
return cache_length
# Fallback: try cache_position if available (newer API)
if hasattr(self, 'cache_position') and self.cache_position is not None:
if hasattr(self.cache_position, '__len__'):
pos_len = len(self.cache_position)
if seq_length is not None:
return min(pos_len, seq_length)
return pos_len
elif hasattr(self.cache_position, 'shape'):
pos_len = self.cache_position.shape[-1] if len(self.cache_position.shape) > 0 else 0
if seq_length is not None:
return min(pos_len, seq_length)
return pos_len
# If seq_length is provided, return it; otherwise return 0
return seq_length if seq_length is not None else 0
DynamicCache.get_usable_length = _get_usable_length
# Also add get_seq_length as an alias if it doesn't exist and get_usable_length does
if not hasattr(DynamicCache, 'get_seq_length') and hasattr(DynamicCache, 'get_usable_length'):
def _get_seq_length(self):
"""Backward compatibility method for get_seq_length (alias for get_usable_length)"""
return self.get_usable_length()
DynamicCache.get_seq_length = _get_seq_length
except (ImportError, AttributeError):
# If DynamicCache doesn't exist or patch fails, continue anyway
pass
# Optional dependency installer (used to keep base image small on Spaces)
_OPTIONAL_INSTALL_LOCK = Lock()
_OPTIONAL_INSTALL_CACHE = set()
def _install_optional_packages(packages, context="optional dependency"):
"""Install optional packages lazily using pip."""
pending = [pkg for pkg in packages if pkg not in _OPTIONAL_INSTALL_CACHE]
if not pending:
return
with _OPTIONAL_INSTALL_LOCK:
pending = [pkg for pkg in pending if pkg not in _OPTIONAL_INSTALL_CACHE]
if not pending:
return
try:
# Try installing packages one by one for better error reporting
for pkg in pending:
try:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", pkg],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
_OPTIONAL_INSTALL_CACHE.add(pkg)
except subprocess.CalledProcessError as e:
stderr_output = e.stderr.decode('utf-8') if e.stderr else str(e)
error_msg = f"Failed to install {pkg}: {stderr_output}"
warnings.warn(f"{context} - {error_msg}")
# For PaddleOCR, try without version constraints
if "paddle" in pkg.lower():
try:
# Try installing paddlepaddle first, then paddleocr
if "paddlepaddle" in pkg:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", "paddlepaddle"],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
_OPTIONAL_INSTALL_CACHE.add("paddlepaddle")
elif "paddleocr" in pkg:
# Try installing paddleocr without doc-parser first
try:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", "paddleocr"],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
_OPTIONAL_INSTALL_CACHE.add("paddleocr")
except:
raise
except Exception:
raise RuntimeError(f"{context} - Failed to install {pkg}. Error: {error_msg}")
else:
raise RuntimeError(f"{context} - Failed to install {pkg}. Error: {error_msg}")
importlib.invalidate_caches()
except Exception as install_error:
warnings.warn(f"Failed to install {context} packages {pending}: {install_error}")
raise
_DEEPSEEK_OPTIONAL_PACKAGES = {
"matplotlib": os.getenv("DEEPSEEK_MATPLOTLIB_SPEC", "matplotlib>=3.8.0"),
"torchvision": os.getenv("DEEPSEEK_TORCHVISION_SPEC", "torchvision>=0.19.0"),
}
def _ensure_deepseek_visual_deps():
"""Ensure DeepSeekOCR's optional visualization dependencies are installed."""
install_specs = []
try:
import matplotlib # noqa: F401
except ImportError:
install_specs.append(_DEEPSEEK_OPTIONAL_PACKAGES["matplotlib"])
try:
import torchvision # noqa: F401
except ImportError:
install_specs.append(_DEEPSEEK_OPTIONAL_PACKAGES["torchvision"])
if install_specs:
_install_optional_packages(install_specs, "DeepSeekOCR visual dependencies")
import matplotlib # noqa: F401
import torchvision # noqa: F401
# PaddleOCR-VL imports are deferred to reduce build-time dependencies
PADDLEOCRVL_AVAILABLE = True
PaddleOCRVL = None
PaddleOCR = None
PADDLEOCRVL_ERROR_MESSAGE = None
_PADDLE_OPTIONAL_PACKAGES = [
os.getenv("PADDLEPADDLE_SPEC", "paddlepaddle>=2.5.0"),
os.getenv("PADDLEOCR_SPEC", "paddleocr[doc-parser]>=2.7.0"),
]
def _import_paddleocr():
global PaddleOCR, PaddleOCRVL, PADDLEOCRVL_AVAILABLE, PADDLEOCRVL_ERROR_MESSAGE
try:
import paddleocr # type: ignore
except ImportError:
try:
_install_optional_packages(_PADDLE_OPTIONAL_PACKAGES, "PaddleOCR support")
import paddleocr # type: ignore
except Exception as install_error:
PADDLEOCRVL_AVAILABLE = False
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR installation failed: {install_error}. Try installing manually: pip install paddlepaddle paddleocr[doc-parser]"
raise RuntimeError(PADDLEOCRVL_ERROR_MESSAGE)
PaddleOCR = paddleocr.PaddleOCR
try:
from paddleocr import PaddleOCRVL # type: ignore
PADDLEOCRVL_AVAILABLE = True
except ImportError:
try:
from paddleocr.paddleocr_vl import PaddleOCRVL # type: ignore
PADDLEOCRVL_AVAILABLE = True
except ImportError:
try:
if hasattr(paddleocr, 'PaddleOCRVL'):
PaddleOCRVL = paddleocr.PaddleOCRVL
PADDLEOCRVL_AVAILABLE = True
elif hasattr(paddleocr, 'paddleocr_vl'):
try:
from paddleocr.paddleocr_vl import PaddleOCRVL # type: ignore
PADDLEOCRVL_AVAILABLE = True
except Exception:
pass
except Exception as e:
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR-VL import failed: {str(e)}"
if not PADDLEOCRVL_AVAILABLE:
try:
_ = PaddleOCR(use_angle_cls=True, lang='en')
PADDLEOCRVL_ERROR_MESSAGE = "PaddleOCR-VL class not found. Using regular PaddleOCR instead. For document parsing, ensure 'paddleocr[doc-parser]' is installed."
except Exception as test_error:
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR not working: {str(test_error)}"
# Gemini imports (optional)
try:
from google import genai
from google.genai import types
GEMINI_AVAILABLE = True
except Exception:
GEMINI_AVAILABLE = False
types = None
warnings.warn("Gemini SDK not available. Install with: pip install google-generativeai")
# olmOCR imports (optional)
OLMOCR_AVAILABLE = False
OLMOCR_MODEL = None
OLMOCR_PROCESSOR = None
OLMOCR_ERROR_MESSAGE = None
# Try to install olmocr conditionally if Python >= 3.11
import sys
if sys.version_info >= (3, 11):
try:
import olmocr
except ImportError:
# Try to install olmocr if not available
try:
import subprocess
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', '--no-cache-dir',
'git+https://github.com/allenai/olmocr.git'
])
# Reload import after installation
import importlib
importlib.invalidate_caches()
except Exception as install_error:
warnings.warn(f"Failed to auto-install olmocr: {install_error}. You may need to install it manually: pip install git+https://github.com/allenai/olmocr.git")
try:
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
OLMOCR_AVAILABLE = True
except ImportError as e:
OLMOCR_AVAILABLE = False
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
if sys.version_info < (3, 11):
# Check if we're on ZeroGPU (which is locked to Python 3.10.13)
is_zerogpu = os.getenv("SPACES_ZERO_GPU", "").lower() in ("true", "1") or "zerogpu" in os.getenv("SPACE_ID", "").lower()
if is_zerogpu:
OLMOCR_ERROR_MESSAGE = f"olmOCR requires Python >=3.11, but ZeroGPU Spaces are locked to Python 3.10.13. olmOCR is not available on ZeroGPU. Use a regular GPU Space or a different OCR engine."
else:
OLMOCR_ERROR_MESSAGE = f"olmOCR requires Python >=3.11, but you have Python {python_version}. For Hugging Face Spaces, create a runtime.txt file with 'python-3.11' or higher. Note: ZeroGPU Spaces are locked to Python 3.10.13 and do not support olmOCR."
else:
OLMOCR_ERROR_MESSAGE = f"olmOCR not available. Install with: pip install git+https://github.com/allenai/olmocr.git. Error: {str(e)}"
warnings.warn(OLMOCR_ERROR_MESSAGE)
except Exception as e:
OLMOCR_AVAILABLE = False
OLMOCR_ERROR_MESSAGE = f"olmOCR setup failed: {str(e)}"
warnings.warn(OLMOCR_ERROR_MESSAGE)
# dots.ocr imports (optional)
DOTSOCR_AVAILABLE = True
DOTSOCR_MODEL = None
DOTSOCR_PROCESSOR = None
DOTSOCR_ERROR_MESSAGE = None
_QWEN_OPTIONAL_PACKAGES = [
os.getenv("QWEN_VL_UTILS_SPEC", "qwen-vl-utils"),
]
try:
from transformers import AutoModelForCausalLM, AutoProcessor
except ImportError as e:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr not available. Install with: pip install transformers. Error: {str(e)}"
warnings.warn(DOTSOCR_ERROR_MESSAGE)
except Exception as e:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr setup failed: {str(e)}"
warnings.warn(DOTSOCR_ERROR_MESSAGE)
process_vision_info = None
def _ensure_dotsocr_dependencies():
"""Ensure qwen-vl-utils is available before using dots.ocr."""
global process_vision_info, DOTSOCR_AVAILABLE, DOTSOCR_ERROR_MESSAGE
if process_vision_info is not None:
return
try:
from qwen_vl_utils import process_vision_info as _process_vision_info
process_vision_info = _process_vision_info
except ImportError:
try:
_install_optional_packages(_QWEN_OPTIONAL_PACKAGES, "dots.ocr support (qwen-vl-utils)")
from qwen_vl_utils import process_vision_info as _process_vision_info # type: ignore
process_vision_info = _process_vision_info
except Exception as install_error:
DOTSOCR_AVAILABLE = False
error_msg = str(install_error)
# Provide more helpful error message
if "returned non-zero exit status" in error_msg:
DOTSOCR_ERROR_MESSAGE = f"dots.ocr not available. qwen-vl-utils installation failed. Try installing manually: pip install qwen-vl-utils. Error: {error_msg}"
else:
DOTSOCR_ERROR_MESSAGE = f"dots.ocr not available. Install with: pip install qwen-vl-utils. Error: {error_msg}"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
# Setup logger
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
logger.addHandler(handler)
logger.setLevel(logging.ERROR)
# Gather Gemini API keys from environment and prepare round-robin iterator
GEMINI_KEYS = [
os.getenv("GEMINI_API_1"),
os.getenv("GEMINI_API_2"),
os.getenv("GEMINI_API_3"),
os.getenv("GEMINI_API_4"),
os.getenv("GEMINI_API_5"),
]
GEMINI_KEYS = [k for k in GEMINI_KEYS if k]
_gemini_cycle = cycle(GEMINI_KEYS) if GEMINI_KEYS else None
_gemini_lock = Lock()
def _get_next_gemini_key():
if not GEMINI_AVAILABLE or not _gemini_cycle:
return None
with _gemini_lock:
return next(_gemini_cycle)
# Allow overriding model via env, default to a stable flash model with fallback
GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
class GeminiClient:
"""Gemini API client for generating responses"""
def __init__(self, api_key: str):
self.client = genai.Client(api_key=api_key)
def generate_content(self, prompt: str, image_data: bytes = None, mime_type: str = "image/jpeg", model: str = "gemini-2.5-flash", temperature: float = 0.7) -> str:
"""Generate content using Gemini API
Args:
prompt: Text prompt
image_data: Optional image bytes
mime_type: MIME type of the image (default: "image/jpeg")
model: Model name to use
temperature: Temperature for generation
"""
try:
# Build parts list
parts = [types.Part(text=prompt)]
if image_data:
# Use InlineData with Blob for image data
parts.append(types.Part(
inline_data=types.Blob(
mime_type=mime_type,
data=image_data
)
))
# Create Content object with role and parts
content = types.Content(role="user", parts=parts)
# Generate content
response = self.client.models.generate_content(
model=model,
contents=[content]
)
return response.text
except Exception as e:
logger.error(f"[LLM] ❌ Error calling Gemini API: {e}")
return "Error generating response from Gemini."
MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
# Monkey-patch LlamaFlashAttention2 if it doesn't exist (for compatibility with transformers >= 4.47.0)
# The model's custom code tries to import LlamaFlashAttention2 which was removed in transformers 4.47.0+
try:
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
except ImportError:
# LlamaFlashAttention2 doesn't exist, create a compatibility wrapper
try:
import transformers.models.llama.modeling_llama as llama_module
import inspect
# Try to use LlamaSdpaAttention or LlamaAttention as base
BaseAttention = None
if hasattr(llama_module, 'LlamaSdpaAttention'):
BaseAttention = llama_module.LlamaSdpaAttention
elif hasattr(llama_module, 'LlamaAttention'):
BaseAttention = llama_module.LlamaAttention
if BaseAttention is not None:
# Create a compatibility wrapper class that handles signature differences
class LlamaFlashAttention2(BaseAttention):
"""Compatibility wrapper for LlamaFlashAttention2 (removed in transformers 4.47.0+)"""
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, output_attentions=False, use_cache=False,
cache_position=None, position_embeddings=None, **kwargs):
"""
Forward method that adapts between old LlamaFlashAttention2 signature
and new LlamaAttention signature.
"""
# Get the signature of the parent class forward method
parent_forward = super().forward
sig = inspect.signature(parent_forward)
params = sig.parameters
# Build arguments dict, excluding 'self'
forward_kwargs = {
'hidden_states': hidden_states,
}
# Add optional arguments only if they're in the parent signature
if 'attention_mask' in params:
forward_kwargs['attention_mask'] = attention_mask
if 'position_ids' in params:
forward_kwargs['position_ids'] = position_ids
if 'past_key_value' in params:
forward_kwargs['past_key_value'] = past_key_value
if 'output_attentions' in params:
forward_kwargs['output_attentions'] = output_attentions
if 'use_cache' in params:
forward_kwargs['use_cache'] = use_cache
if 'cache_position' in params:
forward_kwargs['cache_position'] = cache_position
# Handle position_embeddings - critical for compatibility
if 'position_embeddings' in params:
# Parent accepts position_embeddings
param = params['position_embeddings']
if position_embeddings is not None:
forward_kwargs['position_embeddings'] = position_embeddings
elif param.default is inspect.Parameter.empty:
# Required parameter but not provided - try to generate it
if hasattr(self, 'rotary_emb') and position_ids is not None:
try:
# Generate position embeddings using rotary_emb
forward_kwargs['position_embeddings'] = self.rotary_emb(
hidden_states, position_ids=position_ids
)
except Exception:
# If generation fails, pass None and let parent handle it
forward_kwargs['position_embeddings'] = None
else:
# Can't generate, pass None and hope parent can handle it
forward_kwargs['position_embeddings'] = None
# If it has a default, we don't need to pass it
# If parent doesn't accept position_embeddings, we ignore it
# (parent will generate it internally from position_ids)
# Add any additional kwargs that parent accepts
for key, value in kwargs.items():
if key in params:
forward_kwargs[key] = value
# Call parent forward with adapted arguments
return parent_forward(**forward_kwargs)
llama_module.LlamaFlashAttention2 = LlamaFlashAttention2
else:
# Last resort: create a minimal dummy class
class LlamaFlashAttention2:
"""Compatibility alias for LlamaFlashAttention2 (removed in transformers 4.47.0+)"""
def __init__(self, *args, **kwargs):
pass
llama_module.LlamaFlashAttention2 = LlamaFlashAttention2
except Exception as e:
warnings.warn(f"Could not create LlamaFlashAttention2 compatibility layer: {e}. Model loading may fail.")
# Ensure DeepSeek visual dependencies are ready (installs lazily if missing)
try:
_ensure_deepseek_visual_deps()
except Exception as dep_error:
warnings.warn(f"DeepSeekOCR dependencies missing: {dep_error}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Set padding side early
tokenizer.padding_side = 'right'
def ensure_flash_attn_if_cuda():
# Only attempt install when CUDA is available
if not torch.cuda.is_available():
return False
try:
importlib.import_module('flash_attn')
return True
except Exception:
pass
try:
# Install without build isolation so setup can import torch
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', '--no-build-isolation', '--no-cache-dir', 'flash-attn==2.7.3'
])
importlib.invalidate_caches()
importlib.import_module('flash_attn')
return True
except Exception:
return False
flash_ok = ensure_flash_attn_if_cuda()
model = None
# Try loading with flash attention first if available
if flash_ok and torch.cuda.is_available():
try:
model = AutoModel.from_pretrained(
MODEL_NAME,
_attn_implementation='flash_attention_2',
torch_dtype=torch.bfloat16,
trust_remote_code=True,
use_safetensors=True,
)
model = model.eval().cuda()
except (ImportError, AttributeError) as e:
error_str = str(e)
# If LlamaFlashAttention2 import fails, fall back to default attention
if "LlamaFlashAttention2" in error_str or "cannot import name" in error_str:
warnings.warn(f"Flash attention not available due to transformers version ({error_str}); falling back to default attention.")
model = None # Will be loaded below
else:
# Other import errors, try fallback
warnings.warn(f"Flash attention unavailable ({error_str}); falling back to default attention.")
model = None
except Exception as e:
warnings.warn(f"Flash attention/CUDA unavailable ({e}); falling back to default attention.")
model = None
# Load with default attention if flash attention failed or wasn't attempted
if model is None:
try:
model = AutoModel.from_pretrained(
MODEL_NAME,
_attn_implementation=None, # Explicitly use default attention
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
use_safetensors=True,
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device).eval()
except (ImportError, AttributeError) as e:
error_str = str(e)
# If still failing due to LlamaFlashAttention2, try without specifying attn_implementation
if "LlamaFlashAttention2" in error_str or "cannot import name" in error_str:
warnings.warn(f"Model custom code requires LlamaFlashAttention2 but it's not available. Trying without explicit attention setting.")
try:
model = AutoModel.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
use_safetensors=True,
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device).eval()
except Exception as e2:
try:
import transformers
transformers_version = transformers.__version__
except:
transformers_version = "unknown"
raise RuntimeError(f"Failed to load DeepSeekOCR model. The model's custom code requires LlamaFlashAttention2 which is not available in transformers {transformers_version}. Please upgrade transformers: pip install transformers>=4.47.0. Error: {e2}")
else:
raise RuntimeError(f"Failed to load DeepSeekOCR model: {e}")
except Exception as e:
raise RuntimeError(f"Failed to load DeepSeekOCR model: {e}")
# Configure pad token after model is loaded
try:
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
else:
# Add a new pad token
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
try:
model.resize_token_embeddings(len(tokenizer))
except Exception:
pass
# Ensure model config has pad_token_id set
if hasattr(model, 'config'):
if model.config.pad_token_id is None:
model.config.pad_token_id = tokenizer.pad_token_id
# Also set pad_token_id in generation config if it exists
if hasattr(model.config, 'generation_config') and hasattr(model.config.generation_config, 'pad_token_id'):
if model.config.generation_config.pad_token_id is None:
model.config.generation_config.pad_token_id = tokenizer.pad_token_id
except Exception as e:
pass # Suppress warnings for pad token configuration
MODEL_CONFIGS = {
"Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}
}
# UI labels mapped to internal keys (use plain labels to match dropdown values)
MODE_LABEL_TO_KEY = {
"Gundam": "Gundam",
"Tiny": "Tiny",
"Small": "Small",
"Base": "Base",
"Large": "Large",
}
KEY_TO_MODE_LABEL = {v: k for k, v in MODE_LABEL_TO_KEY.items()}
# PaddleOCR-VL Configuration
# PaddleOCR-VL is a document parsing model that doesn't need language-specific configs
# It uses a single pipeline for all languages
# Defer PaddleOCR-VL pipeline init until first use to avoid startup errors on some builds
paddleocrvl_pipeline = None
PADDLEOCRVL_ERROR_MESSAGE = None
# Defer olmOCR model init until first use to avoid startup errors
olmocr_model = None
olmocr_processor = None
# Defer dots.ocr model init until first use to avoid startup errors
dotsocr_model = None
dotsocr_processor = None
TASK_PROMPTS = {
"Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to GitHub-flavored Markdown. Preserve headings, lists, links, code blocks, and tables.", "has_grounding": True},
"Tables": {"prompt": "<image>\n<|grounding|>Extract ALL tables only as GitHub Markdown tables. Preserve merged cells as best as possible. Do not include non-table content.", "has_grounding": True},
"Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True},
"Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False},
"Custom": {"prompt": "", "has_grounding": False}
}
TASK_LABEL_TO_KEY = {
"Markdown": "Markdown",
"Tables": "Tables",
"Locate": "Locate",
"Describe": "Describe",
"Custom": "Custom",
}
KEY_TO_TASK_LABEL = {v: k for k, v in TASK_LABEL_TO_KEY.items()}
# -----------------
# Simple in-memory LRU cache for per-page results
# -----------------
PAGE_CACHE = {}
PAGE_CACHE_ORDER = []
PAGE_CACHE_CAPACITY = int(os.getenv("PAGE_CACHE_CAPACITY", "512"))
PAGE_CACHE_LOCK = Lock()
def _page_cache_get(key):
with PAGE_CACHE_LOCK:
val = PAGE_CACHE.get(key)
if val is not None:
# Move to end (most recent)
try:
PAGE_CACHE_ORDER.remove(key)
except ValueError:
pass
PAGE_CACHE_ORDER.append(key)
return val
def _page_cache_set(key, value):
with PAGE_CACHE_LOCK:
if key in PAGE_CACHE:
PAGE_CACHE[key] = value
try:
PAGE_CACHE_ORDER.remove(key)
except ValueError:
pass
PAGE_CACHE_ORDER.append(key)
return
# Evict if needed
while len(PAGE_CACHE_ORDER) >= PAGE_CACHE_CAPACITY:
old_key = PAGE_CACHE_ORDER.pop(0)
PAGE_CACHE.pop(old_key, None)
PAGE_CACHE[key] = value
PAGE_CACHE_ORDER.append(key)
def extract_grounding_references(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
return re.findall(pattern, text, re.DOTALL)
def draw_bounding_boxes(image, refs, extract_images=False):
if not refs or len(refs) == 0:
return image.copy(), []
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
font = ImageFont.load_default()
crops = []
for ref in refs:
if not ref or len(ref) < 3:
continue
try:
label = ref[1] if len(ref) > 1 else 'unknown'
coords_str = ref[2] if len(ref) > 2 else '[]'
coords = eval(coords_str) if coords_str else []
if not coords or not isinstance(coords, (list, tuple)):
continue
except Exception as e:
warnings.warn(f"Failed to parse ref {ref}: {e}")
continue
color = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
color_a = color + (60,)
for box in coords:
if not box or len(box) < 4:
continue
try:
x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h)
if extract_images and label == 'image':
crops.append(image.crop((x1, y1, x2, y2)))
width = 5 if label == 'title' else 3
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
draw2.rectangle([x1, y1, x2, y2], fill=color_a)
text_bbox = draw.textbbox((0, 0), label, font=font)
tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
ty = max(0, y1 - 20)
draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
except Exception as e:
warnings.warn(f"Failed to draw box {box}: {e}")
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw, crops
def clean_output(text, include_images=False, remove_labels=False):
if not text:
return ""
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
img_num = 0
for match in matches:
if '<|ref|>image<|/ref|>' in match[0]:
if include_images:
text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
img_num += 1
else:
text = text.replace(match[0], '', 1)
else:
if remove_labels:
text = text.replace(match[0], '', 1)
else:
text = text.replace(match[0], match[1], 1)
return text.strip()
def embed_images(markdown, crops):
if not crops:
return markdown
for i, img in enumerate(crops):
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
return markdown
def _image_to_jpeg_bytes(image: Image.Image) -> bytes:
if image.mode in ('RGBA', 'LA', 'P'):
image = image.convert('RGB')
buf = BytesIO()
image.save(buf, format='JPEG', quality=95)
return buf.getvalue()
def _build_gemini_system_prompt():
return (
"You are an expert document parser. Convert the given document image to GitHub-flavored Markdown. "
"Preserve headings, lists, links, code blocks, and tables with correct alignment and borders. "
"Keep the reading order, avoid hallucinations, and output Markdown only."
)
def process_image_gemini(image: Image.Image):
if image is None:
return " Error Upload image", "", "", None, []
if not GEMINI_AVAILABLE or not GEMINI_KEYS:
return " Gemini not available or no API keys set in .env", "", "", None, []
key = _get_next_gemini_key()
if not key:
return " Gemini API keys not configured", "", "", None, []
try:
client = GeminiClient(api_key=key)
img_bytes = _image_to_jpeg_bytes(image)
system_prompt = _build_gemini_system_prompt()
try:
md = client.generate_content(
prompt=system_prompt,
image_data=img_bytes,
mime_type="image/jpeg",
model=GEMINI_MODEL
)
except Exception:
# Fallback for users who configured an unsupported model name
md = client.generate_content(
prompt=system_prompt,
image_data=img_bytes,
mime_type="image/jpeg",
model="gemini-2.5-flash"
)
md = (md or "").strip()
if not md or md == "Error generating response from Gemini.":
return "No text" if md != "Error generating response from Gemini." else md, "", "", None, []
return md, md, md, None, []
except Exception as e:
return f"Error: {str(e)}", "", "", None, []
@spaces.GPU(duration=120)
def process_image(image, mode_label, task_label, custom_prompt, embed_figures=False, high_accuracy=False):
if image is None:
return " Error Upload image", "", "", None, []
if task_label in ["Custom", "Locate"] and not custom_prompt.strip():
return "Enter prompt", "", "", None, []
if image.mode in ('RGBA', 'LA', 'P'):
image = image.convert('RGB')
image = ImageOps.exif_transpose(image)
# Normalize labels to internal keys
mode_key = MODE_LABEL_TO_KEY.get(mode_label, mode_label)
task_key = TASK_LABEL_TO_KEY.get(task_label, task_label)
config = MODEL_CONFIGS[mode_key]
if task_label == "Custom":
prompt = f"<image>\n{custom_prompt.strip()}"
has_grounding = '<|grounding|>' in custom_prompt
elif task_label == "Locate":
prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
has_grounding = True
else:
prompt = TASK_PROMPTS[task_key]["prompt"]
has_grounding = TASK_PROMPTS[task_key]["has_grounding"]
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(tmp.name, 'JPEG', quality=95)
tmp.close()
out_dir = tempfile.mkdtemp()
stdout = sys.stdout
sys.stdout = StringIO()
try:
model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir,
base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
except Exception as e:
sys.stdout = stdout
os.unlink(tmp.name)
shutil.rmtree(out_dir, ignore_errors=True)
warnings.warn(f"Inference error: {e}")
return f"Inference error: {str(e)}", "", "", None, []
# Get result from stdout
stdout_output = sys.stdout.getvalue()
sys.stdout = stdout
# Filter stdout output
result = '\n'.join([l for l in stdout_output.split('\n')
if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size', 'torch.cuda', 'loading', 'INFO:', 'WARNING:', 'ERROR:'])]).strip()
# Also check output directory for markdown/text files
if os.path.exists(out_dir):
# Look for markdown files first
md_files = sorted([f for f in os.listdir(out_dir) if f.endswith('.md')])
txt_files = sorted([f for f in os.listdir(out_dir) if f.endswith('.txt')])
# Read markdown files if available
if md_files:
file_contents = []
for md_file in md_files:
md_path = os.path.join(out_dir, md_file)
try:
with open(md_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content:
file_contents.append(content)
except Exception as e:
warnings.warn(f"Failed to read {md_path}: {e}")
if file_contents:
result = '\n\n'.join(file_contents) if not result else result + '\n\n' + '\n\n'.join(file_contents)
# Fallback to text files if no markdown
elif txt_files and not result:
file_contents = []
for txt_file in txt_files:
txt_path = os.path.join(out_dir, txt_file)
try:
with open(txt_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content:
file_contents.append(content)
except Exception:
pass
if file_contents:
result = '\n\n'.join(file_contents)
os.unlink(tmp.name)
shutil.rmtree(out_dir, ignore_errors=True)
if not result:
return "No text", "", "", None, []
cleaned = clean_output(result, False, False)
markdown = clean_output(result, True, True)
img_out = None
crops = []
if has_grounding and '<|ref|>' in result:
refs = extract_grounding_references(result)
if refs and len(refs) > 0:
try:
img_out, crops = draw_bounding_boxes(image, refs, True)
except Exception as e:
warnings.warn(f"Failed to draw bounding boxes: {e}")
img_out = None
crops = []
if embed_figures:
markdown = embed_images(markdown, crops)
# Optional second pass for high accuracy (focus on tables refinement)
if high_accuracy and task_key in ["Markdown", "Tables"]:
refine_prompt = "<image>\nRefine the previous extraction with emphasis on accurate table structure and alignment. Output GitHub Markdown only."
tmp2 = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(tmp2.name, 'JPEG', quality=95)
tmp2.close()
out_dir2 = tempfile.mkdtemp()
stdout2 = sys.stdout
sys.stdout = StringIO()
try:
model.infer(tokenizer=tokenizer, prompt=refine_prompt, image_file=tmp2.name, output_path=out_dir2,
base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
except Exception:
pass
stdout_output2 = sys.stdout.getvalue()
sys.stdout = stdout2
refine_result = '\n'.join([l for l in stdout_output2.split('\n')
if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size', 'torch.cuda', 'loading', 'INFO:', 'WARNING:', 'ERROR:'])]).strip()
# Check output directory for refine results
if os.path.exists(out_dir2):
md_files2 = sorted([f for f in os.listdir(out_dir2) if f.endswith('.md')])
if md_files2:
for md_file in md_files2:
md_path2 = os.path.join(out_dir2, md_file)
try:
with open(md_path2, 'r', encoding='utf-8') as f:
content2 = f.read().strip()
if content2:
refine_result = content2 if not refine_result else refine_result + '\n\n' + content2
except Exception:
pass
os.unlink(tmp2.name)
shutil.rmtree(out_dir2, ignore_errors=True)
if refine_result:
refined_md = clean_output(refine_result, embed_figures, True)
# Prefer refined markdown if longer (heuristic)
if len(refined_md) > len(markdown):
markdown = refined_md
return cleaned, markdown, result, img_out, crops
def process_image_paddleocrvl(image, prompt=None):
"""Process image using PaddleOCR-VL and return results in Markdown format.
Args:
image: PIL Image to process
prompt: Optional custom prompt. If None, uses default document parsing prompt.
"""
if image is None:
return " Error Upload image", "", "", None, []
# Lazy init to avoid import-time errors on some environments
global paddleocrvl_pipeline, PADDLEOCRVL_AVAILABLE, PADDLEOCRVL_ERROR_MESSAGE, PaddleOCRVL
if PaddleOCR is None or PaddleOCRVL is None:
try:
_import_paddleocr()
except Exception as e:
PADDLEOCRVL_AVAILABLE = False
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR-VL setup failed: {e}"
if not PADDLEOCRVL_AVAILABLE or PaddleOCRVL is None:
msg = PADDLEOCRVL_ERROR_MESSAGE or "PaddleOCR-VL not available. Install with: pip install 'paddleocr[doc-parser]'"
return f" {msg}", "", "", None, []
if paddleocrvl_pipeline is None:
try:
paddleocrvl_pipeline = PaddleOCRVL()
except Exception as e:
PADDLEOCRVL_AVAILABLE = False
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR-VL init failed: {e}"
return f" {PADDLEOCRVL_ERROR_MESSAGE}", "", "", None, []
if image.mode in ('RGBA', 'LA', 'P'):
image = image.convert('RGB')
image = ImageOps.exif_transpose(image)
# Save image to temporary file for PaddleOCR
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(tmp.name, 'JPEG', quality=95)
tmp.close()
try:
# Try different approaches based on PaddleOCR-VL API
# First, try the simple predict method
try:
if prompt:
# If custom prompt is provided, try to use it with structured messages
# Format: PaddleOCR-VL might accept messages with image and text
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": tmp.name},
{"type": "text", "text": prompt}
]
}
]
# Try if the pipeline has a chat or generate method
if hasattr(paddleocrvl_pipeline, 'chat'):
output = paddleocrvl_pipeline.chat(messages)
elif hasattr(paddleocrvl_pipeline, 'generate'):
output = paddleocrvl_pipeline.generate(messages)
else:
# Fallback to predict with just image
output = paddleocrvl_pipeline.predict(tmp.name)
else:
# Default: use predict with just image (document parsing mode)
output = paddleocrvl_pipeline.predict(tmp.name)
except (TypeError, AttributeError):
# If structured messages don't work, fallback to simple predict
output = paddleocrvl_pipeline.predict(tmp.name)
if not output or len(output) == 0:
os.unlink(tmp.name)
return "No text detected", "", "", None, []
# PaddleOCR-VL returns results that can be saved to markdown
# Save to a temporary directory
out_dir = tempfile.mkdtemp()
markdown_results = []
text_results = []
raw_results = []
# Handle different output types
if isinstance(output, str):
# Direct string output
markdown_results.append(output)
text_results.append(output)
raw_results.append("PaddleOCR-VL direct output")
elif isinstance(output, list):
# List of results
for res in output:
try:
if isinstance(res, str):
# String result
markdown_results.append(res)
text_results.append(res)
raw_results.append("PaddleOCR-VL result")
elif hasattr(res, 'save_to_markdown'):
# Object with save_to_markdown method
res.save_to_markdown(save_path=out_dir)
# Find the markdown file
md_files = [f for f in os.listdir(out_dir) if f.endswith('.md')]
if md_files:
md_path = os.path.join(out_dir, md_files[0])
with open(md_path, 'r', encoding='utf-8') as f:
markdown_content = f.read()
markdown_results.append(markdown_content)
text_results.append(markdown_content)
raw_results.append(f"PaddleOCR-VL result: {str(res)}")
# Remove the file to avoid conflicts
os.remove(md_path)
else:
# Try to convert to string
markdown_results.append(str(res))
text_results.append(str(res))
raw_results.append(f"PaddleOCR-VL result: {str(res)}")
except Exception as e:
warnings.warn(f"Failed to process PaddleOCR-VL result: {e}")
else:
# Try to convert to string
markdown_results.append(str(output))
text_results.append(str(output))
raw_results.append(f"PaddleOCR-VL result: {str(output)}")
# Clean up temp directory
shutil.rmtree(out_dir, ignore_errors=True)
os.unlink(tmp.name)
if not markdown_results:
return "No text detected", "", "", None, []
# Combine results
markdown = "\n\n".join(markdown_results)
text = "\n\n".join(text_results)
raw = "\n\n".join(raw_results)
# For bounding boxes, we can try to extract from the result if available
img_out = None
try:
# Draw bounding boxes if available in the result
img_draw = image.copy()
# PaddleOCR-VL may have layout information we can use
# This is a simplified version - adjust based on actual API
img_out = img_draw
except Exception as e:
warnings.warn(f"Failed to draw bounding boxes: {e}")
return text, markdown, raw, img_out, []
except Exception as e:
os.unlink(tmp.name)
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
warnings.warn(error_msg)
return f"Error: {str(e)}", "", "", None, []
def _init_olmocr_model():
"""Lazy initialization of olmOCR model."""
global olmocr_model, olmocr_processor, OLMOCR_AVAILABLE, OLMOCR_ERROR_MESSAGE
if not OLMOCR_AVAILABLE:
msg = OLMOCR_ERROR_MESSAGE or "olmOCR not available. Install with: pip install git+https://github.com/allenai/olmocr.git (requires Python >=3.11)"
raise RuntimeError(msg)
if olmocr_model is None or olmocr_processor is None:
try:
model_name = "allenai/olmOCR-2-7B-1025"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
olmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
).eval()
olmocr_model.to(device)
olmocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True)
except Exception as e:
OLMOCR_AVAILABLE = False
OLMOCR_ERROR_MESSAGE = f"olmOCR model initialization failed: {str(e)}"
raise RuntimeError(OLMOCR_ERROR_MESSAGE)
def _resize_image_for_olmocr(image: Image.Image, target_longest_dim: int = 1288) -> Image.Image:
"""Resize image so longest dimension is target_longest_dim pixels."""
width, height = image.size
longest_dim = max(width, height)
if longest_dim == target_longest_dim:
return image
scale = target_longest_dim / longest_dim
new_width = int(width * scale)
new_height = int(height * scale)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def _init_dotsocr_model():
"""Lazy initialization of dots.ocr model."""
global dotsocr_model, dotsocr_processor, DOTSOCR_AVAILABLE, DOTSOCR_ERROR_MESSAGE
if not DOTSOCR_AVAILABLE:
msg = DOTSOCR_ERROR_MESSAGE or "dots.ocr not available. Install with: pip install qwen-vl-utils"
raise RuntimeError(msg)
try:
_ensure_dotsocr_dependencies()
except Exception as dep_error:
raise RuntimeError(str(dep_error))
if dotsocr_model is None or dotsocr_processor is None:
try:
model_path = "rednote-hilab/dots.ocr"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Check transformers version
try:
import transformers
transformers_version = transformers.__version__
# dots.ocr may require transformers >= 4.47.0 for Qwen2_5_VLProcessor
try:
from packaging import version
if version.parse(transformers_version) < version.parse("4.47.0"):
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr requires transformers >= 4.47.0, but you have {transformers_version}. Please upgrade: pip install transformers>=4.47.0"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
except ImportError:
# packaging not available, try to check version manually
version_parts = transformers_version.split('.')
if len(version_parts) >= 2:
major, minor = int(version_parts[0]), int(version_parts[1])
if major < 4 or (major == 4 and minor < 47):
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr requires transformers >= 4.47.0, but you have {transformers_version}. Please upgrade: pip install transformers>=4.47.0"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
except RuntimeError:
# Re-raise version check errors
raise
except Exception:
# If version check fails, continue anyway
pass
# Check for flash attention
flash_attn_available = False
try:
importlib.import_module('flash_attn')
flash_attn_available = True
except:
pass
# Try loading with flash attention first, fallback to default if it fails
try:
dotsocr_model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2" if flash_attn_available else None,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
).eval()
except (ImportError, AttributeError) as e:
error_str = str(e)
# If flash attention fails (e.g., LlamaFlashAttention2 not available), try without it
if "LlamaFlashAttention2" in error_str or "flash_attention" in error_str.lower() or "cannot import name" in error_str:
warnings.warn(f"Flash attention not available for dots.ocr, falling back to default attention: {error_str}")
dotsocr_model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation=None, # Use default attention
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
).eval()
else:
raise
if not torch.cuda.is_available():
dotsocr_model.to(device)
try:
dotsocr_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
except (ImportError, AttributeError, TypeError) as e:
error_str = str(e)
# Handle video_processor error - use manual patch (skip gated repo)
if "video_processor" in error_str or "BaseVideoProcessor" in error_str:
# Try manual patch first (skip gated repo to avoid access issues)
try:
from transformers import Qwen2_5_VLProcessor
from transformers import AutoImageProcessor, AutoTokenizer
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Create processor manually with video_processor=None
# Qwen2_5_VLProcessor may require video_processor, so we create a dummy one if needed
try:
dotsocr_processor = Qwen2_5_VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer
)
except TypeError:
# If video_processor is required, try to create a minimal one or pass None explicitly
# Some versions may accept None, others may need a dummy processor
try:
# Try with explicit None
dotsocr_processor = Qwen2_5_VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
video_processor=None
)
except (TypeError, ValueError):
# Last resort: try to create processor by patching the class signature
# Check if we can inspect the __init__ signature
import inspect
try:
sig = inspect.signature(Qwen2_5_VLProcessor.__init__)
# If video_processor has a default, we can call it
params = sig.parameters
if 'video_processor' in params and params['video_processor'].default is not inspect.Parameter.empty:
# video_processor has a default, call normally
dotsocr_processor = Qwen2_5_VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer
)
else:
# Need to patch the class to accept None
# Create processor without video_processor argument
dotsocr_processor = Qwen2_5_VLProcessor.__new__(Qwen2_5_VLProcessor)
# Try to call parent __init__ if available
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
if hasattr(Qwen2VLProcessor, '__init__'):
try:
Qwen2VLProcessor.__init__(dotsocr_processor, image_processor=image_processor, tokenizer=tokenizer)
except:
# Fallback to manual assignment
dotsocr_processor.image_processor = image_processor
dotsocr_processor.tokenizer = tokenizer
else:
dotsocr_processor.image_processor = image_processor
dotsocr_processor.tokenizer = tokenizer
# Set video_processor to None if the attribute exists
if hasattr(dotsocr_processor, 'video_processor'):
dotsocr_processor.video_processor = None
except Exception:
# Final fallback: create processor without video_processor argument
dotsocr_processor = Qwen2_5_VLProcessor.__new__(Qwen2_5_VLProcessor)
dotsocr_processor.image_processor = image_processor
dotsocr_processor.tokenizer = tokenizer
if hasattr(dotsocr_processor, 'video_processor'):
dotsocr_processor.video_processor = None
except Exception as patch_error:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr processor initialization failed with video_processor error. Manual patch failed. Original error: {error_str}, Patch error: {str(patch_error)}. Note: The gated repo fix is not accessible. Please ensure you have transformers >= 4.47.0 installed."
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
elif "LlamaFlashAttention2" in error_str or "cannot import name" in error_str:
# If processor loading fails due to flash attention, it's likely a transformers version issue
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr processor initialization failed. This may require a newer transformers version. Error: {error_str}"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
else:
raise
except RuntimeError:
# Re-raise RuntimeError as-is (from version check)
raise
except ImportError as e:
error_str = str(e)
if "Qwen2_5_VLProcessor" in error_str:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr requires transformers >= 4.47.0 for Qwen2_5_VLProcessor. Current transformers version may be too old. Please upgrade: pip install transformers>=4.47.0. Error: {error_str}"
elif "LlamaFlashAttention2" in error_str or "cannot import name" in error_str:
# This error was already handled above, but if it reaches here, provide helpful message
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr model initialization failed due to flash attention compatibility issue. This may require a newer transformers version or disabling flash attention. Error: {error_str}"
else:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr model initialization failed: {error_str}"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
except Exception as e:
error_str = str(e)
if "video_processor" in error_str or "BaseVideoProcessor" in error_str:
# Try manual patch for processor (skip gated repo)
try:
from transformers import Qwen2_5_VLProcessor
from transformers import AutoImageProcessor, AutoTokenizer
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Create processor manually
try:
dotsocr_processor = Qwen2_5_VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer
)
except TypeError:
try:
dotsocr_processor = Qwen2_5_VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
video_processor=None
)
except (TypeError, ValueError):
# Last resort: try to create processor by patching the class signature
import inspect
try:
sig = inspect.signature(Qwen2_5_VLProcessor.__init__)
params = sig.parameters
if 'video_processor' in params and params['video_processor'].default is not inspect.Parameter.empty:
dotsocr_processor = Qwen2_5_VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer
)
else:
# Create processor without video_processor argument
dotsocr_processor = Qwen2_5_VLProcessor.__new__(Qwen2_5_VLProcessor)
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
if hasattr(Qwen2VLProcessor, '__init__'):
try:
Qwen2VLProcessor.__init__(dotsocr_processor, image_processor=image_processor, tokenizer=tokenizer)
except Exception:
dotsocr_processor.image_processor = image_processor
dotsocr_processor.tokenizer = tokenizer
else:
dotsocr_processor.image_processor = image_processor
dotsocr_processor.tokenizer = tokenizer
if hasattr(dotsocr_processor, 'video_processor'):
dotsocr_processor.video_processor = None
except Exception:
# Final fallback
dotsocr_processor = Qwen2_5_VLProcessor.__new__(Qwen2_5_VLProcessor)
dotsocr_processor.image_processor = image_processor
dotsocr_processor.tokenizer = tokenizer
if hasattr(dotsocr_processor, 'video_processor'):
dotsocr_processor.video_processor = None
except Exception as patch_error:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr model initialization failed with video_processor error. Manual patch failed. Original error: {error_str}, Patch error: {str(patch_error)}. Note: The gated repo fix is not accessible. Please ensure you have transformers >= 4.47.0 installed."
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
elif "Qwen2_5_VLProcessor" in error_str:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr requires transformers >= 4.47.0 for Qwen2_5_VLProcessor. Current transformers version may be too old. Please upgrade: pip install transformers>=4.47.0. Error: {error_str}"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
elif "LlamaFlashAttention2" in error_str or "cannot import name" in error_str:
# This error was already handled above, but if it reaches here, provide helpful message
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr model initialization failed due to flash attention compatibility issue. This may require a newer transformers version or disabling flash attention. Error: {error_str}"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
else:
DOTSOCR_AVAILABLE = False
DOTSOCR_ERROR_MESSAGE = f"dots.ocr model initialization failed: {error_str}"
raise RuntimeError(DOTSOCR_ERROR_MESSAGE)
def process_image_olmocr(image, prompt=None):
"""Process image using olmOCR and return results in Markdown format.
Args:
image: PIL Image to process
prompt: Optional custom prompt. If None, uses default document parsing prompt.
"""
if image is None:
return " Error Upload image", "", "", None, []
# Lazy init to avoid import-time errors
global olmocr_model, olmocr_processor, OLMOCR_AVAILABLE, OLMOCR_ERROR_MESSAGE
if not OLMOCR_AVAILABLE:
msg = OLMOCR_ERROR_MESSAGE or "olmOCR not available. Install with: pip install git+https://github.com/allenai/olmocr.git (requires Python >=3.11)"
return f" {msg}", "", "", None, []
try:
_init_olmocr_model()
except RuntimeError as e:
return f" {str(e)}", "", "", None, []
if image.mode in ('RGBA', 'LA', 'P'):
image = image.convert('RGB')
image = ImageOps.exif_transpose(image)
# Resize image so longest dimension is 1288 pixels
image_resized = _resize_image_for_olmocr(image, target_longest_dim=1288)
# Build the prompt
if prompt:
text_prompt = prompt
else:
text_prompt = build_no_anchoring_v4_yaml_prompt()
# Convert image to base64
buf = BytesIO()
image_resized.save(buf, format='PNG')
image_base64 = base64.b64encode(buf.getvalue()).decode()
# Build messages
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
try:
# Apply chat template and processor
text = olmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = olmocr_processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {key: value.to(device) for (key, value) in inputs.items()}
# Generate output
with torch.no_grad():
output = olmocr_model.generate(
**inputs,
temperature=0.1,
max_new_tokens=2048,
num_return_sequences=1,
do_sample=True,
)
# Decode output
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = olmocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
result = text_output[0] if text_output else ""
if not result or not result.strip():
return "No text detected", "", "", None, []
# Extract markdown from YAML frontmatter if present
# olmOCR output format: ---\n...\n---\n<content>
if result.startswith("---"):
parts = result.split("---", 2)
if len(parts) >= 3:
result = parts[2].strip()
# Return same content for text, markdown, and raw
return result, result, result, None, []
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
warnings.warn(error_msg)
return f"Error: {str(e)}", "", "", None, []
def process_image_dotsocr(image, prompt=None):
"""Process image using dots.ocr and return results in Markdown format.
Args:
image: PIL Image to process
prompt: Optional custom prompt. If None, uses default document parsing prompt.
"""
if image is None:
return " Error Upload image", "", "", None, []
# Lazy init to avoid import-time errors
global dotsocr_model, dotsocr_processor, DOTSOCR_AVAILABLE, DOTSOCR_ERROR_MESSAGE
if not DOTSOCR_AVAILABLE:
msg = DOTSOCR_ERROR_MESSAGE or "dots.ocr not available. Install with: pip install qwen-vl-utils"
return f" {msg}", "", "", None, []
try:
_init_dotsocr_model()
except RuntimeError as e:
return f" {str(e)}", "", "", None, []
if image.mode in ('RGBA', 'LA', 'P'):
image = image.convert('RGB')
image = ImageOps.exif_transpose(image)
# Save image to temporary file
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(tmp.name, 'JPEG', quality=95)
tmp.close()
# Build the prompt
if prompt:
text_prompt = prompt
else:
# Default prompt for dots.ocr document parsing
text_prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object."""
# Build messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": tmp.name
},
{"type": "text", "text": text_prompt}
]
}
]
try:
# Preparation for inference
text = dotsocr_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = dotsocr_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {key: value.to(device) for (key, value) in inputs.items()}
# Generate output
with torch.no_grad():
generated_ids = dotsocr_model.generate(**inputs, max_new_tokens=24000)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = dotsocr_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
result = output_text[0] if output_text else ""
# Clean up temp file
os.unlink(tmp.name)
if not result or not result.strip():
return "No text detected", "", "", None, []
# Try to parse JSON and convert to markdown
try:
layout_data = json.loads(result)
# Convert layout data to markdown
markdown_parts = []
if isinstance(layout_data, dict) and "layout" in layout_data:
layout_items = layout_data["layout"]
elif isinstance(layout_data, list):
layout_items = layout_data
else:
layout_items = [layout_data]
for item in layout_items:
if isinstance(item, dict):
category = item.get("category", "")
text_content = item.get("text", "")
bbox = item.get("bbox", [])
if category == "Title":
markdown_parts.append(f"# {text_content}")
elif category == "Section-header":
markdown_parts.append(f"## {text_content}")
elif category == "Table":
markdown_parts.append(text_content) # Already HTML formatted
elif category == "Formula":
markdown_parts.append(f"$${text_content}$$") # LaTeX formula
elif category == "Picture":
markdown_parts.append(f"![Image](bbox: {bbox})")
else:
markdown_parts.append(text_content)
markdown_result = "\n\n".join(markdown_parts)
return markdown_result, markdown_result, result, None, []
except json.JSONDecodeError:
# If not JSON, return as-is
return result, result, result, None, []
except Exception as e:
# Clean up temp file
try:
os.unlink(tmp.name)
except:
pass
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
warnings.warn(error_msg)
return f"Error: {str(e)}", "", "", None, []
@spaces.GPU(duration=120)
def process_pdf(path, mode_label, task_label, custom_prompt, dpi=300, page_indices=None, embed_figures=False, high_accuracy=False, insert_separators=True, max_retries=3, retry_backoff_seconds=3):
doc = fitz.open(path)
texts, markdowns, raws, all_crops = [], [], [], []
if page_indices is None:
page_indices = list(range(len(doc)))
for i in page_indices:
# Cache key for DeepSeekOCR per page
cache_key = ("DeepSeekOCR", path, int(dpi), int(i), mode_label, task_label, bool(embed_figures), bool(high_accuracy))
cached = _page_cache_get(cache_key)
if cached:
text, md, raw, crops = cached
if text and text.strip() and text != "No text":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops or [])
continue
page = doc.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
# Retry loop to handle GPU timeouts/busy states gracefully
attempt = 0
while True:
try:
text, md, raw, _, crops = process_image(img, mode_label, task_label, custom_prompt, embed_figures=embed_figures, high_accuracy=high_accuracy)
# If we got a result (even if it's "No text"), break the retry loop
if text is not None:
break
# If we got None or empty, retry
attempt += 1
if attempt >= max_retries:
text, md, raw, crops = "", f"<!-- Failed to process page {i+1} after retries -->", "", []
break
time.sleep(retry_backoff_seconds * attempt)
except Exception as e:
attempt += 1
if attempt >= max_retries:
error_msg = f"Error processing page {i+1}: {str(e)}"
text, md, raw, crops = error_msg, f"<!-- {error_msg} -->", error_msg, []
warnings.warn(error_msg)
break
time.sleep(retry_backoff_seconds * attempt)
# Check for valid text results (not empty, not error messages, not "No text")
if text and text.strip() and text != "No text" and not text.startswith("Error") and not text.startswith(" "):
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops)
_page_cache_set(cache_key, (text, md, raw, crops))
elif text and (text.startswith("Error") or text.startswith("Inference error")):
# Include error messages in output for debugging
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n<!-- {text} -->")
raws.append(f"=== Page {i + 1} ===\n{text}")
doc.close()
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts) if texts else "",
sep.join(markdowns) if markdowns else "",
"\n\n".join(raws), None, all_crops)
def process_pdf_all(path, mode_label, task_label, custom_prompt, dpi=300, page_range_text="", embed_figures=False, high_accuracy=False, insert_separators=True, batch_size=3, max_retries=5, retry_backoff_seconds=5):
doc = fitz.open(path)
total_pages = len(doc)
doc.close()
# Parse page range like "1-3,5"
def parse_ranges(s, total):
if not s.strip():
return list(range(total))
pages = set()
parts = [p.strip() for p in s.split(',') if p.strip()]
for part in parts:
if '-' in part:
a, b = part.split('-', 1)
try:
a, b = int(a) - 1, int(b) - 1
except:
continue
for x in range(max(0, a), min(total - 1, b) + 1):
pages.add(x)
else:
try:
idx = int(part) - 1
if 0 <= idx < total:
pages.add(idx)
except:
continue
return sorted(pages)
target_pages = parse_ranges(page_range_text, total_pages)
texts_all, mds_all, raws_all, crops_all = [], [], [], []
for start in range(0, len(target_pages), batch_size):
batch = target_pages[start:start+batch_size]
# Orchestrate retries outside GPU scope (retries at chunk level)
attempt = 0
while True:
try:
tx, mdx, rawx, _, cropsx = process_pdf(path, mode_label, task_label, custom_prompt, dpi=dpi, page_indices=batch, embed_figures=embed_figures, high_accuracy=high_accuracy, insert_separators=insert_separators)
break
except Exception:
attempt += 1
if attempt >= max_retries:
tx, mdx, rawx, cropsx = "", "\n\n".join([f"<!-- Failed batch {start//batch_size+1} -->"]), "", []
break
time.sleep(retry_backoff_seconds * attempt)
if tx:
texts_all.append(tx)
if mdx:
mds_all.append(mdx)
if rawx:
raws_all.append(rawx)
crops_all.extend(cropsx)
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts_all) if texts_all else "No text in PDF",
sep.join(mds_all) if mds_all else "No text in PDF",
"\n\n".join(raws_all), None, crops_all)
def process_pdf_all_gemini(path, dpi=300, page_range_text="", insert_separators=True, batch_size=3, max_retries=5, retry_backoff_seconds=5):
if not GEMINI_AVAILABLE or not GEMINI_KEYS:
return "Gemini not available or no keys", "", "", None, []
doc = fitz.open(path)
total_pages = len(doc)
doc.close()
def parse_ranges(s, total):
if not s.strip():
return list(range(total))
pages = set()
parts = [p.strip() for p in s.split(',') if p.strip()]
for part in parts:
if '-' in part:
a, b = part.split('-', 1)
try:
a, b = int(a) - 1, int(b) - 1
except:
continue
for x in range(max(0, a), min(total - 1, b) + 1):
pages.add(x)
else:
try:
idx = int(part) - 1
if 0 <= idx < total:
pages.add(idx)
except:
continue
return sorted(pages)
target_pages = parse_ranges(page_range_text, total_pages)
texts_all, mds_all, raws_all, crops_all = [], [], [], []
for start in range(0, len(target_pages), batch_size):
batch = target_pages[start:start+batch_size]
attempt = 0
while True:
try:
# Process each page to image and pass to Gemini
docx = fitz.open(path)
for i in batch:
cache_key = ("Gemini", path, int(dpi), int(i), GEMINI_MODEL)
cached = _page_cache_get(cache_key)
if cached:
text, md, raw, crops = cached
if text and text.strip() and text != "No text":
texts_all.append(f"### Page {i + 1}\n\n{text}")
mds_all.append(f"### Page {i + 1}\n\n{md}")
raws_all.append(f"=== Page {i + 1} ===\n{raw}")
crops_all.extend(crops or [])
continue
page = docx.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
text, md, raw, _, crops = process_image_gemini(img)
if text and text != "No text" and not text.startswith("Error"):
texts_all.append(f"### Page {i + 1}\n\n{text}")
mds_all.append(f"### Page {i + 1}\n\n{md}")
raws_all.append(f"=== Page {i + 1} ===\n{raw}")
crops_all.extend(crops or [])
_page_cache_set(cache_key, (text, md, raw, crops))
elif text and text.startswith("Error"):
texts_all.append(f"### Page {i + 1}\n\n{text}")
mds_all.append(f"### Page {i + 1}\n\n<!-- {text} -->")
raws_all.append(f"=== Page {i + 1} ===\n{text}")
docx.close()
break
except Exception:
attempt += 1
if attempt >= max_retries:
mds_all.append(f"<!-- Failed batch {start//batch_size+1} -->")
break
time.sleep(retry_backoff_seconds * attempt)
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts_all) if texts_all else "No text in PDF",
sep.join(mds_all) if mds_all else "No text in PDF",
"\n\n".join(raws_all), None, crops_all)
def process_pdf_paddleocrvl(path, dpi=300, page_indices=None, insert_separators=True, max_retries=3, retry_backoff_seconds=3):
"""Process PDF using PaddleOCR-VL."""
global paddleocrvl_pipeline, PADDLEOCRVL_AVAILABLE, PADDLEOCRVL_ERROR_MESSAGE, PaddleOCRVL
if PaddleOCR is None or PaddleOCRVL is None:
try:
_import_paddleocr()
except Exception as e:
PADDLEOCRVL_AVAILABLE = False
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR-VL setup failed: {e}"
# Early exit if engine is unavailable
if not PADDLEOCRVL_AVAILABLE or paddleocrvl_pipeline is None or PaddleOCRVL is None:
msg = PADDLEOCRVL_ERROR_MESSAGE or "PaddleOCR-VL not available. Install with: pip install 'paddleocr[doc-parser]'"
return msg, f"<!-- {msg} -->", msg, None, []
doc = fitz.open(path)
texts, markdowns, raws, all_crops = [], [], [], []
if page_indices is None:
page_indices = list(range(len(doc)))
for i in page_indices:
cache_key = ("PaddleOCR-VL", path, int(dpi), int(i))
cached = _page_cache_get(cache_key)
if cached:
text, md, raw, crops = cached
if text and text.strip() and text != "No text detected":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops or [])
continue
page = doc.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
attempt = 0
while True:
try:
text, md, raw, _, crops = process_image_paddleocrvl(img)
break
except Exception:
attempt += 1
if attempt >= max_retries:
text, md, raw, crops = "", f"<!-- Failed to process page {i+1} after retries -->", "", []
break
time.sleep(retry_backoff_seconds * attempt)
if text and text != "No text detected":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops)
_page_cache_set(cache_key, (text, md, raw, crops))
doc.close()
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts) if texts else "",
sep.join(markdowns) if markdowns else "",
"\n\n".join(raws), None, all_crops)
def process_pdf_all_paddleocrvl(path, dpi=300, page_range_text="", insert_separators=True, batch_size=3, max_retries=5, retry_backoff_seconds=5):
"""Process all pages of PDF using PaddleOCR-VL."""
global paddleocrvl_pipeline, PADDLEOCRVL_AVAILABLE, PADDLEOCRVL_ERROR_MESSAGE, PaddleOCRVL
if PaddleOCR is None or PaddleOCRVL is None:
try:
_import_paddleocr()
except Exception as e:
PADDLEOCRVL_AVAILABLE = False
PADDLEOCRVL_ERROR_MESSAGE = f"PaddleOCR-VL setup failed: {e}"
# Early exit if engine is unavailable
if not PADDLEOCRVL_AVAILABLE or paddleocrvl_pipeline is None or PaddleOCRVL is None:
msg = PADDLEOCRVL_ERROR_MESSAGE or "PaddleOCR-VL not available. Install with: pip install 'paddleocr[doc-parser]'"
return msg, f"<!-- {msg} -->", msg, None, []
doc = fitz.open(path)
total_pages = len(doc)
doc.close()
def parse_ranges(s, total):
if not s.strip():
return list(range(total))
pages = set()
parts = [p.strip() for p in s.split(',') if p.strip()]
for part in parts:
if '-' in part:
a, b = part.split('-', 1)
try:
a, b = int(a) - 1, int(b) - 1
except:
continue
for x in range(max(0, a), min(total - 1, b) + 1):
pages.add(x)
else:
try:
idx = int(part) - 1
if 0 <= idx < total:
pages.add(idx)
except:
continue
return sorted(pages)
target_pages = parse_ranges(page_range_text, total_pages)
texts_all, mds_all, raws_all, crops_all = [], [], [], []
for start in range(0, len(target_pages), batch_size):
batch = target_pages[start:start+batch_size]
attempt = 0
while True:
try:
tx, mdx, rawx, _, cropsx = process_pdf_paddleocrvl(path, dpi=dpi, page_indices=batch, insert_separators=insert_separators)
break
except Exception:
attempt += 1
if attempt >= max_retries:
tx, mdx, rawx, cropsx = "", "\n\n".join([f"<!-- Failed batch {start//batch_size+1} -->"]), "", []
break
time.sleep(retry_backoff_seconds * attempt)
if tx:
texts_all.append(tx)
if mdx:
mds_all.append(mdx)
if rawx:
raws_all.append(rawx)
crops_all.extend(cropsx)
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts_all) if texts_all else "No text in PDF",
sep.join(mds_all) if mds_all else "No text in PDF",
"\n\n".join(raws_all), None, crops_all)
def process_pdf_olmocr(path, dpi=300, page_indices=None, insert_separators=True, max_retries=3, retry_backoff_seconds=3):
"""Process PDF using olmOCR."""
# Early exit if engine is unavailable
if not OLMOCR_AVAILABLE:
msg = OLMOCR_ERROR_MESSAGE or "olmOCR not available. Install with: pip install git+https://github.com/allenai/olmocr.git (requires Python >=3.11)"
return msg, f"<!-- {msg} -->", msg, None, []
try:
_init_olmocr_model()
except RuntimeError as e:
return str(e), f"<!-- {str(e)} -->", str(e), None, []
doc = fitz.open(path)
texts, markdowns, raws, all_crops = [], [], [], []
if page_indices is None:
page_indices = list(range(len(doc)))
for i in page_indices:
cache_key = ("olmOCR", path, int(dpi), int(i))
cached = _page_cache_get(cache_key)
if cached:
text, md, raw, crops = cached
if text and text.strip() and text != "No text detected":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops or [])
continue
page = doc.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
attempt = 0
while True:
try:
text, md, raw, _, crops = process_image_olmocr(img)
break
except Exception:
attempt += 1
if attempt >= max_retries:
text, md, raw, crops = "", f"<!-- Failed to process page {i+1} after retries -->", "", []
break
time.sleep(retry_backoff_seconds * attempt)
if text and text != "No text detected":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops)
_page_cache_set(cache_key, (text, md, raw, crops))
doc.close()
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts) if texts else "",
sep.join(markdowns) if markdowns else "",
"\n\n".join(raws), None, all_crops)
def process_pdf_all_olmocr(path, dpi=300, page_range_text="", insert_separators=True, batch_size=3, max_retries=5, retry_backoff_seconds=5):
"""Process all pages of PDF using olmOCR."""
# Early exit if engine is unavailable
if not OLMOCR_AVAILABLE:
msg = OLMOCR_ERROR_MESSAGE or "olmOCR not available. Install with: pip install git+https://github.com/allenai/olmocr.git (requires Python >=3.11)"
return msg, f"<!-- {msg} -->", msg, None, []
doc = fitz.open(path)
total_pages = len(doc)
doc.close()
def parse_ranges(s, total):
if not s.strip():
return list(range(total))
pages = set()
parts = [p.strip() for p in s.split(',') if p.strip()]
for part in parts:
if '-' in part:
a, b = part.split('-', 1)
try:
a, b = int(a) - 1, int(b) - 1
except:
continue
for x in range(max(0, a), min(total - 1, b) + 1):
pages.add(x)
else:
try:
idx = int(part) - 1
if 0 <= idx < total:
pages.add(idx)
except:
continue
return sorted(pages)
target_pages = parse_ranges(page_range_text, total_pages)
texts_all, mds_all, raws_all, crops_all = [], [], [], []
for start in range(0, len(target_pages), batch_size):
batch = target_pages[start:start+batch_size]
attempt = 0
while True:
try:
tx, mdx, rawx, _, cropsx = process_pdf_olmocr(path, dpi=dpi, page_indices=batch, insert_separators=insert_separators)
break
except Exception:
attempt += 1
if attempt >= max_retries:
tx, mdx, rawx, cropsx = "", "\n\n".join([f"<!-- Failed batch {start//batch_size+1} -->"]), "", []
break
time.sleep(retry_backoff_seconds * attempt)
if tx:
texts_all.append(tx)
if mdx:
mds_all.append(mdx)
if rawx:
raws_all.append(rawx)
crops_all.extend(cropsx)
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts_all) if texts_all else "No text in PDF",
sep.join(mds_all) if mds_all else "No text in PDF",
"\n\n".join(raws_all), None, crops_all)
def process_pdf_dotsocr(path, dpi=300, page_indices=None, insert_separators=True, max_retries=3, retry_backoff_seconds=3):
"""Process PDF using dots.ocr."""
# Early exit if engine is unavailable
if not DOTSOCR_AVAILABLE:
msg = DOTSOCR_ERROR_MESSAGE or "dots.ocr not available. Install with: pip install qwen-vl-utils"
return msg, f"<!-- {msg} -->", msg, None, []
try:
_init_dotsocr_model()
except RuntimeError as e:
return str(e), f"<!-- {str(e)} -->", str(e), None, []
doc = fitz.open(path)
texts, markdowns, raws, all_crops = [], [], [], []
if page_indices is None:
page_indices = list(range(len(doc)))
for i in page_indices:
cache_key = ("dots.ocr", path, int(dpi), int(i))
cached = _page_cache_get(cache_key)
if cached:
text, md, raw, crops = cached
if text and text.strip() and text != "No text detected":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops or [])
continue
page = doc.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
attempt = 0
while True:
try:
text, md, raw, _, crops = process_image_dotsocr(img)
break
except Exception:
attempt += 1
if attempt >= max_retries:
text, md, raw, crops = "", f"<!-- Failed to process page {i+1} after retries -->", "", []
break
time.sleep(retry_backoff_seconds * attempt)
if text and text != "No text detected":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops)
_page_cache_set(cache_key, (text, md, raw, crops))
doc.close()
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts) if texts else "",
sep.join(markdowns) if markdowns else "",
"\n\n".join(raws), None, all_crops)
def process_pdf_all_dotsocr(path, dpi=300, page_range_text="", insert_separators=True, batch_size=3, max_retries=5, retry_backoff_seconds=5):
"""Process all pages of PDF using dots.ocr."""
# Early exit if engine is unavailable
if not DOTSOCR_AVAILABLE:
msg = DOTSOCR_ERROR_MESSAGE or "dots.ocr not available. Install with: pip install qwen-vl-utils"
return msg, f"<!-- {msg} -->", msg, None, []
doc = fitz.open(path)
total_pages = len(doc)
doc.close()
def parse_ranges(s, total):
if not s.strip():
return list(range(total))
pages = set()
parts = [p.strip() for p in s.split(',') if p.strip()]
for part in parts:
if '-' in part:
a, b = part.split('-', 1)
try:
a, b = int(a) - 1, int(b) - 1
except:
continue
for x in range(max(0, a), min(total - 1, b) + 1):
pages.add(x)
else:
try:
idx = int(part) - 1
if 0 <= idx < total:
pages.add(idx)
except:
continue
return sorted(pages)
target_pages = parse_ranges(page_range_text, total_pages)
texts_all, mds_all, raws_all, crops_all = [], [], [], []
for start in range(0, len(target_pages), batch_size):
batch = target_pages[start:start+batch_size]
attempt = 0
while True:
try:
tx, mdx, rawx, _, cropsx = process_pdf_dotsocr(path, dpi=dpi, page_indices=batch, insert_separators=insert_separators)
break
except Exception:
attempt += 1
if attempt >= max_retries:
tx, mdx, rawx, cropsx = "", "\n\n".join([f"<!-- Failed batch {start//batch_size+1} -->"]), "", []
break
time.sleep(retry_backoff_seconds * attempt)
if tx:
texts_all.append(tx)
if mdx:
mds_all.append(mdx)
if rawx:
raws_all.append(rawx)
crops_all.extend(cropsx)
sep = "\n\n---\n\n" if insert_separators else "\n\n"
return (sep.join(texts_all) if texts_all else "No text in PDF",
sep.join(mds_all) if mds_all else "No text in PDF",
"\n\n".join(raws_all), None, crops_all)
def process_file(path, mode_label, task_label, custom_prompt, dpi=300, page_range_text="", embed_figures=False, high_accuracy=False, insert_separators=True, ocr_engine="DeepSeekOCR"):
if not path:
return "Error Upload file", "", "", None, []
if ocr_engine == "PaddleOCR-VL" or ocr_engine == "PaddleOCR":
if path.lower().endswith('.pdf'):
return process_pdf_all_paddleocrvl(path, dpi=dpi, page_range_text=page_range_text, insert_separators=insert_separators)
else:
return process_image_paddleocrvl(Image.open(path))
elif ocr_engine == "olmOCR":
if path.lower().endswith('.pdf'):
return process_pdf_all_olmocr(path, dpi=dpi, page_range_text=page_range_text, insert_separators=insert_separators)
else:
return process_image_olmocr(Image.open(path))
elif ocr_engine == "dots.ocr":
if path.lower().endswith('.pdf'):
return process_pdf_all_dotsocr(path, dpi=dpi, page_range_text=page_range_text, insert_separators=insert_separators)
else:
return process_image_dotsocr(Image.open(path))
else:
if path.lower().endswith('.pdf'):
return process_pdf_all(path, mode_label, task_label, custom_prompt, dpi=dpi, page_range_text=page_range_text, embed_figures=embed_figures, high_accuracy=high_accuracy, insert_separators=insert_separators)
else:
return process_image(Image.open(path), mode_label, task_label, custom_prompt, embed_figures=embed_figures, high_accuracy=high_accuracy)
def toggle_prompt(task_label):
if task_label == "Custom":
return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
elif task_label == "Locate":
return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
return gr.update(visible=False)
def load_image(file_path):
if not file_path:
return None
if file_path.lower().endswith('.pdf'):
doc = fitz.open(file_path)
page = doc.load_page(0)
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
doc.close()
return img
else:
return Image.open(file_path)
def get_pdf_page_count(file_path):
try:
doc = fitz.open(file_path)
n = len(doc)
doc.close()
return n
except Exception:
return 1
def render_pdf_page(file_path, page_number, dpi_value):
try:
doc = fitz.open(file_path)
idx = max(1, min(page_number, len(doc))) - 1
page = doc.load_page(idx)
pix = page.get_pixmap(matrix=fitz.Matrix(dpi_value/72, dpi_value/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
doc.close()
return img
except Exception:
return None
def build_blocks(theme):
with gr.Blocks(theme=theme, title="OCR-VLs") as demo:
gr.Markdown("""
# OCR-VLs WebUI
**Convert documents to markdown, extract raw text, and locate specific content with bounding boxes.**
""")
with gr.Row():
with gr.Column(scale=1):
# Uploader container
file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath")
input_img = gr.Image(label="Input Image", type="pil", height=300)
# PDF preview page selector container (visible only for PDFs)
page_seps = gr.Checkbox(value=True, label="Insert page separators (---)")
page_slider = gr.Slider(1, 1, value=1, step=1, label="Preview page", visible=False)
# OCR Engine selector
ocr_engine = gr.Radio(
choices=[c for c in ["DeepSeekOCR", "PaddleOCR-VL", "olmOCR", "dots.ocr", "Gemini Flash 2.5"] if (c != "PaddleOCR-VL" or PADDLEOCRVL_AVAILABLE) and (c != "Gemini Flash 2.5" or GEMINI_AVAILABLE) and (c != "olmOCR" or OLMOCR_AVAILABLE) and (c != "dots.ocr" or DOTSOCR_AVAILABLE)],
value="DeepSeekOCR",
label="OCR Engine",
info="Choose between DeepSeekOCR, PaddleOCR-VL, olmOCR, or dots.ocr, Gemini Flash 2.5"
)
# Processing options container (for DeepSeekOCR)
mode = gr.Dropdown(list(MODE_LABEL_TO_KEY.keys()), value="Gundam", label="Mode (DeepSeekOCR)")
task = gr.Dropdown(list(TASK_LABEL_TO_KEY.keys()), value="Markdown", label="Task (DeepSeekOCR)")
prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
with gr.Row():
embed_fig = gr.Checkbox(value=True, label="Embed figures into Markdown")
high_acc = gr.Checkbox(value=False, label="High accuracy (slower)")
with gr.Row():
dpi = gr.Slider(150, 600, value=300, step=50, label="PDF DPI")
page_range = gr.Textbox(label="Page range (e.g. 1-3,5)", placeholder="All pages")
btn = gr.Button("Extract", variant="primary", size="lg")
# Second row container
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("Text"):
text_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
dl_txt = gr.DownloadButton(label="Download Text", value=None)
with gr.Tab("Markdown"):
md_out = gr.Markdown("")
with gr.Row():
dl_md = gr.DownloadButton(label="Download Markdown", value=None)
dl_md_zip = gr.DownloadButton(label="Download Markdown (split pages)", value=None)
with gr.Tab("Boxes"):
img_out = gr.Image(type="pil", height=500, show_label=False)
with gr.Tab("Cropped Images"):
gallery = gr.Gallery(show_label=False, columns=3, height=400)
with gr.Tab("Raw"):
raw_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
with gr.Accordion("ℹ️ Info", open=False):
gr.Markdown("""
### OCR Engines
- **DeepSeekOCR**: AI-powered OCR with advanced document understanding and markdown conversion
- **PaddleOCR-VL**: Document parsing model that converts documents to markdown format (install with: `pip install 'paddleocr[doc-parser]'`)
- **Gemini Flash 2.5**: Google Gemini model for fast, high-quality Markdown conversion (set GEMINI_API_1..5 in .env)
- **olmOCR**: Vision-language model for document OCR (requires Python >=3.11)
- **dots.ocr**: Multilingual document parser with SOTA performance on layout detection and content recognition (install with: `pip install qwen-vl-utils`)
### DeepSeekOCR Modes
- Gundam: 1024 base + 640 tiles with cropping - Best balance
- Tiny: 512×512, no crop - Fastest
- Small: 640×640, no crop - Quick
- Base: 1024×1024, no crop - Standard
- Large: 1280×1280, no crop - Highest quality
### DeepSeekOCR Tasks
- Markdown: Convert document to structured markdown (grounding)
- Tables: Extract tables only as Markdown (grounding)
- Locate: Find specific text in image (grounding)
- Describe: General image description
- Custom: Your own prompt (add `<|grounding|>` for boxes)
### PaddleOCR-VL
- Document parsing model that automatically converts documents to markdown
- Supports both images and PDFs
### olmOCR
- Vision-language model based on Qwen2.5-VL-7B-Instruct
- Automatically converts documents to markdown format
- Supports both images and PDFs
- Model: allenai/olmOCR-2-7B-1025
- **Requires Python >=3.11** - For Hugging Face Spaces, create a `runtime.txt` file with `python-3.11` or higher
### dots.ocr
- Multilingual document parser based on 1.7B LLM with SOTA performance
- Achieves state-of-the-art results for text, tables, and reading order
- Supports both images and PDFs
- Model: rednote-hilab/dots.ocr
- **Requires transformers >= 4.47.0** - Please upgrade transformers if you see import errors
### Gemini Flash 2.5
- Google Gemini model for fast, high-quality Markdown conversion
""")
# Enhanced preview logic for PDFs: show the selected page and slider
def init_preview(file_path, dpi_value):
fp = None
if isinstance(file_path, str):
fp = file_path
elif isinstance(file_path, dict):
fp = file_path.get('name') or file_path.get('path')
if not fp:
return None, gr.update(visible=False)
if fp.lower().endswith('.pdf'):
total = get_pdf_page_count(fp)
img = render_pdf_page(fp, 1, int(dpi_value))
return img, gr.update(visible=True, minimum=1, maximum=max(1, total), value=1)
# Non-PDF
try:
return Image.open(fp), gr.update(visible=False)
except Exception:
return None, gr.update(visible=False)
def update_preview_page(file_path, page_num, dpi_value):
fp = None
if isinstance(file_path, str):
fp = file_path
elif isinstance(file_path, dict):
fp = file_path.get('name') or file_path.get('path')
if fp and fp.lower().endswith('.pdf'):
return render_pdf_page(fp, int(page_num), int(dpi_value))
return input_img.value
file_in.change(init_preview, [file_in, dpi], [input_img, page_slider])
page_slider.change(update_preview_page, [file_in, page_slider, dpi], [input_img])
dpi.release(update_preview_page, [file_in, page_slider, dpi], [input_img])
task.change(toggle_prompt, [task], [prompt])
def toggle_ocr_engine(engine):
"""Show/hide controls based on selected OCR engine."""
if engine in ["PaddleOCR-VL", "Gemini Flash 2.5", "olmOCR", "dots.ocr"]:
return (
gr.update(visible=False), # mode
gr.update(visible=False), # task
gr.update(visible=False), # prompt
gr.update(visible=False), # embed_fig
gr.update(visible=False) # high_acc
)
else:
return (
gr.update(visible=True), # mode
gr.update(visible=True), # task
gr.update(visible=False), # prompt (will be toggled by task)
gr.update(visible=True), # embed_fig
gr.update(visible=True) # high_acc
)
ocr_engine.change(
toggle_ocr_engine,
[ocr_engine],
[mode, task, prompt, embed_fig, high_acc]
)
def run(image, file_path, ocr_engine_val, mode_label, task_label, custom_prompt, dpi_val, page_range_text, embed, hiacc, sep_pages):
# Normalize file path value from Gradio (can be str or dict)
fp = None
if isinstance(file_path, str):
fp = file_path
elif isinstance(file_path, dict):
fp = file_path.get('name') or file_path.get('path')
# Route to appropriate OCR engine
if ocr_engine_val == "PaddleOCR-VL":
# PaddleOCR-VL processing
if fp and isinstance(fp, str) and fp.lower().endswith('.pdf'):
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="PaddleOCR-VL")
elif image is not None:
text, md, raw, img, crops = process_image_paddleocrvl(image)
elif fp:
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="PaddleOCR-VL")
else:
return "Error uploading file or image", "", "", None, [], None, None, None
elif ocr_engine_val == "Gemini Flash 2.5":
# Gemini processing
if fp and isinstance(fp, str) and fp.lower().endswith('.pdf'):
text, md, raw, img, crops = process_pdf_all_gemini(fp, dpi=int(dpi_val), page_range_text=page_range_text, insert_separators=sep_pages)
elif image is not None:
text, md, raw, img, crops = process_image_gemini(image)
elif fp:
text, md, raw, img, crops = process_pdf_all_gemini(fp, dpi=int(dpi_val), page_range_text=page_range_text, insert_separators=sep_pages)
else:
return "Error uploading file or image", "", "", None, [], None, None, None
elif ocr_engine_val == "olmOCR":
# olmOCR processing
if fp and isinstance(fp, str) and fp.lower().endswith('.pdf'):
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="olmOCR")
elif image is not None:
text, md, raw, img, crops = process_image_olmocr(image)
elif fp:
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="olmOCR")
else:
return "Error uploading file or image", "", "", None, [], None, None, None
elif ocr_engine_val == "dots.ocr":
# dots.ocr processing
if fp and isinstance(fp, str) and fp.lower().endswith('.pdf'):
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="dots.ocr")
elif image is not None:
text, md, raw, img, crops = process_image_dotsocr(image)
elif fp:
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="dots.ocr")
else:
return "Error uploading file or image", "", "", None, [], None, None, None
else:
# DeepSeekOCR processing
if fp and isinstance(fp, str) and fp.lower().endswith('.pdf'):
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="DeepSeekOCR")
elif image is not None:
text, md, raw, img, crops = process_image(image, mode_label, task_label, custom_prompt, embed_figures=embed, high_accuracy=hiacc)
elif fp:
text, md, raw, img, crops = process_file(fp, mode_label, task_label, custom_prompt, dpi=int(dpi_val), page_range_text=page_range_text, embed_figures=embed, high_accuracy=hiacc, insert_separators=sep_pages, ocr_engine="DeepSeekOCR")
else:
return "Error uploading file or image", "", "", None, [], None, None, None
# Create temp files for download
md_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".md")
txt_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
with open(md_tmp.name, 'w', encoding='utf-8') as f:
f.write(md or "")
with open(txt_tmp.name, 'w', encoding='utf-8') as f:
f.write(text or "")
# Optional ZIP split by '---' separators
zip_path = None
try:
if md:
# Split on standalone '---' separator variants
parts = re.split(r"\n\s*---\s*\n", md)
if len(parts) > 1:
zip_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
with zipfile.ZipFile(zip_tmp.name, 'w', zipfile.ZIP_DEFLATED) as zf:
for idx, part in enumerate(parts, start=1):
fname = f"page_{idx:03d}.md"
zf.writestr(fname, part.strip() + "\n")
zip_path = zip_tmp.name
except Exception:
zip_path = None
return text, md, raw, img, crops, md_tmp.name, txt_tmp.name, zip_path
btn.click(run, [input_img, file_in, ocr_engine, mode, task, prompt, dpi, page_range, embed_fig, high_acc, page_seps],
[text_out, md_out, raw_out, img_out, gallery, dl_md, dl_txt, dl_md_zip])
return demo
# Build the demo
demo = build_blocks(gr.themes.Soft())
if __name__ == "__main__":
demo.queue(max_size=20).launch()