prithivMLmods commited on
Commit
4377887
·
verified ·
1 Parent(s): f67f60c

update app

Browse files
Files changed (1) hide show
  1. app.py +159 -34
app.py CHANGED
@@ -24,9 +24,6 @@ from transformers.image_utils import load_image
24
  from gradio.themes import Soft
25
  from gradio.themes.utils import colors, fonts, sizes
26
 
27
- # --- Theme and CSS Definition ---
28
-
29
- # Define the SteelBlue color palette
30
  colors.steel_blue = colors.Color(
31
  name="steel_blue",
32
  c50="#EBF3F8",
@@ -34,7 +31,7 @@ colors.steel_blue = colors.Color(
34
  c200="#A8CCE1",
35
  c300="#7DB3D2",
36
  c400="#529AC3",
37
- c500="#4682B4", # SteelBlue base color
38
  c600="#3E72A0",
39
  c700="#36638C",
40
  c800="#2E5378",
@@ -92,8 +89,7 @@ class SteelBlueTheme(Soft):
92
  color_accent_soft="*primary_100",
93
  block_label_background_fill="*primary_200",
94
  )
95
-
96
- # Instantiate the new theme
97
  steel_blue_theme = SteelBlueTheme()
98
 
99
  css = """
@@ -101,11 +97,44 @@ css = """
101
  font-size: 2.3em !important;
102
  }
103
  #output-title h2 {
104
- font-size: 2.1em !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
106
  """
107
 
108
- # Constants for text generation
109
  MAX_MAX_NEW_TOKENS = 4096
110
  DEFAULT_MAX_NEW_TOKENS = 1024
111
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -123,62 +152,139 @@ if torch.cuda.is_available():
123
 
124
  print("Using device:", device)
125
 
126
- # --- Model Loading ---
127
- # Load Nanonets-OCR2-3B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
129
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
130
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
131
  MODEL_ID_V,
132
- attn_implementation="flash_attention_2",
133
  trust_remote_code=True,
134
  torch_dtype=torch.float16
135
  ).to(device).eval()
136
 
137
- # Load Qwen2-VL-OCR-2B-Instruct
138
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
139
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
140
  model_x = Qwen2VLForConditionalGeneration.from_pretrained(
141
  MODEL_ID_X,
142
- attn_implementation="flash_attention_2",
143
  trust_remote_code=True,
144
  torch_dtype=torch.float16
145
  ).to(device).eval()
146
 
147
- # Load Aya-Vision-8b
148
  MODEL_ID_A = "CohereForAI/aya-vision-8b"
149
  processor_a = AutoProcessor.from_pretrained(MODEL_ID_A, trust_remote_code=True)
150
  model_a = AutoModelForImageTextToText.from_pretrained(
151
  MODEL_ID_A,
152
- attn_implementation="flash_attention_2",
153
  trust_remote_code=True,
154
  torch_dtype=torch.float16
155
  ).to(device).eval()
156
 
157
- # Load olmOCR-7B-0725
158
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
159
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
160
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
161
  MODEL_ID_W,
162
- attn_implementation="flash_attention_2",
163
  trust_remote_code=True,
164
  torch_dtype=torch.float16
165
  ).to(device).eval()
166
 
167
- # Load RolmOCR
168
  MODEL_ID_M = "reducto/RolmOCR"
169
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
170
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
171
  MODEL_ID_M,
172
- attn_implementation="flash_attention_2",
173
  trust_remote_code=True,
174
  torch_dtype=torch.float16
175
  ).to(device).eval()
176
 
 
 
 
 
 
 
 
 
177
 
178
- @spaces.GPU
 
179
  def generate_image(model_name: str, text: str, image: Image.Image,
180
  max_new_tokens: int, temperature: float, top_p: float,
181
- top_k: int, repetition_penalty: float):
182
  """
183
  Generates responses using the selected model for image input.
184
  Yields raw text and Markdown-formatted text.
@@ -241,6 +347,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
241
  time.sleep(0.01)
242
  yield buffer, buffer
