prithivMLmods commited on
Commit
eb0db2a
·
verified ·
1 Parent(s): e52b5c8

update [kernels:flash-attn2] (cleaned) ✅

Browse files
Files changed (1) hide show
  1. app.py +162 -18
app.py CHANGED
@@ -27,9 +27,6 @@ from transformers.image_utils import load_image
27
  from gradio.themes import Soft
28
  from gradio.themes.utils import colors, fonts, sizes
29
 
30
- # --- Theme and CSS Definition ---
31
-
32
- # Define the new SteelBlue color palette
33
  colors.steel_blue = colors.Color(
34
  name="steel_blue",
35
  c50="#EBF3F8",
@@ -37,7 +34,7 @@ colors.steel_blue = colors.Color(
37
  c200="#A8CCE1",
38
  c300="#7DB3D2",
39
  c400="#529AC3",
40
- c500="#4682B4", # SteelBlue base color
41
  c600="#3E72A0",
42
  c700="#36638C",
43
  c800="#2E5378",
@@ -50,7 +47,7 @@ class SteelBlueTheme(Soft):
50
  self,
51
  *,
52
  primary_hue: colors.Color | str = colors.gray,
53
- secondary_hue: colors.Color | str = colors.steel_blue, # Use the new color
54
  neutral_hue: colors.Color | str = colors.slate,
55
  text_size: sizes.Size | str = sizes.text_lg,
56
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -96,7 +93,6 @@ class SteelBlueTheme(Soft):
96
  block_label_background_fill="*primary_200",
97
  )
98
 
99
- # Instantiate the new theme
100
  steel_blue_theme = SteelBlueTheme()
101
 
102
  css = """
@@ -106,6 +102,40 @@ css = """
106
  #output-title h2 {
107
  font-size: 2.1em !important;
108
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  """
110
 
111
  MAX_MAX_NEW_TOKENS = 4096
@@ -125,11 +155,86 @@ if torch.cuda.is_available():
125
 
126
  print("Using device:", device)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  MODEL_ID_X = "Senqiao/VisionThink-Efficient"
129
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True, use_fast=False)
130
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
131
  MODEL_ID_X,
132
- attn_implementation="flash_attention_2",
133
  trust_remote_code=True,
134
  torch_dtype=torch.float16
135
  ).to(device).eval()
@@ -138,7 +243,7 @@ MODEL_ID_T = "scb10x/typhoon-ocr-3b"
138
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True, use_fast=False)
139
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
140
  MODEL_ID_T,
141
- attn_implementation="flash_attention_2",
142
  trust_remote_code=True,
143
  torch_dtype=torch.float16
144
  ).to(device).eval()
@@ -147,7 +252,7 @@ MODEL_ID_O = "allenai/olmOCR-7B-0225-preview"
147
  processor_o = AutoProcessor.from_pretrained(MODEL_ID_O, trust_remote_code=True, use_fast=False)
148
  model_o = Qwen2VLForConditionalGeneration.from_pretrained(
149
  MODEL_ID_O,
150
- attn_implementation="flash_attention_2",
151
  trust_remote_code=True,
152
  torch_dtype=torch.float16
153
  ).to(device).eval()
@@ -157,7 +262,7 @@ SUBFOLDER = "think-preview"
157
  processor_j = AutoProcessor.from_pretrained(MODEL_ID_J, trust_remote_code=True, subfolder=SUBFOLDER, use_fast=False)
158
  model_j = Qwen2_5_VLForConditionalGeneration.from_pretrained(
159
  MODEL_ID_J,
160
- attn_implementation="flash_attention_2",
161
  trust_remote_code=True,
162
  subfolder=SUBFOLDER,
163
  torch_dtype=torch.float16
@@ -166,7 +271,7 @@ model_j = Qwen2_5_VLForConditionalGeneration.from_pretrained(
166
  MODEL_ID_V4 = 'openbmb/MiniCPM-V-4'
167
  model_v4 = AutoModel.from_pretrained(
168
  MODEL_ID_V4,
169
- attn_implementation="flash_attention_2",
170
  trust_remote_code=True,
171
  torch_dtype=torch.bfloat16,
172
  ).eval().to(device)
@@ -196,13 +301,33 @@ def downsample_video(video_path):
196
  vidcap.release()
197
  return frames
198
 
199
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def generate_image(model_name: str, text: str, image: Image.Image,
201
  max_new_tokens: int = 1024,
202
  temperature: float = 0.6,
203
  top_p: float = 0.9,
204
  top_k: int = 50,
205
- repetition_penalty: float = 1.2):
 
206
  if image is None:
207
  yield "Please upload an image.", "Please upload an image."
208
  return
@@ -239,13 +364,14 @@ def generate_image(model_name: str, text: str, image: Image.Image,
239
  time.sleep(0.01)
240
  yield buffer, buffer
241
 
242
- @spaces.GPU
243
  def generate_video(model_name: str, text: str, video_path: str,
244
  max_new_tokens: int = 1024,
245
  temperature: float = 0.6,
246
  top_p: float = 0.9,
247
  top_k: int = 50,
248
- repetition_penalty: float = 1.2):
 
249
  if video_path is None:
250
  yield "Please upload a video.", "Please upload a video."
251
  return
@@ -299,7 +425,6 @@ def generate_video(model_name: str, text: str, video_path: str,
299
  time.sleep(0.01)
300
  yield buffer, buffer
301
 
302
- # Define examples for image and video inference
303
  image_examples = [
304
  ["Describe the safety measures in the image. Conclude (Safe / Unsafe)..", "images/5.jpg"],
305
  ["Convert this page to doc [markdown] precisely.", "images/3.png"],
@@ -349,14 +474,33 @@ with gr.Blocks() as demo:
349
  value="Lumian-VLR-7B-Thinking"
350
  )
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  image_submit.click(
353
  fn=generate_image,
354
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
355
  outputs=[output, markdown_output]
356
  )
357
  video_submit.click(
358
  fn=generate_video,
359
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
360
  outputs=[output, markdown_output]
361
  )
362
 
 
27
  from gradio.themes import Soft
28
  from gradio.themes.utils import colors, fonts, sizes
29
 
 
 
 
30
  colors.steel_blue = colors.Color(
31
  name="steel_blue",
32
  c50="#EBF3F8",
 
34
  c200="#A8CCE1",
35
  c300="#7DB3D2",
36
  c400="#529AC3",
37
+ c500="#4682B4",
38
  c600="#3E72A0",
39
  c700="#36638C",
40
  c800="#2E5378",
 
47
  self,
48
  *,
49
  primary_hue: colors.Color | str = colors.gray,
50
+ secondary_hue: colors.Color | str = colors.steel_blue,
51
  neutral_hue: colors.Color | str = colors.slate,
52
  text_size: sizes.Size | str = sizes.text_lg,
53
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
93
  block_label_background_fill="*primary_200",
94
  )
95
 
 
96
  steel_blue_theme = SteelBlueTheme()
97
 
98
  css = """
 
102
  #output-title h2 {
103
  font-size: 2.1em !important;
104
  }
