prithivMLmods commited on
Commit
9b77f2c
·
verified ·
1 Parent(s): 1b53aeb

update [kernels:flash-attn2] (cleaned) ✅

Browse files
Files changed (1) hide show
  1. app.py +153 -27
app.py CHANGED
@@ -101,12 +101,46 @@ css = """
101
  font-size: 2.3em !important;
102
  }
103
  #output-title h2 {
104
- font-size: 2.1em !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
106
  """
107
 
108
  MAX_MAX_NEW_TOKENS = 4096
109
- DEFAULT_MAX_NEW_TOKENS = 1024
110
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
111
 
112
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -122,57 +156,130 @@ if torch.cuda.is_available():
122
 
123
  print("Using device:", device)
124
 
125
- MAX_MAX_NEW_TOKENS = 4096
126
- DEFAULT_MAX_NEW_TOKENS = 2048
127
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
128
 
129
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Load Chandra-OCR
132
  MODEL_ID_V = "datalab-to/chandra"
133
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
134
  model_v = Qwen3VLForConditionalGeneration.from_pretrained(
135
  MODEL_ID_V,
136
- attn_implementation="flash_attention_2",
137
  trust_remote_code=True,
138
  torch_dtype=torch.float16
139
  ).to(device).eval()
140
 
141
- # Load Nanonets-OCR2-3B
142
  MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
143
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
144
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
145
  MODEL_ID_X,
146
- attn_implementation="flash_attention_2",
147
  trust_remote_code=True,
148
  torch_dtype=torch.bfloat16,
149
  ).to(device).eval()
150
 
151
- # Load Dots.OCR from the local, patched directory
152
  MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" # -> alt of [rednote-hilab/dots.ocr]
153
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
154
  model_d = AutoModelForCausalLM.from_pretrained(
155
  MODEL_PATH_D,
156
- attn_implementation="flash_attention_2",
157
  torch_dtype=torch.bfloat16,
158
  device_map="auto",
159
  trust_remote_code=True
160
  ).eval()
161
 
162
- # Load olmOCR-2-7B-1025
163
  MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
164
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
165
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
166
  MODEL_ID_M,
167
- attn_implementation="flash_attention_2",
168
  trust_remote_code=True,
169
  torch_dtype=torch.float16
170
  ).to(device).eval()
171
 
172
- @spaces.GPU
 
 
 
 
 
 
 
 
 
173
  def generate_image(model_name: str, text: str, image: Image.Image,
174
  max_new_tokens: int, temperature: float, top_p: float,
175
- top_k: int, repetition_penalty: float):
176
  """
177
  Generates responses using the selected model for image input.
178
  Yields raw text and Markdown-formatted text.
@@ -259,22 +366,41 @@ with gr.Blocks() as demo:
259
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
260
 
261
  with gr.Column(scale=3):
262
- gr.Markdown("## Output", elem_id="output-title")
263
- output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=15)
264
- with gr.Accordion("(Result.md)", open=False):
265
- markdown_output = gr.Markdown(label="(Result.Md)")
266
 
267
- model_choice = gr.Radio(
268
- choices=["Nanonets-OCR2-3B", "Chandra-OCR", "Dots.OCR", "olmOCR-2-7B-1025"],
269
- label="Select Model",
270
- value="Nanonets-OCR2-3B"
271
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  image_submit.click(
274
  fn=generate_image,
275
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
276
  outputs=[output, markdown_output]
277
  )
278
 
279
  if __name__ == "__main__":
280
- demo.queue(max_size=30).launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)
 
101
  font-size: 2.3em !important;
102
  }
103
  #output-title h2 {
104
+ font-size: 2.2em !important;
105
+ }
106
+
107
+ /* RadioAnimated Styles */
108
+ .ra-wrap{ width: fit-content; }
109
+ .ra-inner{
110
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
111
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
112
+ }
113
+ .ra-input{ display: none; }
114
+ .ra-label{
115
+ position: relative; z-index: 2; padding: 8px 16px;
116
+ font-family: inherit; font-size: 14px; font-weight: 600;
117
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
118
+ }
119
+ .ra-highlight{
120
+ position: absolute; z-index: 1; top: 6px; left: 6px;
121
+ height: calc(100% - 12px); border-radius: 9999px;
122
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
123
+ transition: transform 0.2s, width 0.2s;
124
+ }
125
+ .ra-input:checked + .ra-label{ color: black; }
126
+
127
+ /* Dark mode adjustments for Radio */
128
+ .dark .ra-inner { background: var(--neutral-800); }
129
+ .dark .ra-label { color: var(--neutral-400); }
130
+ .dark .ra-highlight { background: var(--neutral-600); }
131
+ .dark .ra-input:checked + .ra-label { color: white; }
132
+
133
+ #gpu-duration-container {
134
+ padding: 10px;
135
+ border-radius: 8px;
136
+ background: var(--background-fill-secondary);
137
+ border: 1px solid var(--border-color-primary);
138
+ margin-top: 10px;
139
  }
140
  """
141
 
142
  MAX_MAX_NEW_TOKENS = 4096
143
+ DEFAULT_MAX_NEW_TOKENS = 2048
144
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
145
 
146
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
156
 
157
  print("Using device:", device)
158
 
159
+ class RadioAnimated(gr.HTML):
160
+ def __init__(self, choices, value=None, **kwargs):
161
+ if not choices or len(choices) < 2:
162
+ raise ValueError("RadioAnimated requires at least 2 choices.")
163
+ if value is None:
164
+ value = choices[0]
165
 