243
 
 
244
  image_examples = [
245
  ["Perform OCR on the image precisely.", "examples/5.jpg"],
246
  ["Run OCR on the image and ensure high accuracy.", "examples/4.jpg"],
@@ -249,7 +356,6 @@ image_examples = [
249
  ["Convert this page to docling", "examples/3.jpg"],
250
  ]
251
 
252
- # Create the Gradio Interface
253
  with gr.Blocks() as demo:
254
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
255
  with gr.Row():
@@ -271,21 +377,40 @@ with gr.Blocks() as demo:
271
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
272
 
273
  with gr.Column(scale=3):
274
- gr.Markdown("## Output", elem_id="output-title")
275
- output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=11)
276
- with gr.Accordion("(Result.md)", open=False):
277
- markdown_output = gr.Markdown(label="(Result.Md)")
278
-
279
- model_choice = gr.Radio(
280
- choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
281
- "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
282
- label="Select Model",
283
- value="Nanonets-OCR2-3B"
284
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  image_submit.click(
287
  fn=generate_image,
288
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
289
  outputs=[output, markdown_output]
290
  )
291
 
 
24
  from gradio.themes import Soft
25
  from gradio.themes.utils import colors, fonts, sizes
26
 
 
 
 
27
  colors.steel_blue = colors.Color(
28
  name="steel_blue",
29
  c50="#EBF3F8",
 
31
  c200="#A8CCE1",
32
  c300="#7DB3D2",
33
  c400="#529AC3",
34
+ c500="#4682B4",
35
  c600="#3E72A0",
36
  c700="#36638C",
37
  c800="#2E5378",
 
89
  color_accent_soft="*primary_100",
90
  block_label_background_fill="*primary_200",
91
  )
92
+
 
93
  steel_blue_theme = SteelBlueTheme()
94
 
95
  css = """
 
97
  font-size: 2.3em !important;
98
  }
99
  #output-title h2 {
100
+ font-size: 2.2em !important;
101
+ }
102
+
103
+ /* RadioAnimated Styles */
104
+ .ra-wrap{ width: fit-content; }
105
+ .ra-inner{
106
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
107
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
108
+ }
109
+ .ra-input{ display: none; }
110
+ .ra-label{
111
+ position: relative; z-index: 2; padding: 8px 16px;
112
+ font-family: inherit; font-size: 14px; font-weight: 600;
113
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
114
+ }
115
+ .ra-highlight{
116
+ position: absolute; z-index: 1; top: 6px; left: 6px;
117
+ height: calc(100% - 12px); border-radius: 9999px;
118
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
119
+ transition: transform 0.2s, width 0.2s;
120
+ }
121
+ .ra-input:checked + .ra-label{ color: black; }
122
+
123
+ /* Dark mode adjustments for Radio */
124
+ .dark .ra-inner { background: var(--neutral-800); }
125
+ .dark .ra-label { color: var(--neutral-400); }
126
+ .dark .ra-highlight { background: var(--neutral-600); }
127
+ .dark .ra-input:checked + .ra-label { color: white; }
128
+
129
+ #gpu-duration-container {
130
+ padding: 10px;
131
+ border-radius: 8px;
132
+ background: var(--background-fill-secondary);
133
+ border: 1px solid var(--border-color-primary);
134
+ margin-top: 10px;
135
  }
136
  """
137
 
 
138
  MAX_MAX_NEW_TOKENS = 4096
139
  DEFAULT_MAX_NEW_TOKENS = 1024
140
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
152
 
153
  print("Using device:", device)
154
 
155
+ class RadioAnimated(gr.HTML):
156
+ def __init__(self, choices, value=None, **kwargs):
157
+ if not choices or len(choices) < 2:
158
+ raise ValueError("RadioAnimated requires at least 2 choices.")
159
+ if value is None:
160
+ value = choices[0]
161
+
162
+ uid = uuid.uuid4().hex[:8]
163
+ group_name = f"ra-{uid}"
164
+
165
+ inputs_html = "\n".join(
166
+ f"""
167
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
168
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
169
+ """
170
+ for i, c in enumerate(choices)
171
+ )
172
+
173
+ html_template = f"""
174
+ <div class="ra-wrap" data-ra="{uid}">
175
+ <div class="ra-inner">
176
+ <div class="ra-highlight"></div>
177
+ {inputs_html}
178
+ </div>
179
+ </div>
180
+ """
181
+
182
+ js_on_load = r"""
183
+ (() => {
184
+ const wrap = element.querySelector('.ra-wrap');
185
+ const inner = element.querySelector('.ra-inner');
186
+ const highlight = element.querySelector('.ra-highlight');
187
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
188
+
189
+ if (!inputs.length) return;
190
+
191
+ const choices = inputs.map(i => i.value);
192
+
193
+ function setHighlightByIndex(idx) {
194
+ const n = choices.length;
195
+ const pct = 100 / n;
196
+ highlight.style.width = `calc(${pct}% - 6px)`;
197
+ highlight.style.transform = `translateX(${idx * 100}%)`;
198
+ }
199
+
200
+ function setCheckedByValue(val, shouldTrigger=false) {
201
+ const idx = Math.max(0, choices.indexOf(val));
202
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
203
+ setHighlightByIndex(idx);
204
+
205
+ props.value = choices[idx];
206
+ if (shouldTrigger) trigger('change', props.value);
207
+ }
208
+
209
+ setCheckedByValue(props.value ?? choices[0], false);
210
+
211
+ inputs.forEach((inp) => {
212
+ inp.addEventListener('change', () => {
213
+ setCheckedByValue(inp.value, true);
214
+ });
215
+ });
216
+ })();
217
+ """
218
+
219
+ super().__init__(
220
+ value=value,
221
+ html_template=html_template,
222
+ js_on_load=js_on_load,
223
+ **kwargs
224
+ )
225
+
226
+ def apply_gpu_duration(val: str):
227
+ return int(val)
228
+
229
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
230
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
231
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
232
  MODEL_ID_V,
233
+ attn_implementation="kernels-community/flash-attn3",
234
  trust_remote_code=True,
235
  torch_dtype=torch.float16
236
  ).to(device).eval()
