Vishwas1's picture
Create app.py
498832d verified
import os
import io
import tempfile
import shutil
from pathlib import Path
import gradio as gr
from PIL import Image
import fitz # PyMuPDF
# --- dots.ocr imports ---
# We install dots.ocr in requirements.txt and download weights at startup.
# The repo recommends saving weights under a path WITHOUT dots in the directory name.
# Ref: README notes about "DotsOCR" dir. (See citations)
from dots_ocr import pipeline as dots_pipeline # if module layout differs, adjust to: from dots_ocr.pipeline import DotsOCR
WEIGHTS_DIR = Path(os.getenv("DOTS_WEIGHTS_DIR", "weights")) / "DotsOCR"
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
# Lazily load the model once per Space session
DOTS = None
def load_model():
global DOTS
if DOTS is None:
# The library exposes a simple pipeline init; if upstream renames it,
# swap to their documented API or helper. The HF Space by others uses similar.
# You can pass prompt presets like "prompt_ocr" or "prompt_grounding_ocr" later.
DOTS = dots_pipeline.load_model(model_dir=str(WEIGHTS_DIR))
return DOTS
def pdf_to_images(pdf_bytes):
"""Return a list of PIL images from a PDF file bytes."""
imgs = []
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
for page in doc:
pix = page.get_pixmap(alpha=False, dpi=180)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
imgs.append(img)
return imgs
# --- Devanagari normalization helpers ---
import re
# Zero-width joiners
ZWJ = "\u200D"
ZWNJ = "\u200C"
NUKTA = "\u093C"
def normalize_devanagari(text: str) -> str:
"""Light-weight cleanup: normalize nukta sequences, strip stray ZWJ/ZWNJ, and fix hyphen linebreaks."""
if not text:
return text
# Remove stray joiners between letters (keep when used in known conjunct patterns if needed)
text = text.replace(ZWJ, "").replace(ZWNJ, "")
# Fix hyphenated line breaks: "मा-\n झी" -> "माझी"
text = re.sub(r"(\w)-\s*\n\s*(\w)", r"\1\2", text, flags=re.UNICODE)
# Collapse multiple spaces/newlines
text = re.sub(r"[ \t]+", " ", text, flags=re.UNICODE)
text = re.sub(r"\n{3,}", "\n\n", text, flags=re.UNICODE)
return text.strip()
def run_ocr(files, task, lang_hint, return_markdown, show_boxes):
model = load_model()
# Prepare images
images = []
for f in files:
if f.name.lower().endswith(".pdf"):
images.extend(pdf_to_images(f.read()))
else:
images.append(Image.open(f.name).convert("RGB"))
# Choose prompt by task
# From repo notes: switch between tasks via prompt like "prompt_ocr", "prompt_layout_only_en", "prompt_grounding_ocr".
# We'll use OCR-focused prompt and optionally layout for Markdown.
prompt = "prompt_ocr" if not show_boxes else "prompt_grounding_ocr"
# Optional language/script hint to guide decoding. Devanagari covers Hindi/Marathi/Sanskrit.
# If upstream exposes a language arg, pass it; otherwise keep in the prompt text.
if lang_hint and lang_hint != "Auto":
prompt = f"{prompt} ({lang_hint})"
# Inference
results = []
for img in images:
out = model.infer(img, prompt=prompt, as_markdown=return_markdown, with_layout=return_markdown)
# out can be dict with 'markdown', 'text', 'boxes' depending on upstream.
# Normalize Devanagari when applicable
text_md = out.get("markdown") if return_markdown else out.get("text") or out.get("markdown", "")
if lang_hint.startswith("Devanagari") or "Hindi" in lang_hint or "Marathi" in lang_hint or "Sanskrit" in lang_hint:
text_md = normalize_devanagari(text_md)
results.append({
"text": text_md,
"boxes_image": out.get("overlay_image") if show_boxes else None
})
# Aggregate outputs
combined_text = "\n\n".join([r["text"] for r in results if r["text"]])
# Save downloadable files
md_path, txt_path = None, None
if return_markdown:
md_path = "output.md"
with open(md_path, "w", encoding="utf-8") as f:
f.write(combined_text)
txt_path = "output.txt"
with open(txt_path, "w", encoding="utf-8") as f:
f.write(re.sub(r"\n{3,}", "\n\n", combined_text))
overlay_gallery = [r["boxes_image"] for r in results if r["boxes_image"] is not None]
return combined_text, (md_path if return_markdown else None), txt_path, overlay_gallery
with gr.Blocks(title="Dots OCR — Indic-ready") as demo:
gr.Markdown(
"# dots.ocr — Multilingual OCR\n"
"Upload PDFs or images. For Hindi/Marathi/Sanskrit, choose **Devanagari**.\n"
"Outputs: Markdown + Plain text. (Best on GPU)"
)
with gr.Row():
files = gr.File(label="PDF(s) or image(s)", file_count="multiple", type="filepath")
return_markdown = gr.Checkbox(True, label="Return Markdown")
with gr.Row():
task = gr.Radio(choices=["OCR"], value="OCR", label="Task")
lang_hint = gr.Dropdown(
label="Language/Script hint",
choices=["Auto", "Devanagari (Hindi/Marathi/Sanskrit)", "Gujarati", "Bengali", "Tamil", "Telugu", "Kannada", "Malayalam", "Punjabi (Gurmukhi)", "Urdu (Arabic)"],
value="Devanagari (Hindi/Marathi/Sanskrit)"
)
show_boxes = gr.Checkbox(False, label="Show layout boxes (slower)")
btn = gr.Button("Run")
out_text = gr.Textbox(label="Recognized text (Markdown or Plain Text)", lines=20)
out_md = gr.File(label="Download Markdown", visible=True)
out_txt = gr.File(label="Download Plain Text", visible=True)
overlay = gr.Gallery(label="Layout overlays", visible=False)
def _toggle(v):
return gr.update(visible=v)
show_boxes.change(_toggle, inputs=show_boxes, outputs=overlay)
btn.click(run_ocr, inputs=[files, task, lang_hint, return_markdown, show_boxes],
outputs=[out_text, out_md, out_txt, overlay])
if __name__ == "__main__":
demo.launch()