prithivMLmods commited on
Commit
91b63e4
·
verified ·
1 Parent(s): f51f005

update [kernels:flash-attn2] (cleaned) ✅

Browse files
Files changed (1) hide show
  1. app.py +145 -13
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoProcessor, AutoModelForImageTextToText
4
  import spaces
5
  from molmo_utils import process_vision_info
@@ -28,7 +29,7 @@ class OrangeRedTheme(Soft):
28
  self,
29
  *,
30
  primary_hue: colors.Color | str = colors.gray,
31
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
32
  neutral_hue: colors.Color | str = colors.slate,
33
  text_size: sizes.Size | str = sizes.text_lg,
34
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -76,11 +77,124 @@ class OrangeRedTheme(Soft):
76
 
77
  orange_red_theme = OrangeRedTheme()
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  MODEL_ID = "allenai/SAGE-MM-Qwen3-VL-4B-SFT_RL"
80
 
81
  print(f"Loading {MODEL_ID}...")
82
  processor = AutoProcessor.from_pretrained(
83
  MODEL_ID,
 
84
  trust_remote_code=True,
85
  dtype="auto",
86
  device_map="auto"
@@ -94,16 +208,21 @@ model = AutoModelForImageTextToText.from_pretrained(
94
  )
95
  print("Model loaded successfully.")
96
 
97
- @spaces.GPU
98
- def process_video(user_text, video_path, max_new_tokens):
 
 
 
 
 
 
 
99
  if not video_path:
100
  return "Please upload a video."
101
 
102
- # Use default prompt if user input is empty
103
  if not user_text.strip():
104
  user_text = "Describe this video in detail."
105
 
106
- # Construct messages for Molmo/Qwen
107
  messages = [
108
  {
109
  "role": "user",
@@ -114,8 +233,6 @@ def process_video(user_text, video_path, max_new_tokens):
114
  }
115
  ]
116
 
117
- # Process Vision Info using molmo_utils
118
- # This samples frames and handles resizing logic automatically
119
  try:
120
  _, videos, video_kwargs = process_vision_info(messages)
121
  videos, video_metadatas = zip(*videos)
@@ -149,10 +266,6 @@ def process_video(user_text, video_path, max_new_tokens):
149
 
150
  return generated_text
151
 
152
- css = """
153
- #main-title h1 {font-size: 2.4em !important;}
154
- """
155
-
156
  with gr.Blocks() as demo:
157
  gr.Markdown("# **SAGE-MM-Video-Reasoning**", elem_id="main-title")
158
  gr.Markdown("Upload a video to get a detailed explanation or ask specific questions using [SAGE-MM-Qwen3-VL](https://huggingface.co/allenai/SAGE-MM-Qwen3-VL-4B-SFT_RL).")
@@ -180,7 +293,19 @@ with gr.Blocks() as demo:
180
  vid_btn = gr.Button("Analyze Video", variant="primary")
181
 
182
  with gr.Column():
183
- vid_text_out = gr.Textbox(label="Model Response", interactive=True, lines=23)
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  gr.Examples(
186
  examples=[
@@ -194,9 +319,16 @@ with gr.Blocks() as demo:
194
  label="Video Examples"
195
  )
196
 
 
 
 
 
 
 
 
197
  vid_btn.click(
198
  fn=process_video,
199
- inputs=[vid_prompt, vid_input, max_tokens_slider],
200
  outputs=[vid_text_out]
201
  )
202
 
 
1
  import gradio as gr
2
  import torch
3
+ import uuid
4
  from transformers import AutoProcessor, AutoModelForImageTextToText
5
  import spaces
6
  from molmo_utils import process_vision_info
 
29
  self,
30
  *,
31
  primary_hue: colors.Color | str = colors.gray,
32
+ secondary_hue: colors.Color | str = colors.orange_red,
33
  neutral_hue: colors.Color | str = colors.slate,
34
  text_size: sizes.Size | str = sizes.text_lg,
35
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
77
 
78
  orange_red_theme = OrangeRedTheme()
79
 
80
+ css = """
81
+ #main-title h1 {font-size: 2.4em !important;}
82
+
83
+ /* RadioAnimated Styles */
84
+ .ra-wrap{ width: fit-content; }
85
+ .ra-inner{
86
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
87
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
88
+ }
89
+ .ra-input{ display: none; }
90
+ .ra-label{
91
+ position: relative; z-index: 2; padding: 8px 16px;
92
+ font-family: inherit; font-size: 14px; font-weight: 600;
93
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
94
+ }
95
+ .ra-highlight{
96
+ position: absolute; z-index: 1; top: 6px; left: 6px;
97
+ height: calc(100% - 12px); border-radius: 9999px;
98
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
99
+ transition: transform 0.2s, width 0.2s;
100
+ }
101
+ .ra-input:checked + .ra-label{ color: black; }
102
+
103
+ /* Dark mode adjustments for Radio */
104
+ .dark .ra-inner { background: var(--neutral-800); }
105
+ .dark .ra-label { color: var(--neutral-400); }
106
+ .dark .ra-highlight { background: var(--neutral-600); }
107
+ .dark .ra-input:checked + .ra-label { color: white; }
108
+
109
+ #gpu-duration-container {
110
+ padding: 10px;
111
+ border-radius: 8px;
112
+ background: var(--background-fill-secondary);
113
+ border: 1px solid var(--border-color-primary);
114
+ margin-top: 10px;
115
+ }
116
+ """
117
+
118
+ class RadioAnimated(gr.HTML):
119
+ def __init__(self, choices, value=None, **kwargs):
120
+ if not choices or len(choices) < 2:
121
+ raise ValueError("RadioAnimated requires at least 2 choices.")
122
+ if value is None:
123
+ value = choices[0]
124
+
125
+ uid = uuid.uuid4().hex[:8]
126
+ group_name = f"ra-{uid}"
127
+
128
+ inputs_html = "\n".join(
129
+ f"""
130
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
131
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
132
+ """
133
+ for i, c in enumerate(choices)
134
+ )
135
+
136
+ html_template = f"""
137
+ <div class="ra-wrap" data-ra="{uid}">
138
+ <div class="ra-inner">
139
+ <div class="ra-highlight"></div>
140
+ {inputs_html}
141
+ </div>
142
+ </div>
143
+ """
144
+
145
+ js_on_load = r"""
146
+ (() => {
147
+ const wrap = element.querySelector('.ra-wrap');
148
+ const inner = element.querySelector('.ra-inner');
149
+ const highlight = element.querySelector('.ra-highlight');
150
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
151
+
152
+ if (!inputs.length) return;
153
+
154
+ const choices = inputs.map(i => i.value);
155
+
156
+ function setHighlightByIndex(idx) {
157
+ const n = choices.length;
158
+ const pct = 100 / n;
159
+ highlight.style.width = `calc(${pct}% - 6px)`;
160
+ highlight.style.transform = `translateX(${idx * 100}%)`;
161
+ }
162
+
163
+ function setCheckedByValue(val, shouldTrigger=false) {
164
+ const idx = Math.max(0, choices.indexOf(val));
165
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
166
+ setHighlightByIndex(idx);
167
+
168
+ props.value = choices[idx];
169
+ if (shouldTrigger) trigger('change', props.value);
170
+ }
171
+
172
+ setCheckedByValue(props.value ?? choices[0], false);
173
+
174
+ inputs.forEach((inp) => {
175
+ inp.addEventListener('change', () => {
176
+ setCheckedByValue(inp.value, true);
177
+ });
178
+ });
179
+ })();
180
+ """
181
+
182
+ super().__init__(
183
+ value=value,
184
+ html_template=html_template,
185
+ js_on_load=js_on_load,
186
+ **kwargs
187
+ )
188
+
189
+ def apply_gpu_duration(val: str):
190
+ return int(val)
191
+
192
  MODEL_ID = "allenai/SAGE-MM-Qwen3-VL-4B-SFT_RL"
193
 
194
  print(f"Loading {MODEL_ID}...")
195
  processor = AutoProcessor.from_pretrained(
196
  MODEL_ID,
197
+ attn_implementation="kernels-community/flash-attn2",
198
  trust_remote_code=True,
199
  dtype="auto",
200
  device_map="auto"
 
208
  )
209
  print("Model loaded successfully.")
210
 
211
+ def calc_timeout_video(user_text: str, video_path: str, max_new_tokens: int, gpu_timeout: int):
212
+ """Calculate GPU timeout duration for video processing."""
213
+ try:
214
+ return int(gpu_timeout)
215
+ except:
216
+ return 90
217
+
218
+ @spaces.GPU(duration=calc_timeout_video)
219
+ def process_video(user_text, video_path, max_new_tokens, gpu_timeout: int = 120):
220
  if not video_path:
221
  return "Please upload a video."
222
 
 
223
  if not user_text.strip():
224
  user_text = "Describe this video in detail."
225
 
 
226
  messages = [
227
  {
228
  "role": "user",
 
233
  }
234
  ]
235
 
 
 
236
  try:
237
  _, videos, video_kwargs = process_vision_info(messages)
238
  videos, video_metadatas = zip(*videos)
 
266
 
267
  return generated_text
268
 
 
 
 
 
269
  with gr.Blocks() as demo:
270
  gr.Markdown("# **SAGE-MM-Video-Reasoning**", elem_id="main-title")
271
  gr.Markdown("Upload a video to get a detailed explanation or ask specific questions using [SAGE-MM-Qwen3-VL](https://huggingface.co/allenai/SAGE-MM-Qwen3-VL-4B-SFT_RL).")
 
293
  vid_btn = gr.Button("Analyze Video", variant="primary")
294
 
295
  with gr.Column():
296
+ vid_text_out = gr.Textbox(label="Model Response", interactive=True, lines=18)
297
+
298
+ with gr.Row(elem_id="gpu-duration-container"):
299
+ with gr.Column():
300
+ gr.Markdown("**GPU Duration (seconds)**")
301
+ radioanimated_gpu_duration = RadioAnimated(
302
+ choices=["90", "120", "180", "240", "300"],
303
+ value="90",
304
+ elem_id="radioanimated_gpu_duration"
305
+ )
306
+ gpu_duration_state = gr.Number(value=90, visible=False)
307
+
308
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
309
 
310
  gr.Examples(
311
  examples=[
 
319
  label="Video Examples"
320
  )
321
 
322
+ radioanimated_gpu_duration.change(
323
+ fn=apply_gpu_duration,
324
+ inputs=radioanimated_gpu_duration,
325
+ outputs=[gpu_duration_state],
326
+ api_visibility="private"
327
+ )
328
+
329
  vid_btn.click(
330
  fn=process_video,
331
+ inputs=[vid_prompt, vid_input, max_tokens_slider, gpu_duration_state],
332
  outputs=[vid_text_out]
333
  )
334