Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import spaces | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| import fitz | |
| import re | |
| import numpy as np | |
| import base64 | |
| from io import StringIO, BytesIO | |
| MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2' | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True) | |
| model = model.eval().cuda() | |
| BASE_SIZE = 1024 | |
| IMAGE_SIZE = 768 | |
| CROP_MODE = True | |
| TASK_PROMPTS = { | |
| "🧾 OCR": {"prompt": "<image>\nExtract all text from this image.", "has_grounding": False} | |
| } | |
| INTRO_MD = """ | |
| # 🚀 OCR Tester | |
| **Upload an image or PDF to extract text with OCR.** | |
| """ | |
| INFO_MD = """ | |
| ### Notes | |
| - One OCR prompt is used for all uploads. | |
| - `<image>` is the placeholder where visual tokens are inserted. | |
| """ | |
| def extract_grounding_references(text): | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| return re.findall(pattern, text, re.DOTALL) | |
| def draw_bounding_boxes(image, refs, extract_images=False): | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) | |
| draw2 = ImageDraw.Draw(overlay) | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15) | |
| crops = [] | |
| color_map = {} | |
| np.random.seed(42) | |
| for ref in refs: | |
| label = ref[1] | |
| if label not in color_map: | |
| color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)) | |
| color = color_map[label] | |
| coords = eval(ref[2]) | |
| color_a = color + (60,) | |
| for box in coords: | |
| x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h) | |
| if extract_images and label == 'image': | |
| crops.append(image.crop((x1, y1, x2, y2))) | |
| width = 5 if label == 'title' else 3 | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a) | |
| text_bbox = draw.textbbox((0, 0), label, font=font) | |
| tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | |
| ty = max(0, y1 - 20) | |
| draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color) | |
| draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255)) | |
| img_draw.paste(overlay, (0, 0), overlay) | |
| return img_draw, crops | |
| def clean_output(text, include_images=False): | |
| if not text: | |
| return "" | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| img_num = 0 | |
| for match in matches: | |
| if '<|ref|>image<|/ref|>' in match[0]: | |
| if include_images: | |
| text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1) | |
| img_num += 1 | |
| else: | |
| text = text.replace(match[0], '', 1) | |
| else: | |
| text = re.sub(rf'(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?', '', text) | |
| text = text.replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') | |
| return text.strip() | |
| def embed_images(markdown, crops): | |
| if not crops: | |
| return markdown | |
| for i, img in enumerate(crops): | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n\n\n', 1) | |
| return markdown | |
| def process_image(image): | |
| if image is None: | |
| return "Error: Upload an image", "", "", None, [] | |
| if image.mode in ('RGBA', 'LA', 'P'): | |
| image = image.convert('RGB') | |
| image = ImageOps.exif_transpose(image) | |
| prompt = TASK_PROMPTS["🧾 OCR"]["prompt"] | |
| has_grounding = TASK_PROMPTS["🧾 OCR"]["has_grounding"] | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
| image.save(tmp.name, 'JPEG', quality=95) | |
| tmp.close() | |
| out_dir = tempfile.mkdtemp() | |
| stdout = sys.stdout | |
| sys.stdout = StringIO() | |
| model.infer( | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| image_file=tmp.name, | |
| output_path=out_dir, | |
| base_size=BASE_SIZE, | |
| image_size=IMAGE_SIZE, | |
| crop_mode=CROP_MODE, | |
| save_results=False | |
| ) | |
| debug_filters = ['PATCHES', '====', 'BASE:', 'directly resize', 'NO PATCHES', 'torch.Size', '%|'] | |
| result = '\n'.join([l for l in sys.stdout.getvalue().split('\n') | |
| if l.strip() and not any(s in l for s in debug_filters)]).strip() | |
| sys.stdout = stdout | |
| os.unlink(tmp.name) | |
| shutil.rmtree(out_dir, ignore_errors=True) | |
| if not result: | |
| return "No text detected", "", "", None, [] | |
| cleaned = clean_output(result, False) | |
| markdown = clean_output(result, True) | |
| img_out = None | |
| crops = [] | |
| if has_grounding and '<|ref|>' in result: | |
| refs = extract_grounding_references(result) | |
| if refs: | |
| img_out, crops = draw_bounding_boxes(image, refs, True) | |
| markdown = embed_images(markdown, crops) | |
| return cleaned, markdown, result, img_out, crops | |
| def process_pdf(path, page_num): | |
| doc = fitz.open(path) | |
| total_pages = len(doc) | |
| if page_num < 1 or page_num > total_pages: | |
| doc.close() | |
| return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, [] | |
| page = doc.load_page(page_num - 1) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| doc.close() | |
| return process_image(img) | |
| def process_file(path, page_num): | |
| if not path: | |
| return "Error: Upload a file", "", "", None, [] | |
| if path.lower().endswith('.pdf'): | |
| return process_pdf(path, page_num) | |
| else: | |
| return process_image(Image.open(path)) | |
| def unpack_multimodal(value): | |
| if not value or not isinstance(value, dict): | |
| return None | |
| files = value.get("files") or [] | |
| if not files: | |
| return None | |
| file_obj = files[0] | |
| if isinstance(file_obj, str): | |
| return file_obj | |
| if isinstance(file_obj, dict): | |
| return file_obj.get("path") or file_obj.get("name") | |
| return getattr(file_obj, "name", None) | |
| def get_pdf_page_count(file_path): | |
| if not file_path or not file_path.lower().endswith('.pdf'): | |
| return 1 | |
| doc = fitz.open(file_path) | |
| count = len(doc) | |
| doc.close() | |
| return count | |
| def load_image(file_path, page_num=1): | |
| if not file_path: | |
| return None | |
| if file_path.lower().endswith('.pdf'): | |
| doc = fitz.open(file_path) | |
| page_idx = max(0, min(int(page_num) - 1, len(doc) - 1)) | |
| page = doc.load_page(page_idx) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| doc.close() | |
| return img | |
| else: | |
| return Image.open(file_path) | |
| def update_page_selector(file_path): | |
| if not file_path: | |
| return gr.update(visible=False) | |
| if file_path.lower().endswith('.pdf'): | |
| page_count = get_pdf_page_count(file_path) | |
| return gr.update(visible=True, maximum=page_count, value=1, minimum=1, | |
| label=f"Select Page (1-{page_count})") | |
| return gr.update(visible=False) | |
| def load_image_from_multimodal(value, page_num=1): | |
| file_path = unpack_multimodal(value) | |
| return load_image(file_path, page_num) | |
| def update_page_selector_from_multimodal(value): | |
| file_path = unpack_multimodal(value) | |
| return update_page_selector(file_path) | |
| with gr.Blocks(title="DeepSeek-OCR-2") as demo: | |
| gr.Markdown(INTRO_MD) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| multimodal_in = gr.MultimodalTextbox( | |
| label="Input (Image/PDF)", | |
| file_types=["image", ".pdf"], | |
| placeholder="Drop an image or PDF here", | |
| ) | |
| input_img = gr.Image(label="Input Image", type="pil", height=300, interactive=False) | |
| page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False) | |
| btn = gr.Button("Extract", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("Text", id="tab_text"): | |
| text_out = gr.Textbox(lines=20, buttons=["copy"], show_label=False) | |
| with gr.Tab("Markdown Preview", id="tab_markdown"): | |
| md_out = gr.Markdown("") | |
| with gr.Tab("Boxes", id="tab_boxes"): | |
| img_out = gr.Image(type="pil", height=500, show_label=False) | |
| with gr.Tab("Cropped Images", id="tab_crops"): | |
| gallery = gr.Gallery(show_label=False, columns=3, height=400) | |
| with gr.Tab("Raw Text", id="tab_raw"): | |
| raw_out = gr.Textbox(lines=20, buttons=["copy"], show_label=False) | |
| with gr.Accordion("ℹ️ Info", open=False): | |
| gr.Markdown(INFO_MD) | |
| multimodal_in.change(load_image_from_multimodal, [multimodal_in, page_selector], [input_img]) | |
| multimodal_in.change(update_page_selector_from_multimodal, [multimodal_in], [page_selector]) | |
| page_selector.change(load_image_from_multimodal, [multimodal_in, page_selector], [input_img]) | |
| def run(multimodal_value, page_num): | |
| file_path = unpack_multimodal(multimodal_value) | |
| if file_path: | |
| return process_file(file_path, int(page_num)) | |
| return "Error: Upload a file or image", "", "", None, [] | |
| submit_event = btn.click(run, [multimodal_in, page_selector], | |
| [text_out, md_out, raw_out, img_out, gallery]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(theme=gr.themes.Soft()) | |