fffiloni commited on
Commit
c552215
Β·
verified Β·
1 Parent(s): b9092f0

feat(gradio): add user-controlled seed with random mode for video generation

Browse files
Files changed (1) hide show
  1. app_wip.py +24 -9
app_wip.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import sys
4
  import uuid
5
  import shutil
 
6
 
7
  import gradio as gr
8
  import torch
@@ -166,6 +167,7 @@ def reward_forcing_inference(
166
  num_output_frames: int,
167
  use_ema: bool,
168
  output_root: str,
 
169
  progress: gr.Progress,
170
  ):
171
  """
@@ -180,7 +182,12 @@ def reward_forcing_inference(
180
 
181
  # --------------------- Device & randomness ---------------------
182
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
183
- set_seed(0)
 
 
 
 
 
184
 
185
  free_vram = get_cuda_free_memory_gb(device)
186
  logs += f"Free VRAM {free_vram} GB\n"
@@ -290,7 +297,11 @@ def reward_forcing_inference(
290
 
291
  @spaces.GPU(duration=200)
292
  def gradio_generate(
293
- prompt: str, duration: str, use_ema: bool, progress=gr.Progress()
 
 
 
 
294
  ):
295
  """
296
  Triggered by Gradio:
@@ -316,6 +327,7 @@ def gradio_generate(
316
  num_output_frames=num_output_frames,
317
  use_ema=use_ema,
318
  output_root=OUTPUT_ROOT,
 
319
  progress=progress,
320
  )
321
 
@@ -378,6 +390,8 @@ with gr.Blocks(title="Reward Forcing β€” Text-to-Video Demo") as demo:
378
 
379
  πŸ’‘ This model performs best on **detailed prompts with multiple actions or transformations**.
380
 
 
 
381
  > ⏳ The first run may take a little longer while the model loads β€” generation is faster afterwards.
382
  """
383
  )
@@ -389,12 +403,6 @@ with gr.Blocks(title="Reward Forcing β€” Text-to-Video Demo") as demo:
389
  lines=4,
390
  )
391
 
392
- gr.Examples(
393
- examples=examples,
394
- inputs=prompt_in,
395
- label="Example prompts",
396
- )
397
-
398
  with gr.Row():
399
  duration = gr.Radio(
400
  ["5s (21 frames)", "30s (120 frames)"],
@@ -402,6 +410,7 @@ with gr.Blocks(title="Reward Forcing β€” Text-to-Video Demo") as demo:
402
  label="Duration",
403
  )
404
  use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)")
 
405
 
406
  generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")
407
 
@@ -409,9 +418,15 @@ with gr.Blocks(title="Reward Forcing β€” Text-to-Video Demo") as demo:
409
  video_out = gr.Video(label="Generated Video")
410
  logs_out = gr.Textbox(label="Logs", lines=12, interactive=False)
411
 
 
 
 
 
 
 
412
  generate_btn.click(
413
  fn=gradio_generate,
414
- inputs=[prompt_in, duration, use_ema],
415
  outputs=[video_out, logs_out],
416
  )
417
 
 
3
  import sys
4
  import uuid
5
  import shutil
6
+ import random
7
 
8
  import gradio as gr
9
  import torch
 
167
  num_output_frames: int,
168
  use_ema: bool,
169
  output_root: str,
170
+ seed: int,
171
  progress: gr.Progress,
172
  ):
173
  """
 
182
 
183
  # --------------------- Device & randomness ---------------------
184
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
+
186
+ if seed == -1:
187
+ seed = random.randint(0, 2**32 - 1)
188
+
189
+ set_seed(seed)
190
+ logs += f"Seed: {seed}\n"
191
 
192
  free_vram = get_cuda_free_memory_gb(device)
193
  logs += f"Free VRAM {free_vram} GB\n"
 
297
 
298
  @spaces.GPU(duration=200)
299
  def gradio_generate(
300
+ prompt: str,
301
+ duration: str,
302
+ use_ema: bool,
303
+ seed: int,
304
+ progress=gr.Progress(),
305
  ):
306
  """
307
  Triggered by Gradio:
 
327
  num_output_frames=num_output_frames,
328
  use_ema=use_ema,
329
  output_root=OUTPUT_ROOT,
330
+ seed=int(seed),
331
  progress=progress,
332
  )
333
 
 
390
 
391
  πŸ’‘ This model performs best on **detailed prompts with multiple actions or transformations**.
392
 
393
+ 🎲 Set a fixed seed for reproducible results, or use **-1** for a random seed each time.
394
+
395
  > ⏳ The first run may take a little longer while the model loads β€” generation is faster afterwards.
396
  """
397
  )
 
403
  lines=4,
404
  )
405
 
 
 
 
 
 
 
406
  with gr.Row():
407
  duration = gr.Radio(
408
  ["5s (21 frames)", "30s (120 frames)"],
 
410
  label="Duration",
411
  )
412
  use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)")
413
+ seed_in = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
414
 
415
  generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")
416
 
 
418
  video_out = gr.Video(label="Generated Video")
419
  logs_out = gr.Textbox(label="Logs", lines=12, interactive=False)
420
 
421
+ gr.Examples(
422
+ examples=examples,
423
+ inputs=prompt_in,
424
+ label="Example prompts",
425
+ )
426
+
427
  generate_btn.click(
428
  fn=gradio_generate,
429
+ inputs=[prompt_in, duration, use_ema, seed_in],
430
  outputs=[video_out, logs_out],
431
  )
432