prithivMLmods commited on
Commit
887babd
·
verified ·
1 Parent(s): 278014f

update app

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import requests
4
+ from transformers import AutoModel, AutoTokenizer
5
+ import spaces
6
+ from typing import Iterable
7
+ import os
8
+ import tempfile
9
+ from PIL import Image, ImageDraw
10
+ import re
11
+ from gradio.themes import Soft
12
+ from gradio.themes.utils import colors, fonts, sizes
13
+ from docling_core.types.doc import DoclingDocument, DocTagsDocument
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
18
+ print("torch.__version__ =", torch.__version__)
19
+ print("torch.version.cuda =", torch.version.cuda)
20
+ print("cuda available:", torch.cuda.is_available())
21
+ print("cuda device count:", torch.cuda.device_count())
22
+ if torch.cuda.is_available():
23
+ print("current device:", torch.cuda.current_device())
24
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
25
+
26
+ print("Using device:", device)
27
+
28
+
29
+ colors.steel_blue = colors.Color(
30
+ name="steel_blue",
31
+ c50="#EBF3F8",
32
+ c100="#D3E5F0",
33
+ c200="#A8CCE1",
34
+ c300="#7DB3D2",
35
+ c400="#529AC3",
36
+ c500="#4682B4",
37
+ c600="#3E72A0",
38
+ c700="#36638C",
39
+ c800="#2E5378",
40
+ c900="#264364",
41
+ c950="#1E3450",
42
+ )
43
+
44
+ class SteelBlueTheme(Soft):
45
+ def __init__(
46
+ self,
47
+ *,
48
+ primary_hue: colors.Color | str = colors.gray,
49
+ secondary_hue: colors.Color | str = colors.steel_blue,
50
+ neutral_hue: colors.Color | str = colors.slate,
51
+ text_size: sizes.Size | str = sizes.text_lg,
52
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
53
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
54
+ ),
55
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
56
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
57
+ ),
58
+ ):
59
+ super().__init__(
60
+ primary_hue=primary_hue,
61
+ secondary_hue=secondary_hue,
62
+ neutral_hue=neutral_hue,
63
+ text_size=text_size,
64
+ font=font,
65
+ font_mono=font_mono,
66
+ )
67
+ super().set(
68
+ background_fill_primary="*primary_50",
69
+ background_fill_primary_dark="*primary_900",
70
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
71
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
72
+ button_primary_text_color="white",
73
+ button_primary_text_color_hover="white",
74
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
75
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
76
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
77
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
78
+ slider_color="*secondary_500",
79
+ slider_color_dark="*secondary_600",
80
+ block_title_text_weight="600",
81
+ block_border_width="3px",
82
+ block_shadow="*shadow_drop_lg",
83
+ button_primary_shadow="*shadow_drop_lg",
84
+ button_large_padding="11px",
85
+ color_accent_soft="*primary_100",
86
+ block_label_background_fill="*primary_200",
87
+ )
88
+
89
+ steel_blue_theme = SteelBlueTheme()
90
+
91
+ css = """
92
+ #main-title h1 {
93
+ font-size: 2.3em !important;
94
+ }
95
+ #output-title h2 {
96
+ font-size: 2.1em !important;
97
+ }
98
+ """
99
+
100
+ print("Determining device...")
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ print(f"✅ Using device: {device}")
103
+
104
+ print("Loading model and tokenizer...")
105
+ model_name = "prithivMLmods/DeepSeek-OCR-transformers-5.0.0.dev0" # -> Latest transformers version used for the model. (https://huggingface.co/deepseek-ai/DeepSeek-OCR)
106
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
107
+
108
+ model = AutoModel.from_pretrained(
109
+ model_name,
110
+ #_attn_implementation="flash_attention_2",
111
+ trust_remote_code=True,
112
+ use_safetensors=True,
113
+ ).to(device).eval() # Move to device and set to eval mode
114
+
115
+ if device.type == 'cuda':
116
+ model = model.to(torch.bfloat16)
117
+
118
+ print("✅ Model loaded successfully to device and in eval mode.")
119
+
120
+ def find_result_image(path):
121
+ for filename in os.listdir(path):
122
+ if "grounding" in filename or "result" in filename:
123
+ try:
124
+ image_path = os.path.join(path, filename)
125
+ return Image.open(image_path)
126
+ except Exception as e:
127
+ print(f"Error opening result image {filename}: {e}")
128
+ return None
129
+
130
+ @spaces.GPU
131
+ def process_ocr_task(image, model_size, task_type, ref_text):
132
+ """
133
+ Processes an image with DeepSeek-OCR. The model is already on the correct device.
134
+ """
135
+ if image is None:
136
+ return "Please upload an image first.", None
137
+
138
+ print("✅ Model is already on the designated device.")
139
+
140
+ with tempfile.TemporaryDirectory() as output_path:
141
+ # Build the prompt
142
+ if task_type == "Free OCR":
143
+ prompt = "<image>\nFree OCR."
144
+ elif task_type == "Convert to Markdown":
145
+ prompt = "<image>\n<|grounding|>Convert the document to markdown."
146
+ elif task_type == "Parse Figure":
147
+ prompt = "<image>\nParse the figure."
148
+ elif task_type == "Locate Object by Reference":
149
+ if not ref_text or ref_text.strip() == "":
150
+ raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
151
+ prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
152
+ else:
153
+ prompt = "<image>\nFree OCR."
154
+
155
+ temp_image_path = os.path.join(output_path, "temp_image.png")
156
+ image.save(temp_image_path)
157
+
158
+ size_configs = {
159
+ "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
160
+ "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
161
+ "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
162
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
163
+ "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
164
+ }
165
+ config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
166
+
167
+ print(f"🏃 Running inference with prompt: {prompt}")
168
+ text_result = model.infer(
169
+ tokenizer,
170
+ prompt=prompt,
171
+ image_file=temp_image_path,
172
+ output_path=output_path,
173
+ base_size=config["base_size"],
174
+ image_size=config["image_size"],
175
+ crop_mode=config["crop_mode"],
176
+ save_results=True,
177
+ test_compress=True,
178
+ eval_mode=True,
179
+ )
180
+
181
+ print(f"====\n📄 Text Result: {text_result}\n====")
182
+
183
+ result_image_pil = None
184
+ pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
185
+ matches = list(pattern.finditer(text_result))
186
+
187
+ if matches:
188
+ print(f"✅ Found {len(matches)} bounding box(es). Drawing on the original image.")
189
+ image_with_bboxes = image.copy()
190
+ draw = ImageDraw.Draw(image_with_bboxes)
191
+ w, h = image.size
192
+
193
+ for match in matches:
194
+ coords_norm = [int(c) for c in match.groups()]
195
+ x1_norm, y1_norm, x2_norm, y2_norm = coords_norm
196
+
197
+ x1 = int(x1_norm / 1000 * w)
198
+ y1 = int(y1_norm / 1000 * h)
199
+ x2 = int(x2_norm / 1000 * w)
200
+ y2 = int(y2_norm / 1000 * h)
201
+
202
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
203
+
204
+ result_image_pil = image_with_bboxes
205
+ else:
206
+ print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.")
207
+ result_image_pil = find_result_image(output_path)
208
+
209
+ return text_result, result_image_pil
210
+
211
+ # url = "https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR3/resolve/main/examples/3.jpg?download=true"
212
+ # example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
213
+
214
+ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
215
+ gr.Markdown("# **DeepSeek OCR [exp]**", elem_id="main-title")
216
+
217
+
218
+ with gr.Row():
219
+ with gr.Column(scale=1):
220
+ image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
221
+ model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Large", label="Resolution Size")
222
+ task_type = gr.Dropdown(choices=["Free OCR", "Convert to Markdown", "Parse Figure", "Locate Object by Reference"], value="Convert to Markdown", label="Task Type")
223
+ ref_text_input = gr.Textbox(label="Reference Text (for Locate task)", placeholder="e.g., the teacher, 20-10, a red car...", visible=False)
224
+ submit_btn = gr.Button("Process Image", variant="primary")
225
+
226
+ examples = gr.Examples(
227
+ examples=["examples/1.jpg", "examples/2.jpg", "examples/3.jpg"],
228
+ inputs=image_input, label="Examples"
229
+ )
230
+
231
+ with gr.Column(scale=2):
232
+ output_text = gr.Textbox(label="Output (OCR)", lines=8, show_copy_button=True)
233
+ output_image = gr.Image(label="Layout Detection (If Any)", type="pil")
234
+
235
+ with gr.Accordion("Note", open=False):
236
+ gr.Markdown("Inference using Huggingface transformers on NVIDIA GPUs. This app is running with transformers version 4.57.1 and torch version 2.6.0.")
237
+
238
+ def toggle_ref_text_visibility(task):
239
+ return gr.Textbox(visible=True) if task == "Locate Object by Reference" else gr.Textbox(visible=False)
240
+
241
+ task_type.change(fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input)
242
+ submit_btn.click(fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image])
243
+
244
+ if __name__ == "__main__":
245
+ demo.queue(max_size=20).launch(share=True, mcp_server=True, ssr_mode=False)