Yuanshi commited on
Commit
6359789
·
1 Parent(s): ab20f51
Files changed (1) hide show
  1. app.py +97 -41
app.py CHANGED
@@ -5,6 +5,32 @@ import random
5
  import gradio as gr
6
  import spaces
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @dataclass(frozen=True)
10
  class SliderConfig:
@@ -29,47 +55,55 @@ GAMMA_SLIDER = SliderConfig(
29
  maximum=10.0,
30
  step=0.5,
31
  value=5.0,
32
- info="Scheduler adjustment parameter."
33
  )
34
 
35
  STEP_SLIDER = SliderConfig(
36
  label="Inference Steps",
37
- minimum=10,
38
- maximum=50,
39
  step=1,
40
- value=28,
41
- info="More steps improve quality but take longer."
42
  )
43
 
44
  GUIDANCE_SLIDER = SliderConfig(
45
  label="Guidance Scale (CFG)",
46
  minimum=1.0,
47
- maximum=20.0,
48
  step=0.5,
49
- value=1.5,
50
- info="Controls adherence to the text prompt."
51
  )
52
 
53
 
54
  STYLE_CHOICES = [
55
- "Oil painting style, vivid colors",
56
- "Neon cyberpunk, futuristic city",
57
- "Minimalist sketch, soft shading",
58
- "Anime aesthetic, bold lines",
 
 
 
 
 
 
 
59
  ]
60
 
61
 
62
  EXAMPLE_INPUTS = [
63
- ["assets/video_00000000.mp4", "Oil painting style, vivid colors"],
64
- ["assets/video_00000007.mp4", "Neon cyberpunk, futuristic city"],
65
- ["assets/video_00000107.mp4", "Minimalist sketch, soft shading"],
 
66
  ]
67
 
68
 
69
  PRESET_MODES = {
70
- "Fast": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=1.),
71
- "Balanced": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=1.5),
72
- "Quality": PresetConfig(shift_gamma=5.0, steps=20, guidance_scale=1.5),
73
  }
74
 
75
 
@@ -112,7 +146,30 @@ def run_stylization(
112
  f"Seed={resolved_seed}"
113
  )
114
 
115
- return input_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def _resolve_seed(seed_value, randomize):
@@ -175,26 +232,31 @@ def build_demo() -> gr.Blocks:
175
 
176
  with gr.Row():
177
  with gr.Column(scale=1) as control_col:
 
 
 
 
 
 
178
  with gr.Tabs():
179
  with gr.Tab("Quick Generate"):
180
- prompt_quick = gr.Dropdown(
181
- label="Style Instruction",
182
- choices=STYLE_CHOICES,
183
- value=STYLE_CHOICES[0],
184
- allow_custom_value=False,
185
- )
186
-
187
  with gr.Row():
188
- fast_btn = gr.Button("⚡ Fast Generate", variant="primary")
189
- balanced_btn = gr.Button("🎯 Balanced Generate", variant="primary")
190
- quality_btn = gr.Button("🌟 Quality Generate", variant="primary")
 
 
 
 
 
 
191
 
192
  _bind_preset_button(
193
  button=fast_btn,
194
  preset_key="Fast",
195
  inputs=[
196
  input_video,
197
- prompt_quick,
198
  ],
199
  output=output_video,
200
  extra_kwargs={"seed": None, "randomize_seed": True},
@@ -204,7 +266,7 @@ def build_demo() -> gr.Blocks:
204
  preset_key="Balanced",
205
  inputs=[
206
  input_video,
207
- prompt_quick,
208
  ],
209
  output=output_video,
210
  extra_kwargs={"seed": None, "randomize_seed": True},
@@ -214,19 +276,13 @@ def build_demo() -> gr.Blocks:
214
  preset_key="Quality",
215
  inputs=[
216
  input_video,
217
- prompt_quick,
218
  ],
219
  output=output_video,
220
  extra_kwargs={"seed": None, "randomize_seed": True},
221
  )
222
 
223
  with gr.Tab("Advanced Settings"):
