prithivMLmods commited on
Commit
07b42b2
·
verified ·
1 Parent(s): 09b73c5

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -406
app.py DELETED
@@ -1,406 +0,0 @@
1
- import os
2
- import sys
3
- import random
4
- import uuid
5
- import json
6
- import time
7
- from threading import Thread
8
- from typing import Iterable
9
- from huggingface_hub import snapshot_download
10
-
11
- import gradio as gr
12
- import spaces
13
- import torch
14
- import numpy as np
15
- from PIL import Image
16
- import cv2
17
-
18
- from transformers import (
19
- Qwen2_5_VLForConditionalGeneration,
20
- Qwen3VLForConditionalGeneration,
21
- AutoModelForImageTextToText,
22
- AutoModelForCausalLM,
23
- AutoProcessor,
24
- TextIteratorStreamer,
25
- )
26
-
27
- from transformers.image_utils import load_image
28
- from gradio.themes import Soft
29
- from gradio.themes.utils import colors, fonts, sizes
30
-
31
- colors.steel_blue = colors.Color(
32
- name="steel_blue",
33
- c50="#EBF3F8",
34
- c100="#D3E5F0",
35
- c200="#A8CCE1",
36
- c300="#7DB3D2",
37
- c400="#529AC3",
38
- c500="#4682B4",
39
- c600="#3E72A0",
40
- c700="#36638C",
41
- c800="#2E5378",
42
- c900="#264364",
43
- c950="#1E3450",
44
- )
45
-
46
- class SteelBlueTheme(Soft):
47
- def __init__(
48
- self,
49
- *,
50
- primary_hue: colors.Color | str = colors.gray,
51
- secondary_hue: colors.Color | str = colors.steel_blue,
52
- neutral_hue: colors.Color | str = colors.slate,
53
- text_size: sizes.Size | str = sizes.text_lg,
54
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
55
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
56
- ),
57
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
58
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
59
- ),
60
- ):
61
- super().__init__(
62
- primary_hue=primary_hue,
63
- secondary_hue=secondary_hue,
64
- neutral_hue=neutral_hue,
65
- text_size=text_size,
66
- font=font,
67
- font_mono=font_mono,
68
- )
69
- super().set(
70
- background_fill_primary="*primary_50",
71
- background_fill_primary_dark="*primary_900",
72
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
73
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
74
- button_primary_text_color="white",
75
- button_primary_text_color_hover="white",
76
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
77
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
78
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
79
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
80
- button_secondary_text_color="black",
81
- button_secondary_text_color_hover="white",
82
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
83
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
84
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
85
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
86
- slider_color="*secondary_500",
87
- slider_color_dark="*secondary_600",
88
- block_title_text_weight="600",
89
- block_border_width="3px",
90
- block_shadow="*shadow_drop_lg",
91
- button_primary_shadow="*shadow_drop_lg",
92
- button_large_padding="11px",
93
- color_accent_soft="*primary_100",
94
- block_label_background_fill="*primary_200",
95
- )
96
-
97
- steel_blue_theme = SteelBlueTheme()
98
-
99
- css = """
100
- #main-title h1 {
101
- font-size: 2.3em !important;
102
- }
103
- #output-title h2 {
104
- font-size: 2.2em !important;
105
- }
106
-
107
- /* RadioAnimated Styles */
108
- .ra-wrap{ width: fit-content; }
109
- .ra-inner{
110
- position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
111
- background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
112
- }
113
- .ra-input{ display: none; }
114
- .ra-label{
115
- position: relative; z-index: 2; padding: 8px 16px;
116
- font-family: inherit; font-size: 14px; font-weight: 600;
117
- color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
118
- }
119
- .ra-highlight{
120
- position: absolute; z-index: 1; top: 6px; left: 6px;
121
- height: calc(100% - 12px); border-radius: 9999px;
122
- background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
123
- transition: transform 0.2s, width 0.2s;
124
- }
125
- .ra-input:checked + .ra-label{ color: black; }
126
-
127
- /* Dark mode adjustments for Radio */
128
- .dark .ra-inner { background: var(--neutral-800); }
129
- .dark .ra-label { color: var(--neutral-400); }
130
- .dark .ra-highlight { background: var(--neutral-600); }
131
- .dark .ra-input:checked + .ra-label { color: white; }
132
-
133
- #gpu-duration-container {
134
- padding: 10px;
135
- border-radius: 8px;
136
- background: var(--background-fill-secondary);
137
- border: 1px solid var(--border-color-primary);
138
- margin-top: 10px;
139
- }
140
- """
141
-
142
- MAX_MAX_NEW_TOKENS = 4096
143
- DEFAULT_MAX_NEW_TOKENS = 2048
144
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
145
-
146
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
147
-
148
- print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
149
- print("torch.__version__ =", torch.__version__)
150
- print("torch.version.cuda =", torch.version.cuda)
151
- print("cuda available:", torch.cuda.is_available())
152
- print("cuda device count:", torch.cuda.device_count())
153
- if torch.cuda.is_available():
154
- print("current device:", torch.cuda.current_device())
155
- print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
156
-
157
- print("Using device:", device)
158
-
159
- class RadioAnimated(gr.HTML):
160
- def __init__(self, choices, value=None, **kwargs):
161
- if not choices or len(choices) < 2:
162
- raise ValueError("RadioAnimated requires at least 2 choices.")
163
- if value is None:
164
- value = choices[0]
165
-
166
- uid = uuid.uuid4().hex[:8]
167
- group_name = f"ra-{uid}"
168
-
169
- inputs_html = "\n".join(
170
- f"""
171
- <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
172
- <label class="ra-label" for="{group_name}-{i}">{c}</label>
173
- """
174
- for i, c in enumerate(choices)
175
- )
176
-
177
- html_template = f"""
178
- <div class="ra-wrap" data-ra="{uid}">
179
- <div class="ra-inner">
180
- <div class="ra-highlight"></div>
181
- {inputs_html}
182
- </div>
183
- </div>
184
- """
185
-
186
- js_on_load = r"""
187
- (() => {
188
- const wrap = element.querySelector('.ra-wrap');
189
- const inner = element.querySelector('.ra-inner');
190
- const highlight = element.querySelector('.ra-highlight');
191
- const inputs = Array.from(element.querySelectorAll('.ra-input'));
192
-
193
- if (!inputs.length) return;
194
-
195
- const choices = inputs.map(i => i.value);
196
-
197
- function setHighlightByIndex(idx) {
198
- const n = choices.length;
199
- const pct = 100 / n;
200
- highlight.style.width = `calc(${pct}% - 6px)`;
201
- highlight.style.transform = `translateX(${idx * 100}%)`;
202
- }
203
-
204
- function setCheckedByValue(val, shouldTrigger=false) {
205
- const idx = Math.max(0, choices.indexOf(val));
206
- inputs.forEach((inp, i) => { inp.checked = (i === idx); });
207
- setHighlightByIndex(idx);
208
-
209
- props.value = choices[idx];
210
- if (shouldTrigger) trigger('change', props.value);
211
- }
212
-
213
- setCheckedByValue(props.value ?? choices[0], false);
214
-
215
- inputs.forEach((inp) => {
216
- inp.addEventListener('change', () => {
217
- setCheckedByValue(inp.value, true);
218
- });
219
- });
220
- })();
221
- """
222
-
223
- super().__init__(
224
- value=value,
225
- html_template=html_template,
226
- js_on_load=js_on_load,
227
- **kwargs
228
- )
229
-
230
- def apply_gpu_duration(val: str):
231
- return int(val)
232
-
233
- MODEL_ID_V = "datalab-to/chandra"
234
- processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
235
- model_v = Qwen3VLForConditionalGeneration.from_pretrained(
236
- MODEL_ID_V,
237
- #attn_implementation="kernels-community/flash-attn2",
238
- trust_remote_code=True,
239
- torch_dtype=torch.float16
240
- ).to(device).eval()
241
-
242
- MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
243
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
244
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
245
- MODEL_ID_X,
246
- #attn_implementation="kernels-community/flash-attn2",
247
- trust_remote_code=True,
248
- torch_dtype=torch.bfloat16,
249
- ).to(device).eval()
250
-
251
- MODEL_PATH_D = "prithivMLmods/Dots.OCR-1.5-Transformers.v5.0" # -> alt of [prithivMLmods/Dots.OCR-Latest-BF16] / [rednote-hilab/dots.ocr]
252
- processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
253
- model_d = AutoModelForCausalLM.from_pretrained(
254
- MODEL_PATH_D,
255
- attn_implementation="flash_attention_2",
256
- torch_dtype=torch.bfloat16,
257
- device_map="auto",
258
- trust_remote_code=True
259
- ).eval()
260
-
261
- MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
262
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
263
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
264
- MODEL_ID_M,
265
- #attn_implementation="kernels-community/flash-attn2",
266
- trust_remote_code=True,
267
- torch_dtype=torch.float16
268
- ).to(device).eval()
269
-
270
- def calc_timeout_image(model_name: str, text: str, image: Image.Image,
271
- max_new_tokens: int, temperature: float, top_p: float,
272
- top_k: int, repetition_penalty: float, gpu_timeout: int):
273
- """Calculate GPU timeout duration for image inference."""
274
- try:
275
- return int(gpu_timeout)
276
- except:
277
- return 60
278
-
279
- @spaces.GPU(duration=calc_timeout_image)
280
- def generate_image(model_name: str, text: str, image: Image.Image,
281
- max_new_tokens: int, temperature: float, top_p: float,
282
- top_k: int, repetition_penalty: float, gpu_timeout: int = 60):
283
- """
284
- Generates responses using the selected model for image input.
285
- Yields raw text and Markdown-formatted text.
286
- """
287
- if model_name == "olmOCR-2-7B-1025":
288
- processor = processor_m
289
- model = model_m
290
- elif model_name == "Nanonets-OCR2-3B":
291
- processor = processor_x
292
- model = model_x
293
- elif model_name == "Chandra-OCR":
294
- processor = processor_v
295
- model = model_v
296
- elif model_name == "Dots.OCR":
297
- processor = processor_d
298
- model = model_d
299
- else:
300
- yield "Invalid model selected.", "Invalid model selected."
301
- return
302
-
303
- if image is None:
304
- yield "Please upload an image.", "Please upload an image."
305
- return
306
-
307
- messages = [{
308
- "role": "user",
309
- "content": [
310
- {"type": "image"},
311
- {"type": "text", "text": text},
312
- ]
313
- }]
314
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
315
-
316
- inputs = processor(
317
- text=[prompt_full],
318
- images=[image],
319
- return_tensors="pt",
320
- padding=True).to(device)
321
-
322
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
323
- generation_kwargs = {
324
- **inputs,
325
- "streamer": streamer,
326
- "max_new_tokens": max_new_tokens,
327
- "do_sample": True,
328
- "temperature": temperature,
329
- "top_p": top_p,
330
- "top_k": top_k,
331
- "repetition_penalty": repetition_penalty,
332
- }
333
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
334
- thread.start()
335
- buffer = ""
336
- for new_text in streamer:
337
- buffer += new_text
338
- buffer = buffer.replace("<|im_end|>", "")
339
- time.sleep(0.01)
340
- yield buffer, buffer
341
-
342
- image_examples = [
343
- ["Convert to Markdown.", "examples/3.jpg"],
344
- ["Perform OCR on the image. [Markdown]", "examples/1.jpg"],
345
- ["Extract the contents. [Markdown].", "examples/2.jpg"],
346
- ]
347
-
348
- with gr.Blocks() as demo:
349
- gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
350
- with gr.Row():
351
- with gr.Column(scale=2):
352
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
353
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
354
-
355
- image_submit = gr.Button("Submit", variant="primary")
356
- gr.Examples(
357
- examples=image_examples,
358
- inputs=[image_query, image_upload]
359
- )
360
-
361
- with gr.Accordion("Advanced options", open=False):
362
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
363
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
364
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
365
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
366
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
367
-
368
- with gr.Column(scale=3):
369
- gr.Markdown("## Output", elem_id="output-title")
370
- output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=15)
371
- with gr.Accordion("(Result.md)", open=False):
372
- markdown_output = gr.Markdown(label="(Result.Md)")
373
-
374
- model_choice = gr.Radio(
375
- choices=["Nanonets-OCR2-3B", "Chandra-OCR", "Dots.OCR", "olmOCR-2-7B-1025"],
376
- label="Select Model",
377
- value="Nanonets-OCR2-3B"
378
- )
379
-
380
- with gr.Row(elem_id="gpu-duration-container"):
381
- with gr.Column():
382
- gr.Markdown("**GPU Duration (seconds)**")
383
- radioanimated_gpu_duration = RadioAnimated(
384
- choices=["60", "90", "120", "180", "240", "300"],
385
- value="60",
386
- elem_id="radioanimated_gpu_duration"
387
- )
388
- gpu_duration_state = gr.Number(value=60, visible=False)
389
-
390
- gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
391
-
392
- radioanimated_gpu_duration.change(
393
- fn=apply_gpu_duration,
394
- inputs=radioanimated_gpu_duration,
395
- outputs=[gpu_duration_state],
396
- api_visibility="private"
397
- )
398
-
399
- image_submit.click(
400
- fn=generate_image,
401
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
402
- outputs=[output, markdown_output]
403
- )
404
-
405
- if __name__ == "__main__":
406
- demo.queue(max_size=50).launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)