166
+ uid = uuid.uuid4().hex[:8]
167
+ group_name = f"ra-{uid}"
168
+
169
+ inputs_html = "\n".join(
170
+ f"""
171
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
172
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
173
+ """
174
+ for i, c in enumerate(choices)
175
+ )
176
+
177
+ html_template = f"""
178
+ <div class="ra-wrap" data-ra="{uid}">
179
+ <div class="ra-inner">
180
+ <div class="ra-highlight"></div>
181
+ {inputs_html}
182
+ </div>
183
+ </div>
184
+ """
185
+
186
+ js_on_load = r"""
187
+ (() => {
188
+ const wrap = element.querySelector('.ra-wrap');
189
+ const inner = element.querySelector('.ra-inner');
190
+ const highlight = element.querySelector('.ra-highlight');
191
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
192
+
193
+ if (!inputs.length) return;
194
+
195
+ const choices = inputs.map(i => i.value);
196
+
197
+ function setHighlightByIndex(idx) {
198
+ const n = choices.length;
199
+ const pct = 100 / n;
200
+ highlight.style.width = `calc(${pct}% - 6px)`;
201
+ highlight.style.transform = `translateX(${idx * 100}%)`;
202
+ }
203
+
204
+ function setCheckedByValue(val, shouldTrigger=false) {
205
+ const idx = Math.max(0, choices.indexOf(val));
206
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
207
+ setHighlightByIndex(idx);
208
+
209
+ props.value = choices[idx];
210
+ if (shouldTrigger) trigger('change', props.value);
211
+ }
212
+
213
+ setCheckedByValue(props.value ?? choices[0], false);
214
+
215
+ inputs.forEach((inp) => {
216
+ inp.addEventListener('change', () => {
217
+ setCheckedByValue(inp.value, true);
218
+ });
219
+ });
220
+ })();
221
+ """
222
+
223
+ super().__init__(
224
+ value=value,
225
+ html_template=html_template,
226
+ js_on_load=js_on_load,
227
+ **kwargs
228
+ )
229
+
230
+ def apply_gpu_duration(val: str):
231
+ return int(val)
232
 
 
233
  MODEL_ID_V = "datalab-to/chandra"
234
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
235
  model_v = Qwen3VLForConditionalGeneration.from_pretrained(
236
  MODEL_ID_V,
237
+ attn_implementation="kernels-community/flash-attn2",
238
  trust_remote_code=True,
239
  torch_dtype=torch.float16
240
  ).to(device).eval()
241
 
 
242
  MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
243
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
244
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
245
  MODEL_ID_X,
246
+ attn_implementation="kernels-community/flash-attn2",
247
  trust_remote_code=True,
248
  torch_dtype=torch.bfloat16,
249
  ).to(device).eval()
250
 
 
251
  MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" # -> alt of [rednote-hilab/dots.ocr]
252
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
253
  model_d = AutoModelForCausalLM.from_pretrained(
254
  MODEL_PATH_D,
255
+ attn_implementation="kernels-community/flash-attn2",
256
  torch_dtype=torch.bfloat16,
257
  device_map="auto",
258
  trust_remote_code=True
259
  ).eval()
260
 
 
261
  MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
262
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
263
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
264
  MODEL_ID_M,
265
+ attn_implementation="kernels-community/flash-attn2",
266
  trust_remote_code=True,
267
  torch_dtype=torch.float16
268
  ).to(device).eval()
269
 
270
+ def calc_timeout_image(model_name: str, text: str, image: Image.Image,
271
+ max_new_tokens: int, temperature: float, top_p: float,
272
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
273
+ """Calculate GPU timeout duration for image inference."""
274
+ try:
275
+ return int(gpu_timeout)
276
+ except:
277
+ return 60
278
+
279
+ @spaces.GPU(duration=calc_timeout_image)
280
  def generate_image(model_name: str, text: str, image: Image.Image,
281
  max_new_tokens: int, temperature: float, top_p: float,
282
+ top_k: int, repetition_penalty: float, gpu_timeout: int = 60):
283
  """
284
  Generates responses using the selected model for image input.
285
  Yields raw text and Markdown-formatted text.
 
366
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
367
 
368
  with gr.Column(scale=3):
369
+ gr.Markdown("## Output", elem_id="output-title")
370
+ output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=15)
371
+ with gr.Accordion("(Result.md)", open=False):
372
+ markdown_output = gr.Markdown(label="(Result.Md)")
373
 
374
+ model_choice = gr.Radio(
375
+ choices=["Nanonets-OCR2-3B", "Chandra-OCR", "Dots.OCR", "olmOCR-2-7B-1025"],
376
+ label="Select Model",
377
+ value="Nanonets-OCR2-3B"
378
+ )
379
+
380
+ with gr.Row(elem_id="gpu-duration-container"):
381
+ with gr.Column():
382
+ gr.Markdown("**GPU Duration (seconds)**")
383
+ radioanimated_gpu_duration = RadioAnimated(
384
+ choices=["60", "90", "120", "180", "240", "300"],
385
+ value="60",
386
+ elem_id="radioanimated_gpu_duration"
387
+ )
388
+ gpu_duration_state = gr.Number(value=60, visible=False)
389
+
390
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
391
+
392
+ radioanimated_gpu_duration.change(
393
+ fn=apply_gpu_duration,
394
+ inputs=radioanimated_gpu_duration,
395
+ outputs=[gpu_duration_state],
396
+ api_visibility="private"
397
+ )
398
 
399
  image_submit.click(
400
  fn=generate_image,
401
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
402
  outputs=[output, markdown_output]
403
  )
404
 
405
  if __name__ == "__main__":
406
+ demo.queue(max_size=50).launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)