224
- prompt_adv = gr.Dropdown(
225
- label="Style Instruction",
226
- choices=STYLE_CHOICES,
227
- value=STYLE_CHOICES[0],
228
- allow_custom_value=True,
229
- )
230
  with gr.Row():
231
  shift_gamma = _create_slider(GAMMA_SLIDER)
232
  guidance_scale = _create_slider(GUIDANCE_SLIDER)
@@ -235,7 +291,7 @@ def build_demo() -> gr.Blocks:
235
  num_steps = _create_slider(STEP_SLIDER)
236
  randomize_seed_adv = gr.Checkbox(
237
  label="Randomize Seed",
238
- value=False,
239
  info="Checked = new random seed each run. Uncheck to provide your own seed.",
240
  )
241
 
@@ -251,7 +307,7 @@ def build_demo() -> gr.Blocks:
251
  fn=run_stylization,
252
  inputs=[
253
  input_video,
254
- prompt_adv,
255
  shift_gamma,
256
  num_steps,
257
  guidance_scale,
@@ -264,7 +320,7 @@ def build_demo() -> gr.Blocks:
264
  with gr.Column(scale=1):
265
  gr.Examples(
266
  examples=EXAMPLE_INPUTS,
267
- inputs=[input_video, prompt_quick, prompt_adv],
268
  label="Example inputs",
269
  )
270
 
 
5
  import gradio as gr
6
  import spaces
7
 
8
+ import torch
9
+ from diffusers import WanPipeline
10
+ from diffusers.utils import export_to_video, load_video
11
+ from vibt.wan import load_vibt_weight, encode_video
12
+ from vibt.scheduler import ViBTScheduler
13
+ import tempfile
14
+ import os
15
+ import cv2
16
+
17
+
18
+ def get_fps(path):
19
+ cap = cv2.VideoCapture(path)
20
+ fps = cap.get(cv2.CAP_PROP_FPS)
21
+ cap.release()
22
+ return fps
23
+
24
+
25
+ base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
26
+ pipe = WanPipeline.from_pretrained(base_model_id, torch_dtype=torch.bfloat16).to("cuda")
27
+ load_vibt_weight(
28
+ pipe.transformer,
29
+ "Yuanshi/ViBT",
30
+ "video/video_stylization.safetensors",
31
+ )
32
+ pipe.scheduler = ViBTScheduler.from_scheduler(pipe.scheduler)
33
+
34
 
35
  @dataclass(frozen=True)
36
  class SliderConfig:
 
55
  maximum=10.0,
56
  step=0.5,
57
  value=5.0,
58
+ info="Scheduler adjustment parameter.",
59
  )
60
 
61
  STEP_SLIDER = SliderConfig(
62
  label="Inference Steps",
63
+ minimum=6,
64
+ maximum=28,
65
  step=1,
66
+ value=10,
67
+ info="More steps improve quality but take longer.",
68
  )
69
 
70
  GUIDANCE_SLIDER = SliderConfig(
71
  label="Guidance Scale (CFG)",
72
  minimum=1.0,
73
+ maximum=5.0,
74
  step=0.5,
75
+ value=2,
76
+ info="Controls adherence to the text prompt.",
77
  )
78
 
79
 
80
  STYLE_CHOICES = [
81
+ "Make it Illustration style.",
82
+ "Make it a drawing by Van Gogh.",
83
+ "Make it a pencil sketch style.",
84
+ "Make it watercolor drawing style.",
85
+ "Make it a Pixel Art.",
86
+ "Make it a Japanese anime style, cel shading.",
87
+ "Make it the style of Neon Light Art.",
88
+ "Make it papercut style.",
89
+ "Make it a blueprint.",
90
+ "Make it Comic Book Style.",
91
+ "Render the subject as a classical sculpture carved from a single block of pristine white marble.",
92
  ]
93
 
94
 
95
  EXAMPLE_INPUTS = [
96
+ ["assets/video_00000000.mp4", STYLE_CHOICES[0]],
97
+ ["assets/video_00000007.mp4", STYLE_CHOICES[1]],
98
+ ["assets/video_00000019.mp4", STYLE_CHOICES[2]],
99
+ ["assets/video_00000071.mp4", STYLE_CHOICES[3]],
100
  ]
