import gradio as gr import torch from transformers import AutoModel, AutoTokenizer from huggingface_hub import snapshot_download import spaces import os import tempfile from PIL import Image, ImageDraw import re # --- 1. Download Model to a Local Cache, Modify, and Load --- print("Downloading and setting up model from Hugging Face Hub...") # Define a cache path for the model CACHE_PATH = "./model_cache" if not os.path.exists(CACHE_PATH): os.makedirs(CACHE_PATH) # Download the model repository to the local directory model_path_local = snapshot_download( repo_id='strangervisionhf/deepseek-ocr-latest-transformers', local_dir=os.path.join(CACHE_PATH, 'deepseek.ocr'), max_workers=8, # Adjusted for typical connection speeds local_dir_use_symlinks=False ) print(f"✅ Model downloaded to: {model_path_local}") # --- Remove the specified file after downloading --- file_to_remove = os.path.join(model_path_local, "modeling_deepseekv2.py") if os.path.exists(file_to_remove): try: os.remove(file_to_remove) print(f"✅ Successfully removed file: {file_to_remove}") except OSError as e: print(f"❌ Error removing file {file_to_remove}: {e}") else: print(f"⚠️ File not found, could not remove: {file_to_remove}") # --- Load tokenizer and model from the local path --- print("Loading model and tokenizer from local cache...") MODEL_PATH = model_path_local tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) # Load the model with automatic device mapping and bfloat16 for efficiency model = AutoModel.from_pretrained( MODEL_PATH, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", # Automatically maps model to available GPU(s)/CPU trust_remote_code=True ).eval() print("✅ Model loaded successfully with automatic device mapping.") # --- Helper function to find pre-generated result images --- def find_result_image(path): for filename in os.listdir(path): if "grounding" in filename or "result" in filename: try: image_path = os.path.join(path, filename) return Image.open(image_path) except Exception as e: print(f"Error opening result image {filename}: {e}") return None # --- 2. Main Processing Function (No changes needed here) --- @spaces.GPU def process_ocr_task(image, model_size, task_type, ref_text): """ Processes an image with DeepSeek-OCR. Model is already loaded on the correct device. """ if image is None: return "Please upload an image first.", None # No need to move the model; device_map="auto" handled it at load time. print("✅ Model is already on the designated device(s).") with tempfile.TemporaryDirectory() as output_path: # Build the prompt if task_type == "📝 Free OCR": prompt = "\nFree OCR." elif task_type == "📄 Convert to Markdown": prompt = "\n<|grounding|>Convert the document to markdown." elif task_type == "📈 Parse Figure": prompt = "\nParse the figure." elif task_type == "🔍 Locate Object by Reference": if not ref_text or ref_text.strip() == "": raise gr.Error("For the 'Locate' task, you must provide the reference text to find!") prompt = f"\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image." else: prompt = "\nFree OCR." temp_image_path = os.path.join(output_path, "temp_image.png") image.save(temp_image_path) # Configure model size size_configs = { "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, } config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) print(f"🏃 Running inference with prompt: {prompt}") text_result = model.infer( tokenizer, prompt=prompt, image_file=temp_image_path, output_path=output_path, base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"], save_results=True, test_compress=True, eval_mode=True, ) print(f"====\n📄 Text Result: {text_result}\n====") # --- Logic to draw bounding boxes --- result_image_pil = None pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>") matches = list(pattern.finditer(text_result)) if matches: print(f"✅ Found {len(matches)} bounding box(es). Drawing on the original image.") image_with_bboxes = image.copy() draw = ImageDraw.Draw(image_with_bboxes) w, h = image.size for match in matches: coords_norm = [int(c) for c in match.groups()] x1_norm, y1_norm, x2_norm, y2_norm = coords_norm x1 = int(x1_norm / 1000 * w) y1 = int(y1_norm / 1000 * h) x2 = int(x2_norm / 1000 * w) y2 = int(y2_norm / 1000 * h) draw.rectangle([x1, y1, x2, y2], outline="red", width=3) result_image_pil = image_with_bboxes else: print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.") result_image_pil = find_result_image(output_path) return text_result, result_image_pil # --- 3. Build the Gradio Interface --- with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🐳 Full Demo of DeepSeek-OCR 🐳 **💡 How to use:** 1. **Upload an image** using the upload box. 2. Select a **Resolution**. `Gundam` is recommended for most documents. 3. Choose a **Task Type**: - **📝 Free OCR**: Extracts raw text from the image. - **📄 Convert to Markdown**: Converts the document into Markdown, preserving structure. - **📈 Parse Figure**: Extracts structured data from charts and figures. - **🔍 Locate Object by Reference**: Finds a specific object/text. 4. If this helpful, please give it a like! 🙏 ❤️ """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="🖼️ Upload Image", sources=["upload", "clipboard"]) model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="⚙️ Resolution Size") task_type = gr.Dropdown(choices=["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"], value="📄 Convert to Markdown", label="🚀 Task Type") ref_text_input = gr.Textbox(label="📝 Reference Text (for Locate task)", placeholder="e.g., the teacher, 20-10, a red car...", visible=False) submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=2): output_text = gr.Textbox(label="📄 Text Result", lines=15, show_copy_button=True) output_image = gr.Image(label="🖼️ Image Result (if any)", type="pil") # --- UI Interaction Logic --- def toggle_ref_text_visibility(task): return gr.Textbox(visible=True) if task == "🔍 Locate Object by Reference" else gr.Textbox(visible=False) task_type.change(fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input) submit_btn.click(fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image]) # --- 4. Launch the App --- if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)