Chhagan005 commited on
Commit
5a86f7e
·
verified ·
1 Parent(s): 30e8898

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -0
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ from threading import Thread
7
+ from typing import Iterable
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+
16
+ from transformers import (
17
+ Qwen2VLForConditionalGeneration,
18
+ Qwen2_5_VLForConditionalGeneration,
19
+ AutoModelForImageTextToText,
20
+ AutoProcessor,
21
+ TextIteratorStreamer,
22
+ )
23
+ from transformers.image_utils import load_image
24
+ from gradio.themes import Soft
25
+ from gradio.themes.utils import colors, fonts, sizes
26
+
27
+ colors.steel_blue = colors.Color(
28
+ name="steel_blue",
29
+ c50="#EBF3F8",
30
+ c100="#D3E5F0",
31
+ c200="#A8CCE1",
32
+ c300="#7DB3D2",
33
+ c400="#529AC3",
34
+ c500="#4682B4",
35
+ c600="#3E72A0",
36
+ c700="#36638C",
37
+ c800="#2E5378",
38
+ c900="#264364",
39
+ c950="#1E3450",
40
+ )
41
+
42
+ class SteelBlueTheme(Soft):
43
+ def __init__(
44
+ self,
45
+ *,
46
+ primary_hue: colors.Color | str = colors.gray,
47
+ secondary_hue: colors.Color | str = colors.steel_blue,
48
+ neutral_hue: colors.Color | str = colors.slate,
49
+ text_size: sizes.Size | str = sizes.text_lg,
50
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
51
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
52
+ ),
53
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
54
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
55
+ ),
56
+ ):
57
+ super().__init__(
58
+ primary_hue=primary_hue,
59
+ secondary_hue=secondary_hue,
60
+ neutral_hue=neutral_hue,
61
+ text_size=text_size,
62
+ font=font,
63
+ font_mono=font_mono,
64
+ )
65
+ super().set(
66
+ background_fill_primary="*primary_50",
67
+ background_fill_primary_dark="*primary_900",
68
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
69
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
70
+ button_primary_text_color="white",
71
+ button_primary_text_color_hover="white",
72
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
73
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
74
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
75
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
76
+ button_secondary_text_color="black",
77
+ button_secondary_text_color_hover="white",
78
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
79
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
80
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
81
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
82
+ slider_color="*secondary_500",
83
+ slider_color_dark="*secondary_600",
84
+ block_title_text_weight="600",
85
+ block_border_width="3px",
86
+ block_shadow="*shadow_drop_lg",
87
+ button_primary_shadow="*shadow_drop_lg",
88
+ button_large_padding="11px",
89
+ color_accent_soft="*primary_100",
90
+ block_label_background_fill="*primary_200",
91
+ )
92
+
93
+ steel_blue_theme = SteelBlueTheme()
94
+
95
+ css = """
96
+ #main-title h1 {
97
+ font-size: 2.3em !important;
98
+ }
99
+ #output-title h2 {
100
+ font-size: 2.2em !important;
101
+ }
102
+ /* RadioAnimated Styles */
103
+ .ra-wrap{ width: fit-content; }
104
+ .ra-inner{
105
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
106
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
107
+ }
108
+ .ra-input{ display: none; }
109
+ .ra-label{
110
+ position: relative; z-index: 2; padding: 8px 16px;
111
+ font-family: inherit; font-size: 14px; font-weight: 600;
112
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
113
+ }
114
+ .ra-highlight{
115
+ position: absolute; z-index: 1; top: 6px; left: 6px;
116
+ height: calc(100% - 12px); border-radius: 9999px;
117
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
118
+ transition: transform 0.2s, width 0.2s;
119
+ }
120
+ .ra-input:checked + .ra-label{ color: black; }
121
+ /* Dark mode adjustments for Radio */
122
+ .dark .ra-inner { background: var(--neutral-800); }
123
+ .dark .ra-label { color: var(--neutral-400); }
124
+ .dark .ra-highlight { background: var(--neutral-600); }
125
+ .dark .ra-input:checked + .ra-label { color: white; }
126
+ #gpu-duration-container {
127
+ padding: 10px;
128
+ border-radius: 8px;
129
+ background: var(--background-fill-secondary);
130
+ border: 1px solid var(--border-color-primary);
131
+ margin-top: 10px;
132
+ }
133
+ """
134
+
135
+ MAX_MAX_NEW_TOKENS = 4096
136
+ DEFAULT_MAX_NEW_TOKENS = 1024
137
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
138
+
139
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
140
+
141
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
142
+ print("torch.__version__ =", torch.__version__)
143
+ print("torch.version.cuda =", torch.version.cuda)
144
+ print("cuda available:", torch.cuda.is_available())
145
+ print("cuda device count:", torch.cuda.device_count())
146
+ if torch.cuda.is_available():
147
+ print("current device:", torch.cuda.current_device())
148
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
149
+
150
+ print("Using device:", device)
151
+
152
+ class RadioAnimated(gr.HTML):
153
+ def __init__(self, choices, value=None, **kwargs):
154
+ if not choices or len(choices) < 2:
155
+ raise ValueError("RadioAnimated requires at least 2 choices.")
156
+ if value is None:
157
+ value = choices[0]
158
+
159
+ uid = uuid.uuid4().hex[:8]
160
+ group_name = f"ra-{uid}"
161
+
162
+ inputs_html = "\n".join(
163
+ f"""
164
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
165
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
166
+ """
167
+ for i, c in enumerate(choices)
168
+ )
169
+
170
+ html_template = f"""
171
+ <div class="ra-wrap" data-ra="{uid}">
172
+ <div class="ra-inner">
173
+ <div class="ra-highlight"></div>
174
+ {inputs_html}
175
+ </div>
176
+ </div>
177
+ """
178
+
179
+ js_on_load = r"""
180
+ (() => {
181
+ const wrap = element.querySelector('.ra-wrap');
182
+ const inner = element.querySelector('.ra-inner');
183
+ const highlight = element.querySelector('.ra-highlight');
184
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
185
+ if (!inputs.length) return;
186
+ const choices = inputs.map(i => i.value);
187
+ function setHighlightByIndex(idx) {
188
+ const n = choices.length;
189
+ const pct = 100 / n;
190
+ highlight.style.width = `calc(${pct}% - 6px)`;
191
+ highlight.style.transform = `translateX(${idx * 100}%)`;
192
+ }
193
+ function setCheckedByValue(val, shouldTrigger=false) {
194
+ const idx = Math.max(0, choices.indexOf(val));
195
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
196
+ setHighlightByIndex(idx);
197
+ props.value = choices[idx];
198
+ if (shouldTrigger) trigger('change', props.value);
199
+ }
200
+ setCheckedByValue(props.value ?? choices[0], false);
201
+ inputs.forEach((inp) => {
202
+ inp.addEventListener('change', () => {
203
+ setCheckedByValue(inp.value, true);
204
+ });
205
+ });
206
+ })();
207
+ """
208
+
209
+ super().__init__(
210
+ value=value,
211
+ html_template=html_template,
212
+ js_on_load=js_on_load,
213
+ **kwargs
214
+ )
215
+
216
+ def apply_gpu_duration(val: str):
217
+ return int(val)
218
+
219
+ MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
220
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
221
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
222
+ MODEL_ID_V,
223
+ attn_implementation="kernels-community/flash-attn2",
224
+ trust_remote_code=True,
225
+ torch_dtype=torch.float16
226
+ ).to(device).eval()
227
+
228
+ MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
229
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
230
+ model_x = Qwen2VLForConditionalGeneration.from_pretrained(
231
+ MODEL_ID_X,
232
+ attn_implementation="kernels-community/flash-attn2",
233
+ trust_remote_code=True,
234
+ torch_dtype=torch.float16
235
+ ).to(device).eval()
236
+
237
+ MODEL_ID_A = "CohereForAI/aya-vision-8b"
238
+ processor_a = AutoProcessor.from_pretrained(MODEL_ID_A, trust_remote_code=True)
239
+ model_a = AutoModelForImageTextToText.from_pretrained(
240
+ MODEL_ID_A,
241
+ attn_implementation="kernels-community/flash-attn2",
242
+ trust_remote_code=True,
243
+ torch_dtype=torch.float16
244
+ ).to(device).eval()
245
+
246
+ MODEL_ID_W = "allenai/olmOCR-7B-0725"
247
+ processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
248
+ model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
249
+ MODEL_ID_W,
250
+ attn_implementation="kernels-community/flash-attn2",
251
+ trust_remote_code=True,
252
+ torch_dtype=torch.float16
253
+ ).to(device).eval()
254
+
255
+ MODEL_ID_M = "reducto/RolmOCR"
256
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
257
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
258
+ MODEL_ID_M,
259
+ attn_implementation="kernels-community/flash-attn2",
260
+ trust_remote_code=True,
261
+ torch_dtype=torch.float16
262
+ ).to(device).eval()
263
+
264
+ def calc_timeout_duration(model_name: str, text: str, image: Image.Image,
265
+ max_new_tokens: int, temperature: float, top_p: float,
266
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
267
+ """Calculate GPU timeout duration based on the last argument."""
268
+ try:
269
+ return int(gpu_timeout)
270
+ except:
271
+ return 60
272
+
273
+
274
+ @spaces.GPU(duration=calc_timeout_duration)
275
+ def generate_image(model_name: str, text: str, image: Image.Image,
276
+ max_new_tokens: int, temperature: float, top_p: float,
277
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
278
+ """
279
+ Generates responses using the selected model for image input.
280
+ Yields raw text and Markdown-formatted text.
281
+ """
282
+ if model_name == "RolmOCR-7B":
283
+ processor = processor_m
284
+ model = model_m
285
+ elif model_name == "Qwen2-VL-OCR-2B":
286
+ processor = processor_x
287
+ model = model_x
288
+ elif model_name == "Nanonets-OCR2-3B":
289
+ processor = processor_v
290
+ model = model_v
291
+ elif model_name == "Aya-Vision-8B":
292
+ processor = processor_a
293
+ model = model_a
294
+ elif model_name == "olmOCR-7B-0725":
295
+ processor = processor_w
296
+ model = model_w
297
+ else:
298
+ yield "Invalid model selected.", "Invalid model selected."
299
+ return
300
+
301
+ if image is None:
302
+ yield "Please upload an image.", "Please upload an image."
303
+ return
304
+
305
+ messages = [{
306
+ "role": "user",
307
+ "content": [
308
+ {"type": "image"},
309
+ {"type": "text", "text": text},
310
+ ]
311
+ }]
312
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
313
+
314
+ inputs = processor(
315
+ text=[prompt_full],
316
+ images=[image],
317
+ return_tensors="pt",
318
+ padding=True).to(device)
319
+
320
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
321
+ generation_kwargs = {
322
+ **inputs,
323
+ "streamer": streamer,
324
+ "max_new_tokens": max_new_tokens,
325
+ "do_sample": True,
326
+ "temperature": temperature,
327
+ "top_p": top_p,
328
+ "top_k": top_k,
329
+ "repetition_penalty": repetition_penalty,
330
+ }
331
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
332
+ thread.start()
333
+ buffer = ""
334
+ for new_text in streamer:
335
+ buffer += new_text
336
+ buffer = buffer.replace("<|im_end|>", "")
337
+ time.sleep(0.01)
338
+ yield buffer, buffer
339
+
340
+
341
+ image_examples = [
342
+ ["Perform OCR on the image precisely.", "examples/5.jpg"],
343
+ ["Run OCR on the image and ensure high accuracy.", "examples/4.jpg"],
344
+ ["Conduct OCR on the image with exact text recognition.", "examples/2.jpg"],
345
+ ["Perform precise OCR extraction on the image.", "examples/1.jpg"],
346
+ ["Convert this page to docling", "examples/3.jpg"],
347
+ ]
348
+
349
+ with gr.Blocks() as demo:
350
+ gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
351
+ with gr.Row():
352
+ with gr.Column(scale=2):
353
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
354
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
355
+
356
+ image_submit = gr.Button("Submit", variant="primary")
357
+ gr.Examples(
358
+ examples=image_examples,
359
+ inputs=[image_query, image_upload]
360
+ )
361
+
362
+ with gr.Accordion("Advanced options", open=False):
363
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
364
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
365
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
366
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
367
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
368
+
369
+ with gr.Column(scale=3):
370
+ gr.Markdown("## Output", elem_id="output-title")
371
+ output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=11)
372
+ with gr.Accordion("(Result.md)", open=False):
373
+ markdown_output = gr.Markdown(label="(Result.Md)")
374
+
375
+ model_choice = gr.Radio(
376
+ choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
377
+ "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
378
+ label="Select Model",
379
+ value="Nanonets-OCR2-3B"
380
+ )
381
+
382
+ with gr.Row(elem_id="gpu-duration-container"):
383
+ with gr.Column():
384
+ gr.Markdown("**GPU Duration (seconds)**")
385
+ radioanimated_gpu_duration = RadioAnimated(
386
+ choices=["60", "90", "120", "180", "240"],
387
+ value="60",
388
+ elem_id="radioanimated_gpu_duration"
389
+ )
390
+ gpu_duration_state = gr.Number(value=60, visible=False)
391
+
392
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
393
+
394
+ radioanimated_gpu_duration.change(
395
+ fn=apply_gpu_duration,
396
+ inputs=radioanimated_gpu_duration,
397
+ outputs=[gpu_duration_state],
398
+ api_visibility="private"
399
+ )
400
+
401
+ image_submit.click(
402
+ fn=generate_image,
403
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
404
+ outputs=[output, markdown_output]
405
+ )
406
+
407
+ if __name__ == "__main__":
408
+ demo.queue(max_size=50).launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)