101
 
102
 
103
  PRESET_MODES = {
104
+ "Fast": PresetConfig(shift_gamma=5.0, steps=6, guidance_scale=2),
105
+ "Balanced": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=2),
106
+ "Quality": PresetConfig(shift_gamma=5.0, steps=20, guidance_scale=2),
107
  }
108
 
109
 
 
146
  f"Seed={resolved_seed}"
147
  )
148
 
149
+ source_video = load_video(input_video_path)
150
+ source_video = [each.resize((832, 480)) for each in source_video][:81]
151
+ if len(source_video) < 81:
152
+ source_video += [source_video[-1]] * (81 - len(source_video))
153
+ source_fps = get_fps(input_video_path)
154
+
155
+ source_latents = encode_video(pipe, source_video)
156
+
157
+ pipe.scheduler.set_parameters(
158
+ noise_scale=1.0, shift_gamma=shift_gamma, seed=resolved_seed
159
+ )
160
+
161
+ output = pipe(
162
+ prompt=prompt,
163
+ num_inference_steps=steps,
164
+ guidance_scale=guidance_scale,
165
+ latents=source_latents,
166
+ ).frames[0]
167
+
168
+ tmp_dir = tempfile.mkdtemp()
169
+ out_path = os.path.join(tmp_dir, f"{random.randint(0, 2**31 - 1)}.mp4")
170
+ export_to_video(output, out_path, fps=source_fps)
171
+ print(out_path)
172
+ return out_path
173
 
174
 
175
  def _resolve_seed(seed_value, randomize):
 
232
 
233
  with gr.Row():
234
  with gr.Column(scale=1) as control_col:
235
+ prompt = gr.Dropdown(
236
+ label="Style Instruction",
237
+ choices=STYLE_CHOICES,
238
+ value=STYLE_CHOICES[0],
239
+ allow_custom_value=True,
240
+ )
241
  with gr.Tabs():
242
  with gr.Tab("Quick Generate"):
 
 
 
 
 
 
 
243
  with gr.Row():
244
+ fast_btn = gr.Button(
245
+ " Fast Generate", variant="primary"
246
+ )
247
+ balanced_btn = gr.Button(
248
+ "🎯 Balanced Generate", variant="primary"
249
+ )
250
+ quality_btn = gr.Button(
251
+ "🌟 Quality Generate", variant="primary"
252
+ )
253
 
254
  _bind_preset_button(
255
  button=fast_btn,
256
  preset_key="Fast",
257
  inputs=[
258
  input_video,
259
+ prompt,
260
  ],
261
  output=output_video,
262
  extra_kwargs={"seed": None, "randomize_seed": True},
 
266
  preset_key="Balanced",
267
  inputs=[
268
  input_video,
269
+ prompt,
270
  ],
271
  output=output_video,
272
  extra_kwargs={"seed": None, "randomize_seed": True},
 
276
  preset_key="Quality",
277
  inputs=[
278
  input_video,
279
+ prompt,
280
  ],
281
  output=output_video,
282
  extra_kwargs={"seed": None, "randomize_seed": True},
283
  )
284
 
285
  with gr.Tab("Advanced Settings"):
 
 
 
 
 
 
286
  with gr.Row():
287
  shift_gamma = _create_slider(GAMMA_SLIDER)
288
  guidance_scale = _create_slider(GUIDANCE_SLIDER)
 
291
  num_steps = _create_slider(STEP_SLIDER)
292
  randomize_seed_adv = gr.Checkbox(
293
  label="Randomize Seed",
294
+ value=True,
295
  info="Checked = new random seed each run. Uncheck to provide your own seed.",
296
  )
297
 
 
307
  fn=run_stylization,
308
  inputs=[
309
  input_video,
310
+ prompt,
311
  shift_gamma,
312
  num_steps,
313
  guidance_scale,
 
320
  with gr.Column(scale=1):
321
  gr.Examples(
322
  examples=EXAMPLE_INPUTS,
323
+ inputs=[input_video, prompt],
324
  label="Example inputs",
325
  )
326