BestWishYsh commited on
Commit
2f1331b
verified
1 Parent(s): ba1a1ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -29
app.py CHANGED
@@ -1,20 +1,14 @@
1
- import os
2
- import subprocess
3
- import sys
4
- import time
5
  import tempfile
6
- import zipfile
7
- import torch
8
 
9
  import gradio as gr
10
  import spaces
11
- from diffusers import (
12
- AutoencoderKLWan,
13
- HeliosPyramidPipeline,
14
- HeliosDMDScheduler
15
- )
16
  from diffusers.utils import export_to_video, load_image, load_video
17
 
 
18
  # ---------------------------------------------------------------------------
19
  # Pre-load model
20
  # ---------------------------------------------------------------------------
@@ -23,16 +17,14 @@ MODEL_ID = "BestWishYsh/Helios-Distilled"
23
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
24
  scheduler = HeliosDMDScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")
25
  pipe = HeliosPyramidPipeline.from_pretrained(
26
- MODEL_ID,
27
- vae=vae,
28
- scheduler=scheduler,
29
- torch_dtype=torch.bfloat16,
30
- is_distilled=True
31
  )
32
 
33
  pipe.to("cuda")
34
- # pipe.transformer.set_attention_backend("flash_hub")
35
- pipe.transformer.set_attention_backend("_flash_3_hub")
 
 
36
 
37
  # @spaces.GPU(duration=1500)
38
  # def compile_transformer():
@@ -49,6 +41,7 @@ pipe.transformer.set_attention_backend("_flash_3_hub")
49
  # compiled_transformer = compile_transformer()
50
  # spaces.aoti_apply(compiled_transformer, pipe.transformer)
51
 
 
52
  # ---------------------------------------------------------------------------
53
  # Generation
54
  # ---------------------------------------------------------------------------
@@ -102,6 +95,7 @@ def generate_video(
102
  info = f"Generated in {elapsed:.1f}s 路 {num_frames} frames 路 {height}脳{width}"
103
  return tmp.name, info
104
 
 
105
  # ---------------------------------------------------------------------------
106
  # UI Setup
107
  # ---------------------------------------------------------------------------
@@ -113,18 +107,19 @@ def update_conditional_visibility(mode):
113
  else:
114
  return gr.update(visible=False), gr.update(visible=False)
115
 
 
116
  CSS = """
117
  #header { text-align: center; margin-bottom: 1.5em; }
118
  #header h1 { font-size: 2.2em; margin-bottom: 0.2em; }
119
  .logo { max-height: 100px; margin: 0 auto 10px auto; display: block; }
120
  .link-buttons { display: flex; justify-content: center; gap: 15px; margin-top: 10px; }
121
- .link-buttons a {
122
- background-color: #2b3137;
123
- color: #ffffff !important;
124
- padding: 8px 20px;
125
- border-radius: 6px;
126
- text-decoration: none;
127
- font-weight: 600;
128
  font-size: 1em;
129
  transition: all 0.2s ease-in-out;
130
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
@@ -133,7 +128,7 @@ CSS = """
133
  .contain { max-width: 1350px; margin: 0 auto !important; }
134
  """
135
 
136
- with gr.Blocks(title="Helios Video Generation") as demo:
137
  gr.HTML(
138
  """
139
  <div style='display: flex; align-items: center; justify-content: center; width: 100%;'>
@@ -176,7 +171,7 @@ with gr.Blocks(title="Helios Video Generation") as demo:
176
  "of hard and soft corals in shades of red, orange, and green. The photo captures "
177
  "the fish from a slightly elevated angle, emphasizing its lively movements and the "
178
  "vivid colors of its surroundings. A close-up shot with dynamic movement."
179
- )
180
  )
181
  with gr.Accordion("Advanced Settings", open=False):
182
  with gr.Row():
@@ -198,7 +193,18 @@ with gr.Blocks(title="Helios Video Generation") as demo:
198
  mode.change(fn=update_conditional_visibility, inputs=[mode], outputs=[image_input, video_input])
199
  generate_btn.click(
200
  fn=generate_video,
201
- inputs=[mode, prompt, image_input, video_input, height, width, num_frames, num_inference_steps, seed, is_amplify_first_chunk],
 
 
 
 
 
 
 
 
 
 
 
202
  outputs=[video_output, info_output],
203
  )
