Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import spaces | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| import fitz | |
| import re | |
| import numpy as np | |
| import base64 | |
| from io import StringIO, BytesIO | |
| from pathlib import Path | |
| import time | |
| from docx import Document | |
| from pptx import Presentation | |
| MODEL_NAME = 'deepseek-ai/DeepSeek-OCR' | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True) | |
| model = model.eval().cuda() | |
| MODEL_CONFIGS = { | |
| "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True}, | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False} | |
| } | |
| TASK_PROMPTS = { | |
| "π Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True}, | |
| "π Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False}, | |
| "π Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True}, | |
| "π Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False}, | |
| "βοΈ Custom": {"prompt": "", "has_grounding": False} | |
| } | |
| def extract_grounding_references(text): | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| return re.findall(pattern, text, re.DOTALL) | |
| def draw_bounding_boxes(image, refs, extract_images=False): | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) | |
| draw2 = ImageDraw.Draw(overlay) | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 25) | |
| crops = [] | |
| color_map = {} | |
| np.random.seed(42) | |
| for ref in refs: | |
| label = ref[1] | |
| if label not in color_map: | |
| color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)) | |
| color = color_map[label] | |
| coords = eval(ref[2]) | |
| color_a = color + (60,) | |
| for box in coords: | |
| x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h) | |
| if extract_images and label == 'image': | |
| crops.append(image.crop((x1, y1, x2, y2))) | |
| width = 5 if label == 'title' else 3 | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a) | |
| text_bbox = draw.textbbox((0, 0), label, font=font) | |
| tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | |
| ty = max(0, y1 - 20) | |
| draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color) | |
| draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255)) | |
| img_draw.paste(overlay, (0, 0), overlay) | |
| return img_draw, crops | |
| def clean_output(text, include_images=False, remove_labels=False): | |
| if not text: | |
| return "" | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| img_num = 0 | |
| for match in matches: | |
| if '<|ref|>image<|/ref|>' in match[0]: | |
| if include_images: | |
| text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1) | |
| img_num += 1 | |
| else: | |
| text = text.replace(match[0], '', 1) | |
| else: | |
| if remove_labels: | |
| text = text.replace(match[0], '', 1) | |
| else: | |
| text = text.replace(match[0], match[1], 1) | |
| return text.strip() | |
| def embed_images(markdown, crops): | |
| if not crops: | |
| return markdown | |
| for i, img in enumerate(crops): | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n\n\n', 1) | |
| return markdown | |
| def process_image(image, mode, task, custom_prompt): | |
| if image is None: | |
| return "Error: Upload image", "", "", None, [] | |
| if task in ["βοΈ Custom", "π Locate"] and not custom_prompt.strip(): | |
| return "Enter prompt", "", "", None, [] | |
| if image.mode in ('RGBA', 'LA', 'P'): | |
| image = image.convert('RGB') | |
| image = ImageOps.exif_transpose(image) | |
| config = MODEL_CONFIGS[mode] | |
| if task == "βοΈ Custom": | |
| prompt = f"<image>\n{custom_prompt.strip()}" | |
| has_grounding = '<|grounding|>' in custom_prompt | |
| elif task == "π Locate": | |
| prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image." | |
| has_grounding = True | |
| else: | |
| prompt = TASK_PROMPTS[task]["prompt"] | |
| has_grounding = TASK_PROMPTS[task]["has_grounding"] | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
| image.save(tmp.name, 'JPEG', quality=95) | |
| tmp.close() | |
| out_dir = tempfile.mkdtemp() | |
| stdout = sys.stdout | |
| sys.stdout = StringIO() | |
| model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir, | |
| base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"]) | |
| result = '\n'.join([l for l in sys.stdout.getvalue().split('\n') | |
| if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip() | |
| sys.stdout = stdout | |
| os.unlink(tmp.name) | |
| shutil.rmtree(out_dir, ignore_errors=True) | |
| if not result: | |
| return "No text", "", "", None, [] | |
| cleaned = clean_output(result, False, False) | |
| markdown = clean_output(result, True, True) | |
| img_out = None | |
| crops = [] | |
| if has_grounding and '<|ref|>' in result: | |
| refs = extract_grounding_references(result) | |
| if refs: | |
| img_out, crops = draw_bounding_boxes(image, refs, True) | |
| markdown = embed_images(markdown, crops) | |
| return cleaned, markdown, result, img_out, crops | |
| def docx_to_images(path): | |
| doc = Document(path) | |
| images = [] | |
| for i, para in enumerate(doc.paragraphs): | |
| if para.text.strip(): | |
| img = Image.new('RGB', (800, 1100), color='white') | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) | |
| draw.text((50, 50), para.text, fill='black', font=font) | |
| images.append(img) | |
| return images | |
| def pptx_to_images(path): | |
| prs = Presentation(path) | |
| images = [] | |
| for i, slide in enumerate(prs.slides): | |
| img = Image.new('RGB', (960, 720), color='white') | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) | |
| y = 50 | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text.strip(): | |
| draw.text((50, y), shape.text, fill='black', font=font) | |
| y += 100 | |
| images.append(img) | |
| return images | |
| def process_pdf(path, mode, task, custom_prompt): | |
| doc = fitz.open(path) | |
| texts, markdowns, raws, all_crops = [], [], [], [] | |
| box_images = [] | |
| for i in range(len(doc)): | |
| page = doc.load_page(i) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| text, md, raw, box_img, crops = process_image(img, mode, task, custom_prompt) | |
| if text and text != "No text": | |
| texts.append(f"### Page {i + 1}\n\n{text}") | |
| markdowns.append(f"### Page {i + 1}\n\n{md}") | |
| raws.append(f"=== Page {i + 1} ===\n{raw}") | |
| all_crops.extend(crops) | |
| box_images.append(box_img) | |
| total_pages = len(doc) | |
| doc.close() | |
| return ("\n\n---\n\n".join(texts) if texts else "No text in PDF", | |
| "\n\n---\n\n".join(markdowns) if markdowns else "No text in PDF", | |
| "\n\n".join(raws), box_images, all_crops, total_pages) | |
| def save_outputs(doc_name, text_content, md_content, raw_content, box_images, cropped_images): | |
| base_dir = Path("outputs") | |
| base_dir.mkdir(exist_ok=True) | |
| existing_dirs = [d for d in base_dir.iterdir() if d.is_dir()] | |
| folder_num = len(existing_dirs) + 1 | |
| doc_folder = base_dir / f"{folder_num:02d}_{doc_name}" | |
| doc_folder.mkdir(exist_ok=True) | |
| (doc_folder / "text_output.txt").write_text(text_content, encoding='utf-8') | |
| (doc_folder / "clean_output.md").write_text(md_content, encoding='utf-8') | |
| (doc_folder / "raw_output.txt").write_text(raw_content, encoding='utf-8') | |
| boxes_dir = doc_folder / "boxes" | |
| boxes_dir.mkdir(exist_ok=True) | |
| for i, img in enumerate(box_images): | |
| if img is not None: | |
| img.save(boxes_dir / f"page_{i+1:02d}_box.jpg") | |
| cropped_dir = doc_folder / "cropped" | |
| cropped_dir.mkdir(exist_ok=True) | |
| for i, img in enumerate(cropped_images): | |
| if img is not None: | |
| img.save(cropped_dir / f"crop_{i+1:02d}.jpg") | |
| return str(doc_folder) | |
| def process_single_file(file_path, mode, task, custom_prompt): | |
| start_time = time.time() | |
| file_name = Path(file_path).stem | |
| ext = Path(file_path).suffix.lower() | |
| if ext == '.pdf': | |
| text, md, raw, box_images, crops, total_pages = process_pdf(file_path, mode, task, custom_prompt) | |
| elif ext == '.docx': | |
| images = docx_to_images(file_path) | |
| texts, mds, raws, box_images, crops = [], [], [], [], [] | |
| for i, img in enumerate(images): | |
| text, md, raw, box_img, crp = process_image(img, mode, task, custom_prompt) | |
| texts.append(f"### Page {i+1}\n\n{text}") | |
| mds.append(f"### Page {i+1}\n\n{md}") | |
| raws.append(f"=== Page {i+1} ===\n{raw}") | |
| box_images.append(box_img) | |
| crops.extend(crp) | |
| text = "\n\n---\n\n".join(texts) | |
| md = "\n\n---\n\n".join(mds) | |
| raw = "\n\n".join(raws) | |
| total_pages = len(images) | |
| elif ext == '.pptx': | |
| images = pptx_to_images(file_path) | |
| texts, mds, raws, box_images, crops = [], [], [], [], [] | |
| for i, img in enumerate(images): | |
| text, md, raw, box_img, crp = process_image(img, mode, task, custom_prompt) | |
| texts.append(f"### Slide {i+1}\n\n{text}") | |
| mds.append(f"### Slide {i+1}\n\n{md}") | |
| raws.append(f"=== Slide {i+1} ===\n{raw}") | |
| box_images.append(box_img) | |
| crops.extend(crp) | |
| text = "\n\n---\n\n".join(texts) | |
| md = "\n\n---\n\n".join(mds) | |
| raw = "\n\n".join(raws) | |
| total_pages = len(images) | |
| else: | |
| img = Image.open(file_path) | |
| text, md, raw, box_img, crops = process_image(img, mode, task, custom_prompt) | |
| box_images = [box_img] if box_img else [] | |
| total_pages = 1 | |
| elapsed_time = time.time() - start_time | |
| folder_path = save_outputs(file_name, text, md, raw, box_images, crops) | |
| summary = f"π File: {file_name}\nπ Pages/Slides: {total_pages}\nπΌοΈ Cropped Images: {len(crops)}\nβ±οΈ Processing Time: {elapsed_time:.2f}s\nπ Saved to: {folder_path}" | |
| return text, md, raw, box_images, crops, summary | |
| def process_multiple_files(files, mode, task, custom_prompt): | |
| if not files: | |
| return "No files uploaded", "", "", [], [], "No files to process" | |
| all_texts, all_mds, all_raws, all_boxes, all_crops = [], [], [], [], [] | |
| summaries = [] | |
| total_start = time.time() | |
| for file in files: | |
| text, md, raw, boxes, crops, summary = process_single_file(file.name, mode, task, custom_prompt) | |
| all_texts.append(text) | |
| all_mds.append(md) | |
| all_raws.append(raw) | |
| all_boxes.extend(boxes) | |
| all_crops.extend(crops) | |
| summaries.append(summary) | |
| total_time = time.time() - total_start | |
| combined_text = "\n\n========================================\n\n".join(all_texts) | |
| combined_md = "\n\n========================================\n\n".join(all_mds) | |
| combined_raw = "\n\n========================================\n\n".join(all_raws) | |
| final_summary = f"β Processed {len(files)} file(s)\nβ±οΈ Total Time: {total_time:.2f}s\n\n" + "\n\n".join(summaries) | |
| return combined_text, combined_md, combined_raw, all_boxes, all_crops, final_summary | |
| def toggle_prompt(task): | |
| if task == "βοΈ Custom": | |
| return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes") | |
| elif task == "π Locate": | |
| return gr.update(visible=True, label="Text to Locate", placeholder="Enter text") | |
| return gr.update(visible=False) | |
| def show_view(view_type): | |
| """Toggle visibility of different output views""" | |
| return ( | |
| gr.update(visible=(view_type == "text")), | |
| gr.update(visible=(view_type == "markdown")), | |
| gr.update(visible=(view_type == "raw")), | |
| gr.update(visible=(view_type == "boxes")), | |
| gr.update(visible=(view_type == "crops")) | |
| ) | |
| with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR Multi-file") as demo: | |
| gr.Markdown(""" | |
| # π DeepSeek-OCR Multi-file Processor | |
| Upload multiple files (PDF, DOCX, PPTX, Images) and process them with document-wise folder structure. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| files_in = gr.File(label="π Upload Files", file_count="multiple", type="filepath") | |
| mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="βοΈ Mode") | |
| task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="π Markdown", label="π Task") | |
| prompt = gr.Textbox(label="Prompt", lines=2, visible=False) | |
| btn = gr.Button("π Process All Files", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| summary_out = gr.Textbox(label="π Processing Summary", lines=8) | |
| with gr.Column(scale=2): | |
| # View selection buttons in one row | |
| with gr.Row(): | |
| text_btn = gr.Button("π Text", variant="secondary", size="sm") | |
| md_btn = gr.Button("π Markdown", variant="secondary", size="sm") | |
| raw_btn = gr.Button("π Raw", variant="secondary", size="sm") | |
| boxes_btn = gr.Button("π― Boxes", variant="secondary", size="sm") | |
| crops_btn = gr.Button("βοΈ Crops", variant="secondary", size="sm") | |
| # Output containers (only one visible at a time) | |
| text_container = gr.Column(visible=True) | |
| with text_container: | |
| gr.Markdown("### π Text Output") | |
| text_out = gr.Textbox(lines=25, show_copy_button=True, show_label=False) | |
| md_container = gr.Column(visible=False) | |
| with md_container: | |
| gr.Markdown("### π Markdown Output") | |
| md_out = gr.Markdown("") | |
| raw_container = gr.Column(visible=False) | |
| with raw_container: | |
| gr.Markdown("### π Raw Output") | |
| raw_out = gr.Textbox(lines=25, show_copy_button=True, show_label=False) | |
| boxes_container = gr.Column(visible=False) | |
| with boxes_container: | |
| gr.Markdown("### π― Bounding Boxes") | |
| boxes_gallery = gr.Gallery(show_label=False, columns=3, height=600) | |
| crops_container = gr.Column(visible=False) | |
| with crops_container: | |
| gr.Markdown("### βοΈ Cropped Images") | |
| crops_gallery = gr.Gallery(show_label=False, columns=4, height=600) | |
| with gr.Accordion("βΉοΈ Info", open=False): | |
| gr.Markdown(""" | |
| ### Modes | |
| - **Gundam**: 1024 base + 640 tiles with cropping - Best balance | |
| - **Tiny**: 512Γ512, no crop - Fastest | |
| - **Small**: 640Γ640, no crop - Quick | |
| - **Base**: 1024Γ1024, no crop - Standard | |
| - **Large**: 1280Γ1280, no crop - Highest quality | |
| ### Tasks | |
| - **Markdown**: Convert document to structured markdown (grounding β ) | |
| - **Free OCR**: Simple text extraction | |
| - **Locate**: Find specific things in image (grounding β ) | |
| - **Describe**: General image description | |
| - **Custom**: Your own prompt (add `<|grounding|>` for boxes) | |
| ### Supported Formats | |
| - π PDF files | |
| - π Word documents (.docx) | |
| - π PowerPoint presentations (.pptx) | |
| - πΌοΈ Images (JPG, PNG, etc.) | |
| """) | |
| # Event handlers | |
| task.change(toggle_prompt, [task], [prompt]) | |
| btn.click( | |
| process_multiple_files, | |
| [files_in, mode, task, prompt], | |
| [text_out, md_out, raw_out, boxes_gallery, crops_gallery, summary_out] | |
| ) | |
| # View toggle buttons | |
| text_btn.click( | |
| lambda: show_view("text"), | |
| None, | |
| [text_container, md_container, raw_container, boxes_container, crops_container] | |
| ) | |
| md_btn.click( | |
| lambda: show_view("markdown"), | |
| None, | |
| [text_container, md_container, raw_container, boxes_container, crops_container] | |
| ) | |
| raw_btn.click( | |
| lambda: show_view("raw"), | |
| None, | |
| [text_container, md_container, raw_container, boxes_container, crops_container] | |
| ) | |
| boxes_btn.click( | |
| lambda: show_view("boxes"), | |
| None, | |
| [text_container, md_container, raw_container, boxes_container, crops_container] | |
| ) | |
| crops_btn.click( | |
| lambda: show_view("crops"), | |
| None, | |
| [text_container, md_container, raw_container, boxes_container, crops_container] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(share=True) |