deepseek_ocr / app.py
akshayve3's picture
Update app.py
6bca8e9 verified
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import spaces
import os
import sys
import tempfile
import shutil
from PIL import Image, ImageDraw, ImageFont, ImageOps
import fitz
import re
import numpy as np
import base64
from io import StringIO, BytesIO
from pathlib import Path
import time
from docx import Document
from pptx import Presentation
MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True)
model = model.eval().cuda()
MODEL_CONFIGS = {
"Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}
}
TASK_PROMPTS = {
"πŸ“‹ Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
"πŸ“ Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False},
"πŸ“ Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True},
"πŸ” Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False},
"✏️ Custom": {"prompt": "", "has_grounding": False}
}
def extract_grounding_references(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
return re.findall(pattern, text, re.DOTALL)
def draw_bounding_boxes(image, refs, extract_images=False):
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 25)
crops = []
color_map = {}
np.random.seed(42)
for ref in refs:
label = ref[1]
if label not in color_map:
color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
color = color_map[label]
coords = eval(ref[2])
color_a = color + (60,)
for box in coords:
x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h)
if extract_images and label == 'image':
crops.append(image.crop((x1, y1, x2, y2)))
width = 5 if label == 'title' else 3
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
draw2.rectangle([x1, y1, x2, y2], fill=color_a)
text_bbox = draw.textbbox((0, 0), label, font=font)
tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
ty = max(0, y1 - 20)
draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
img_draw.paste(overlay, (0, 0), overlay)
return img_draw, crops
def clean_output(text, include_images=False, remove_labels=False):
if not text:
return ""
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
img_num = 0
for match in matches:
if '<|ref|>image<|/ref|>' in match[0]:
if include_images:
text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
img_num += 1
else:
text = text.replace(match[0], '', 1)
else:
if remove_labels:
text = text.replace(match[0], '', 1)
else:
text = text.replace(match[0], match[1], 1)
return text.strip()
def embed_images(markdown, crops):
if not crops:
return markdown
for i, img in enumerate(crops):
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
return markdown
@spaces.GPU(duration=60)
def process_image(image, mode, task, custom_prompt):
if image is None:
return "Error: Upload image", "", "", None, []
if task in ["✏️ Custom", "πŸ“ Locate"] and not custom_prompt.strip():
return "Enter prompt", "", "", None, []
if image.mode in ('RGBA', 'LA', 'P'):
image = image.convert('RGB')
image = ImageOps.exif_transpose(image)
config = MODEL_CONFIGS[mode]
if task == "✏️ Custom":
prompt = f"<image>\n{custom_prompt.strip()}"
has_grounding = '<|grounding|>' in custom_prompt
elif task == "πŸ“ Locate":
prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
has_grounding = True
else:
prompt = TASK_PROMPTS[task]["prompt"]
has_grounding = TASK_PROMPTS[task]["has_grounding"]
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(tmp.name, 'JPEG', quality=95)
tmp.close()
out_dir = tempfile.mkdtemp()
stdout = sys.stdout
sys.stdout = StringIO()
model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir,
base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip()
sys.stdout = stdout
os.unlink(tmp.name)
shutil.rmtree(out_dir, ignore_errors=True)
if not result:
return "No text", "", "", None, []
cleaned = clean_output(result, False, False)
markdown = clean_output(result, True, True)
img_out = None
crops = []
if has_grounding and '<|ref|>' in result:
refs = extract_grounding_references(result)
if refs:
img_out, crops = draw_bounding_boxes(image, refs, True)
markdown = embed_images(markdown, crops)
return cleaned, markdown, result, img_out, crops
def docx_to_images(path):
doc = Document(path)
images = []
for i, para in enumerate(doc.paragraphs):
if para.text.strip():
img = Image.new('RGB', (800, 1100), color='white')
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
draw.text((50, 50), para.text, fill='black', font=font)
images.append(img)
return images
def pptx_to_images(path):
prs = Presentation(path)
images = []
for i, slide in enumerate(prs.slides):
img = Image.new('RGB', (960, 720), color='white')
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
y = 50
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text.strip():
draw.text((50, y), shape.text, fill='black', font=font)
y += 100
images.append(img)
return images
@spaces.GPU(duration=300)
def process_pdf(path, mode, task, custom_prompt):
doc = fitz.open(path)
texts, markdowns, raws, all_crops = [], [], [], []
box_images = []
for i in range(len(doc)):
page = doc.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
img = Image.open(BytesIO(pix.tobytes("png")))
text, md, raw, box_img, crops = process_image(img, mode, task, custom_prompt)
if text and text != "No text":
texts.append(f"### Page {i + 1}\n\n{text}")
markdowns.append(f"### Page {i + 1}\n\n{md}")
raws.append(f"=== Page {i + 1} ===\n{raw}")
all_crops.extend(crops)
box_images.append(box_img)
total_pages = len(doc)
doc.close()
return ("\n\n---\n\n".join(texts) if texts else "No text in PDF",
"\n\n---\n\n".join(markdowns) if markdowns else "No text in PDF",
"\n\n".join(raws), box_images, all_crops, total_pages)
def save_outputs(doc_name, text_content, md_content, raw_content, box_images, cropped_images):
base_dir = Path("outputs")
base_dir.mkdir(exist_ok=True)
existing_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
folder_num = len(existing_dirs) + 1
doc_folder = base_dir / f"{folder_num:02d}_{doc_name}"
doc_folder.mkdir(exist_ok=True)
(doc_folder / "text_output.txt").write_text(text_content, encoding='utf-8')
(doc_folder / "clean_output.md").write_text(md_content, encoding='utf-8')
(doc_folder / "raw_output.txt").write_text(raw_content, encoding='utf-8')
boxes_dir = doc_folder / "boxes"
boxes_dir.mkdir(exist_ok=True)
for i, img in enumerate(box_images):
if img is not None:
img.save(boxes_dir / f"page_{i+1:02d}_box.jpg")
cropped_dir = doc_folder / "cropped"
cropped_dir.mkdir(exist_ok=True)
for i, img in enumerate(cropped_images):
if img is not None:
img.save(cropped_dir / f"crop_{i+1:02d}.jpg")
return str(doc_folder)
def process_single_file(file_path, mode, task, custom_prompt):
start_time = time.time()
file_name = Path(file_path).stem
ext = Path(file_path).suffix.lower()
if ext == '.pdf':
text, md, raw, box_images, crops, total_pages = process_pdf(file_path, mode, task, custom_prompt)
elif ext == '.docx':
images = docx_to_images(file_path)
texts, mds, raws, box_images, crops = [], [], [], [], []
for i, img in enumerate(images):
text, md, raw, box_img, crp = process_image(img, mode, task, custom_prompt)
texts.append(f"### Page {i+1}\n\n{text}")
mds.append(f"### Page {i+1}\n\n{md}")
raws.append(f"=== Page {i+1} ===\n{raw}")
box_images.append(box_img)
crops.extend(crp)
text = "\n\n---\n\n".join(texts)
md = "\n\n---\n\n".join(mds)
raw = "\n\n".join(raws)
total_pages = len(images)
elif ext == '.pptx':
images = pptx_to_images(file_path)
texts, mds, raws, box_images, crops = [], [], [], [], []
for i, img in enumerate(images):
text, md, raw, box_img, crp = process_image(img, mode, task, custom_prompt)
texts.append(f"### Slide {i+1}\n\n{text}")
mds.append(f"### Slide {i+1}\n\n{md}")
raws.append(f"=== Slide {i+1} ===\n{raw}")
box_images.append(box_img)
crops.extend(crp)
text = "\n\n---\n\n".join(texts)
md = "\n\n---\n\n".join(mds)
raw = "\n\n".join(raws)
total_pages = len(images)
else:
img = Image.open(file_path)
text, md, raw, box_img, crops = process_image(img, mode, task, custom_prompt)
box_images = [box_img] if box_img else []
total_pages = 1
elapsed_time = time.time() - start_time
folder_path = save_outputs(file_name, text, md, raw, box_images, crops)
summary = f"πŸ“„ File: {file_name}\nπŸ“Š Pages/Slides: {total_pages}\nπŸ–ΌοΈ Cropped Images: {len(crops)}\n⏱️ Processing Time: {elapsed_time:.2f}s\nπŸ“ Saved to: {folder_path}"
return text, md, raw, box_images, crops, summary
def process_multiple_files(files, mode, task, custom_prompt):
if not files:
return "No files uploaded", "", "", [], [], "No files to process"
all_texts, all_mds, all_raws, all_boxes, all_crops = [], [], [], [], []
summaries = []
total_start = time.time()
for file in files:
text, md, raw, boxes, crops, summary = process_single_file(file.name, mode, task, custom_prompt)
all_texts.append(text)
all_mds.append(md)
all_raws.append(raw)
all_boxes.extend(boxes)
all_crops.extend(crops)
summaries.append(summary)
total_time = time.time() - total_start
combined_text = "\n\n========================================\n\n".join(all_texts)
combined_md = "\n\n========================================\n\n".join(all_mds)
combined_raw = "\n\n========================================\n\n".join(all_raws)
final_summary = f"βœ… Processed {len(files)} file(s)\n⏱️ Total Time: {total_time:.2f}s\n\n" + "\n\n".join(summaries)
return combined_text, combined_md, combined_raw, all_boxes, all_crops, final_summary
def toggle_prompt(task):
if task == "✏️ Custom":
return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
elif task == "πŸ“ Locate":
return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
return gr.update(visible=False)
def show_view(view_type):
"""Toggle visibility of different output views"""
return (
gr.update(visible=(view_type == "text")),
gr.update(visible=(view_type == "markdown")),
gr.update(visible=(view_type == "raw")),
gr.update(visible=(view_type == "boxes")),
gr.update(visible=(view_type == "crops"))
)
with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR Multi-file") as demo:
gr.Markdown("""
# πŸš€ DeepSeek-OCR Multi-file Processor
Upload multiple files (PDF, DOCX, PPTX, Images) and process them with document-wise folder structure.
""")
with gr.Row():
with gr.Column(scale=1):
files_in = gr.File(label="πŸ“ Upload Files", file_count="multiple", type="filepath")
mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="βš™οΈ Mode")
task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="πŸ“‹ Markdown", label="πŸ“ Task")
prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
btn = gr.Button("πŸ”„ Process All Files", variant="primary", size="lg")
gr.Markdown("---")
summary_out = gr.Textbox(label="πŸ“Š Processing Summary", lines=8)
with gr.Column(scale=2):
# View selection buttons in one row
with gr.Row():
text_btn = gr.Button("πŸ“„ Text", variant="secondary", size="sm")
md_btn = gr.Button("πŸ“‹ Markdown", variant="secondary", size="sm")
raw_btn = gr.Button("πŸ” Raw", variant="secondary", size="sm")
boxes_btn = gr.Button("🎯 Boxes", variant="secondary", size="sm")
crops_btn = gr.Button("βœ‚οΈ Crops", variant="secondary", size="sm")
# Output containers (only one visible at a time)
text_container = gr.Column(visible=True)
with text_container:
gr.Markdown("### πŸ“„ Text Output")
text_out = gr.Textbox(lines=25, show_copy_button=True, show_label=False)
md_container = gr.Column(visible=False)
with md_container:
gr.Markdown("### πŸ“‹ Markdown Output")
md_out = gr.Markdown("")
raw_container = gr.Column(visible=False)
with raw_container:
gr.Markdown("### πŸ” Raw Output")
raw_out = gr.Textbox(lines=25, show_copy_button=True, show_label=False)
boxes_container = gr.Column(visible=False)
with boxes_container:
gr.Markdown("### 🎯 Bounding Boxes")
boxes_gallery = gr.Gallery(show_label=False, columns=3, height=600)
crops_container = gr.Column(visible=False)
with crops_container:
gr.Markdown("### βœ‚οΈ Cropped Images")
crops_gallery = gr.Gallery(show_label=False, columns=4, height=600)
with gr.Accordion("ℹ️ Info", open=False):
gr.Markdown("""
### Modes
- **Gundam**: 1024 base + 640 tiles with cropping - Best balance
- **Tiny**: 512Γ—512, no crop - Fastest
- **Small**: 640Γ—640, no crop - Quick
- **Base**: 1024Γ—1024, no crop - Standard
- **Large**: 1280Γ—1280, no crop - Highest quality
### Tasks
- **Markdown**: Convert document to structured markdown (grounding βœ…)
- **Free OCR**: Simple text extraction
- **Locate**: Find specific things in image (grounding βœ…)
- **Describe**: General image description
- **Custom**: Your own prompt (add `<|grounding|>` for boxes)
### Supported Formats
- πŸ“„ PDF files
- πŸ“ Word documents (.docx)
- πŸ“Š PowerPoint presentations (.pptx)
- πŸ–ΌοΈ Images (JPG, PNG, etc.)
""")
# Event handlers
task.change(toggle_prompt, [task], [prompt])
btn.click(
process_multiple_files,
[files_in, mode, task, prompt],
[text_out, md_out, raw_out, boxes_gallery, crops_gallery, summary_out]
)
# View toggle buttons
text_btn.click(
lambda: show_view("text"),
None,
[text_container, md_container, raw_container, boxes_container, crops_container]
)
md_btn.click(
lambda: show_view("markdown"),
None,
[text_container, md_container, raw_container, boxes_container, crops_container]
)
raw_btn.click(
lambda: show_view("raw"),
None,
[text_container, md_container, raw_container, boxes_container, crops_container]
)
boxes_btn.click(
lambda: show_view("boxes"),
None,
[text_container, md_container, raw_container, boxes_container, crops_container]
)
crops_btn.click(
lambda: show_view("crops"),
None,
[text_container, md_container, raw_container, boxes_container, crops_container]
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)