204
 
@@ -261,4 +267,5 @@ with gr.Blocks(title="Helios Video Generation") as demo:
261
  )
262
 
263
  if __name__ == "__main__":
264
- demo.launch(share=True, css=CSS, theme=gr.themes.Soft())
 
 
 
 
 
 
1
  import tempfile
2
+ import time
 
3
 
4
  import gradio as gr
5
  import spaces
6
+ import torch
7
+
8
+ from diffusers import AutoencoderKLWan, HeliosDMDScheduler, HeliosPyramidPipeline
 
 
9
  from diffusers.utils import export_to_video, load_image, load_video
10
 
11
+
12
  # ---------------------------------------------------------------------------
13
  # Pre-load model
14
  # ---------------------------------------------------------------------------
 
17
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
18
  scheduler = HeliosDMDScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")
19
  pipe = HeliosPyramidPipeline.from_pretrained(
20
+ MODEL_ID, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16, is_distilled=True
 
 
 
 
21
  )
22
 
23
  pipe.to("cuda")
24
+ try:
25
+ pipe.transformer.set_attention_backend("_flash_3_hub")
26
+ except Exception:
27
+ pipe.transformer.set_attention_backend("flash_hub")
28
 
29
  # @spaces.GPU(duration=1500)
30
  # def compile_transformer():
 
41
  # compiled_transformer = compile_transformer()
42
  # spaces.aoti_apply(compiled_transformer, pipe.transformer)
43
 
44
+
45
  # ---------------------------------------------------------------------------
46
  # Generation
47
  # ---------------------------------------------------------------------------
 
95
  info = f"Generated in {elapsed:.1f}s 路 {num_frames} frames 路 {height}脳{width}"
96
  return tmp.name, info
97
 
98
+
99
  # ---------------------------------------------------------------------------
100
  # UI Setup
101
  # ---------------------------------------------------------------------------
 
107
  else:
108
  return gr.update(visible=False), gr.update(visible=False)
109
 
110
+
111
  CSS = """
112
  #header { text-align: center; margin-bottom: 1.5em; }
113
  #header h1 { font-size: 2.2em; margin-bottom: 0.2em; }
114
  .logo { max-height: 100px; margin: 0 auto 10px auto; display: block; }
115
  .link-buttons { display: flex; justify-content: center; gap: 15px; margin-top: 10px; }
116
+ .link-buttons a {
117
+ background-color: #2b3137;
118
+ color: #ffffff !important;
119
+ padding: 8px 20px;
120
+ border-radius: 6px;
121
+ text-decoration: none;
122
+ font-weight: 600;
123
  font-size: 1em;
124
  transition: all 0.2s ease-in-out;
125
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
 
128
  .contain { max-width: 1350px; margin: 0 auto !important; }
129
  """
130
 
131
+ with gr.Blocks(css=CSS, title="Helios Video Generation", theme=gr.themes.Soft()) as demo:
132
  gr.HTML(
133
  """
134
  <div style='display: flex; align-items: center; justify-content: center; width: 100%;'>
 
171
  "of hard and soft corals in shades of red, orange, and green. The photo captures "
172
  "the fish from a slightly elevated angle, emphasizing its lively movements and the "
173
  "vivid colors of its surroundings. A close-up shot with dynamic movement."
174
+ ),
175
  )
176
  with gr.Accordion("Advanced Settings", open=False):
177
  with gr.Row():
 
193
  mode.change(fn=update_conditional_visibility, inputs=[mode], outputs=[image_input, video_input])
194
  generate_btn.click(
195
  fn=generate_video,
196
+ inputs=[
197
+ mode,
198
+ prompt,
199
+ image_input,
200
+ video_input,
201
+ height,
202
+ width,
203
+ num_frames,
204
+ num_inference_steps,
205
+ seed,
206
+ is_amplify_first_chunk,
207
+ ],
208
  outputs=[video_output, info_output],
209
  )
210
 
 
267
  )
268
 
269
  if __name__ == "__main__":
270
+ # demo.launch(share=True, allowed_paths=["./examples"])
271
+ demo.launch(share=True)