237
 
 
238
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
239
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
240
  model_x = Qwen2VLForConditionalGeneration.from_pretrained(
241
  MODEL_ID_X,
242
+ attn_implementation="kernels-community/flash-attn3",
243
  trust_remote_code=True,
244
  torch_dtype=torch.float16
245
  ).to(device).eval()
246
 
 
247
  MODEL_ID_A = "CohereForAI/aya-vision-8b"
248
  processor_a = AutoProcessor.from_pretrained(MODEL_ID_A, trust_remote_code=True)
249
  model_a = AutoModelForImageTextToText.from_pretrained(
250
  MODEL_ID_A,
251
+ attn_implementation="kernels-community/flash-attn3",
252
  trust_remote_code=True,
253
  torch_dtype=torch.float16
254
  ).to(device).eval()
255
 
 
256
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
257
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
258
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
259
  MODEL_ID_W,
260
+ attn_implementation="kernels-community/flash-attn3",
261
  trust_remote_code=True,
262
  torch_dtype=torch.float16
263
  ).to(device).eval()
264
 
 
265
  MODEL_ID_M = "reducto/RolmOCR"
266
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
267
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
268
  MODEL_ID_M,
269
+ attn_implementation="kernels-community/flash-attn3",
270
  trust_remote_code=True,
271
  torch_dtype=torch.float16
272
  ).to(device).eval()
273
 
274
+ def calc_timeout_duration(model_name: str, text: str, image: Image.Image,
275
+ max_new_tokens: int, temperature: float, top_p: float,
276
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
277
+ """Calculate GPU timeout duration based on the last argument."""
278
+ try:
279
+ return int(gpu_timeout)
280
+ except:
281
+ return 60
282
 
283
+
284
+ @spaces.GPU(duration=calc_timeout_duration)
285
  def generate_image(model_name: str, text: str, image: Image.Image,
286
  max_new_tokens: int, temperature: float, top_p: float,
287
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
288
  """
289
  Generates responses using the selected model for image input.
290
  Yields raw text and Markdown-formatted text.
 
347
  time.sleep(0.01)
348
  yield buffer, buffer
349
 
350
+
351
  image_examples = [
352
  ["Perform OCR on the image precisely.", "examples/5.jpg"],
353
  ["Run OCR on the image and ensure high accuracy.", "examples/4.jpg"],
 
356
  ["Convert this page to docling", "examples/3.jpg"],
357
  ]
358
 
 
359
  with gr.Blocks() as demo:
360
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
361
  with gr.Row():
 
377
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
378
 
379
  with gr.Column(scale=3):
380
+ gr.Markdown("## Output", elem_id="output-title")
381
+ output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=11)
382
+ with gr.Accordion("(Result.md)", open=False):
383
+ markdown_output = gr.Markdown(label="(Result.Md)")
384
+
385
+ model_choice = gr.Radio(
386
+ choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
387
+ "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
388
+ label="Select Model",
389
+ value="Nanonets-OCR2-3B"
390
+ )
391
+
392
+ with gr.Row(elem_id="gpu-duration-container"):
393
+ with gr.Column():
394
+ gr.Markdown("**GPU Duration (seconds)**")
395
+ radioanimated_gpu_duration = RadioAnimated(
396
+ choices=["60", "90", "120", "180", "240"],
397
+ value="60",
398
+ elem_id="radioanimated_gpu_duration"
399
+ )
400
+ gpu_duration_state = gr.Number(value=60, visible=False)
401
+
402
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
403
+
404
+ radioanimated_gpu_duration.change(
405
+ fn=apply_gpu_duration,
406
+ inputs=radioanimated_gpu_duration,
407
+ outputs=[gpu_duration_state],
408
+ api_visibility="private"
409
+ )
410
 
411
  image_submit.click(
412
  fn=generate_image,
413
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
414
  outputs=[output, markdown_output]
415
  )
416