105
+
106
+ /* RadioAnimated Styles */
107
+ .ra-wrap{ width: fit-content; }
108
+ .ra-inner{
109
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
110
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
111
+ }
112
+ .ra-input{ display: none; }
113
+ .ra-label{
114
+ position: relative; z-index: 2; padding: 8px 16px;
115
+ font-family: inherit; font-size: 14px; font-weight: 600;
116
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
117
+ }
118
+ .ra-highlight{
119
+ position: absolute; z-index: 1; top: 6px; left: 6px;
120
+ height: calc(100% - 12px); border-radius: 9999px;
121
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
122
+ transition: transform 0.2s, width 0.2s;
123
+ }
124
+ .ra-input:checked + .ra-label{ color: black; }
125
+
126
+ /* Dark mode adjustments for Radio */
127
+ .dark .ra-inner { background: var(--neutral-800); }
128
+ .dark .ra-label { color: var(--neutral-400); }
129
+ .dark .ra-highlight { background: var(--neutral-600); }
130
+ .dark .ra-input:checked + .ra-label { color: white; }
131
+
132
+ #gpu-duration-container {
133
+ padding: 10px;
134
+ border-radius: 8px;
135
+ background: var(--background-fill-secondary);
136
+ border: 1px solid var(--border-color-primary);
137
+ margin-top: 10px;
138
+ }
139
  """
140
 
141
  MAX_MAX_NEW_TOKENS = 4096
 
155
 
156
  print("Using device:", device)
157
 
158
+ # --- RadioAnimated Component ---
159
+ class RadioAnimated(gr.HTML):
160
+ def __init__(self, choices, value=None, **kwargs):
161
+ if not choices or len(choices) < 2:
162
+ raise ValueError("RadioAnimated requires at least 2 choices.")
163
+ if value is None:
164
+ value = choices[0]
165
+
166
+ uid = uuid.uuid4().hex[:8]
167
+ group_name = f"ra-{uid}"
168
+
169
+ inputs_html = "\n".join(
170
+ f"""
171
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
172
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
173
+ """
174
+ for i, c in enumerate(choices)
175
+ )
176
+
177
+ html_template = f"""
178
+ <div class="ra-wrap" data-ra="{uid}">
179
+ <div class="ra-inner">
180
+ <div class="ra-highlight"></div>
181
+ {inputs_html}
182
+ </div>
183
+ </div>
184
+ """
185
+
186
+ js_on_load = r"""
187
+ (() => {
188
+ const wrap = element.querySelector('.ra-wrap');
189
+ const inner = element.querySelector('.ra-inner');
190
+ const highlight = element.querySelector('.ra-highlight');
191
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
192
+
193
+ if (!inputs.length) return;
194
+
195
+ const choices = inputs.map(i => i.value);
196
+
197
+ function setHighlightByIndex(idx) {
198
+ const n = choices.length;
199
+ const pct = 100 / n;
200
+ highlight.style.width = `calc(${pct}% - 6px)`;
201
+ highlight.style.transform = `translateX(${idx * 100}%)`;
202
+ }
203
+
204
+ function setCheckedByValue(val, shouldTrigger=false) {
205
+ const idx = Math.max(0, choices.indexOf(val));
206
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
207
+ setHighlightByIndex(idx);
208
+
209
+ props.value = choices[idx];
210
+ if (shouldTrigger) trigger('change', props.value);
211
+ }
212
+
213
+ setCheckedByValue(props.value ?? choices[0], false);
214
+
215
+ inputs.forEach((inp) => {
216
+ inp.addEventListener('change', () => {
217
+ setCheckedByValue(inp.value, true);
218
+ });
219
+ });
220
+ })();
221
+ """
222
+
223
+ super().__init__(
224
+ value=value,
225
+ html_template=html_template,
226
+ js_on_load=js_on_load,
227
+ **kwargs
228
+ )
229
+
230
+ def apply_gpu_duration(val: str):
231
+ return int(val)
232
+
233
  MODEL_ID_X = "Senqiao/VisionThink-Efficient"
234
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True, use_fast=False)
235
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
236
  MODEL_ID_X,
237
+ attn_implementation="kernels-community/flash-attn2",
238
  trust_remote_code=True,
239
  torch_dtype=torch.float16
240
  ).to(device).eval()
 
243
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True, use_fast=False)
244
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
245
  MODEL_ID_T,
246
+ attn_implementation="kernels-community/flash-attn2",
247
  trust_remote_code=True,
248
  torch_dtype=torch.float16
249
  ).to(device).eval()
 
252
  processor_o = AutoProcessor.from_pretrained(MODEL_ID_O, trust_remote_code=True, use_fast=False)
253
  model_o = Qwen2VLForConditionalGeneration.from_pretrained(
254
  MODEL_ID_O,
255
+ attn_implementation="kernels-community/flash-attn2",
256
  trust_remote_code=True,
257
  torch_dtype=torch.float16
258
  ).to(device).eval()
 
262
  processor_j = AutoProcessor.from_pretrained(MODEL_ID_J, trust_remote_code=True, subfolder=SUBFOLDER, use_fast=False)
263
  model_j = Qwen2_5_VLForConditionalGeneration.from_pretrained(
264
  MODEL_ID_J,
265
+ attn_implementation="kernels-community/flash-attn2",
266
  trust_remote_code=True,
267
  subfolder=SUBFOLDER,
268
  torch_dtype=torch.float16
 
271
  MODEL_ID_V4 = 'openbmb/MiniCPM-V-4'
272
  model_v4 = AutoModel.from_pretrained(
273
  MODEL_ID_V4,
274
+ attn_implementation="kernels-community/flash-attn2",
275
  trust_remote_code=True,
276
  torch_dtype=torch.bfloat16,
277
  ).eval().to(device)
 
301
  vidcap.release()
302
  return frames
303
 
304
+ # --- GPU Timeout Calculation Functions ---
305
+ def calc_timeout_image(model_name: str, text: str, image: Image.Image,
306
+ max_new_tokens: int, temperature: float, top_p: float,
307
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
308
+ """Calculate GPU timeout duration for image inference."""
309
+ try:
310
+ return int(gpu_timeout)
311
+ except:
312
+ return 60
313
+
314
+ def calc_timeout_video(model_name: str, text: str, video_path: str,
315
+ max_new_tokens: int, temperature: float, top_p: float,
316
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
317
+ """Calculate GPU timeout duration for video inference."""
318
+ try:
319
+ return int(gpu_timeout)
320
+ except:
321
+ return 60
322
+
323
+ @spaces.GPU(duration=calc_timeout_image)
324
  def generate_image(model_name: str, text: str, image: Image.Image,
325
  max_new_tokens: int = 1024,
326
  temperature: float = 0.6,
327
  top_p: float = 0.9,
328
  top_k: int = 50,
329
+ repetition_penalty: float = 1.2,
330
+ gpu_timeout: int = 60):
331
  if image is None:
332
  yield "Please upload an image.", "Please upload an image."
333
  return
 
364
  time.sleep(0.01)
365
  yield buffer, buffer
366
 
367
+ @spaces.GPU(duration=calc_timeout_video)
368
  def generate_video(model_name: str, text: str, video_path: str,
369
  max_new_tokens: int = 1024,
370
  temperature: float = 0.6,
371
  top_p: float = 0.9,
372
  top_k: int = 50,
373
+ repetition_penalty: float = 1.2,
374
+ gpu_timeout: int = 90):
375
  if video_path is None:
376
  yield "Please upload a video.", "Please upload a video."
377
  return
 
425
  time.sleep(0.01)
426
  yield buffer, buffer
427
 
 
428
  image_examples = [
429
  ["Describe the safety measures in the image. Conclude (Safe / Unsafe)..", "images/5.jpg"],
430
  ["Convert this page to doc [markdown] precisely.", "images/3.png"],
 
474
  value="Lumian-VLR-7B-Thinking"
475
  )
476
 
477
+ with gr.Row(elem_id="gpu-duration-container"):
478
+ with gr.Column():
479
+ gr.Markdown("**GPU Duration (seconds)**")
480
+ radioanimated_gpu_duration = RadioAnimated(
481
+ choices=["60", "90", "120", "180", "240", "300"],
482
+ value="60",
483
+ elem_id="radioanimated_gpu_duration"
484
+ )
485
+ gpu_duration_state = gr.Number(value=60, visible=False)
486
+
487
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
488
+
489
+ radioanimated_gpu_duration.change(
490
+ fn=apply_gpu_duration,
491
+ inputs=radioanimated_gpu_duration,
492
+ outputs=[gpu_duration_state],
493
+ api_visibility="private"
494
+ )
495
+
496
  image_submit.click(
497
  fn=generate_image,
498
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
499
  outputs=[output, markdown_output]
500
  )
501
  video_submit.click(
502
  fn=generate_video,
503
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
504
  outputs=[output, markdown_output]
505
  )
506