multimodalart HF Staff commited on
Commit
9f13b69
·
verified ·
1 Parent(s): 4dc6132

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -71
app.py CHANGED
@@ -4,14 +4,14 @@ import sys
4
  import time
5
  import tempfile
6
  import zipfile
 
7
 
8
  # ---------------------------------------------------------------------------
9
- # Install private diffusers fork from bundled zip before anything imports it
10
  # ---------------------------------------------------------------------------
11
  _APP_DIR = os.path.dirname(os.path.abspath(__file__))
12
  ZIP_PATH = os.path.join(_APP_DIR, "helios_diffusers.zip")
13
  EXTRACT_DIR = os.path.join(_APP_DIR, "_helios_diffusers")
14
-
15
  _PKG_ROOT = os.path.join(EXTRACT_DIR, "diffusers-new-model-addition-helios-helios")
16
 
17
  if not os.path.isdir(_PKG_ROOT):
@@ -25,16 +25,12 @@ try:
25
  except subprocess.CalledProcessError as e:
26
  print(f"[setup] pip install failed (exit {e.returncode}), falling back to sys.path")
27
 
28
- # Always ensure the src-layout package is importable
29
  _SRC_DIR = os.path.join(_PKG_ROOT, "src")
30
  if os.path.isdir(_SRC_DIR):
31
  sys.path.insert(0, _SRC_DIR)
32
- print(f"[setup] Added {_SRC_DIR} to sys.path")
33
 
34
  import gradio as gr
35
  import spaces
36
- import torch
37
-
38
  from diffusers import (
39
  AutoencoderKLWan,
40
  HeliosPyramidPipeline,
@@ -43,7 +39,7 @@ from diffusers import (
43
  from diffusers.utils import export_to_video, load_image, load_video
44
 
45
  # ---------------------------------------------------------------------------
46
- # Pre-load model at import time (cached on ZeroGPU)
47
  # ---------------------------------------------------------------------------
48
  MODEL_ID = "BestWishYsh/Helios-Distilled"
49
 
@@ -58,13 +54,53 @@ pipe = HeliosPyramidPipeline.from_pretrained(
58
  )
59
  pipe.to("cuda")
60
 
61
- compiled_transformer = spaces.aoti_load("helios_distilled_transformer.pt2")
62
- spaces.aoti_apply(compiled_transformer, pipe.transformer)
63
-
64
- #pipe.transformer.set_attention_backend("_flash_3_hub")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # ---------------------------------------------------------------------------
67
- # Generation (decorated for ZeroGPU)
68
  # ---------------------------------------------------------------------------
69
  @spaces.GPU(duration=300)
70
  def generate_video(
@@ -80,7 +116,6 @@ def generate_video(
80
  is_amplify_first_chunk: bool,
81
  progress=gr.Progress(track_tqdm=True),
82
  ):
83
- """Run the Helios-Distilled pipeline and return the generated video."""
84
  if not prompt:
85
  raise gr.Error("Please provide a prompt.")
86
 
@@ -102,7 +137,6 @@ def generate_video(
102
  "is_amplify_first_chunk": is_amplify_first_chunk,
103
  }
104
 
105
- # Conditional inputs
106
  if mode == "Image-to-Video" and image_input is not None:
107
  img = load_image(image_input).resize((int(width), int(height)))
108
  kwargs["image"] = img
@@ -118,12 +152,10 @@ def generate_video(
118
  info = f"Generated in {elapsed:.1f}s · {num_frames} frames · {height}×{width}"
119
  return tmp.name, info
120
 
121
-
122
  # ---------------------------------------------------------------------------
123
- # Visibility toggle for conditional inputs
124
  # ---------------------------------------------------------------------------
125
  def update_conditional_visibility(mode):
126
- """Show image input for I2V, video input for V2V, hide both for T2V."""
127
  if mode == "Image-to-Video":
128
  return gr.update(visible=True), gr.update(visible=False)
129
  elif mode == "Video-to-Video":
@@ -131,14 +163,9 @@ def update_conditional_visibility(mode):
131
  else:
132
  return gr.update(visible=False), gr.update(visible=False)
133
 
134
-
135
- # ---------------------------------------------------------------------------
136
- # Gradio UI
137
- # ---------------------------------------------------------------------------
138
  CSS = """
139
  #header { text-align: center; margin-bottom: 0.5em; }
140
  #header h1 { font-size: 2.2em; margin-bottom: 0; }
141
- #header p { opacity: 0.7; margin-top: 0.2em; }
142
  .contain { max-width: 1350px; margin: 0 auto !important; }
143
  """
144
 
@@ -147,88 +174,48 @@ with gr.Blocks(css=CSS, title="Helios Video Generation", theme=gr.themes.Soft())
147
  """
148
  <div id="header">
149
  <h1>🎬 Helios 14B distilled</h1>
150
- <p></p>
151
  </div>
152
  """
153
  )
154
 
155
  with gr.Row():
156
- # ---- Left column: Controls ----
157
  with gr.Column(scale=1):
158
  mode = gr.Radio(
159
  choices=["Text-to-Video", "Image-to-Video", "Video-to-Video"],
160
  value="Text-to-Video",
161
  label="Generation Mode",
162
  )
163
-
164
- # Conditional inputs placed above the prompt, visible based on mode
165
- image_input = gr.Image(
166
- label="Image (for I2V)", type="filepath", visible=False
167
- )
168
- video_input = gr.Video(
169
- label="Video (for V2V)", visible=False
170
- )
171
-
172
  prompt = gr.Textbox(
173
  label="Prompt",
174
  lines=4,
175
- placeholder="Describe the video you want to generate…",
176
- value=(
177
- "A vibrant tropical fish swimming gracefully among colorful coral "
178
- "reefs in a clear, turquoise ocean. The fish has bright blue and yellow "
179
- "scales with a small, distinctive orange spot on its side, its fins "
180
- "moving fluidly. A close-up shot with dynamic movement."
181
- ),
182
  )
183
-
184
  with gr.Accordion("Advanced Settings", open=False):
185
  with gr.Row():
186
  height = gr.Number(value=384, label="Height", precision=0, interactive=False)
187
  width = gr.Number(value=640, label="Width", precision=0, interactive=False)
188
  with gr.Row():
189
- num_frames = gr.Slider(33, 240, value=33, step=33, label="Num Frames (must be multiple of 33)")
190
- num_inference_steps = gr.Slider(
191
- 1, 10, value=2, step=1, label="Steps (per pyramid stage)"
192
- )
193
  with gr.Row():
194
  seed = gr.Number(value=42, label="Seed", precision=0)
195
- is_amplify_first_chunk = gr.Checkbox(
196
- label="Amplify First Chunk", value=True
197
- )
198
 
199
  generate_btn = gr.Button("🚀 Generate Video", variant="primary", size="lg")
200
 
201
- # ---- Right column: Output ----
202
  with gr.Column(scale=1):
203
  video_output = gr.Video(label="Generated Video", autoplay=True)
204
  info_output = gr.Textbox(label="Info", interactive=False)
205
 
206
- # ---- Toggle conditional input visibility on mode change ----
207
- mode.change(
208
- fn=update_conditional_visibility,
209
- inputs=[mode],
210
- outputs=[image_input, video_input],
211
- )
212
-
213
- # ---- Generation ----
214
  generate_btn.click(
215
  fn=generate_video,
216
- inputs=[
217
- mode,
218
- prompt,
219
- image_input,
220
- video_input,
221
- height,
222
- width,
223
- num_frames,
224
- num_inference_steps,
225
- seed,
226
- is_amplify_first_chunk,
227
- ],
228
  outputs=[video_output, info_output],
229
  )
230
 
231
- # ---- Examples ----
232
  gr.Examples(
233
  examples=[
234
  [
@@ -257,6 +244,5 @@ with gr.Blocks(css=CSS, title="Helios Video Generation", theme=gr.themes.Soft())
257
  label="Example Prompts",
258
  )
259
 
260
-
261
  if __name__ == "__main__":
262
  demo.launch()
 
4
  import time
5
  import tempfile
6
  import zipfile
7
+ import torch
8
 
9
  # ---------------------------------------------------------------------------
10
+ # Install private diffusers fork
11
  # ---------------------------------------------------------------------------
12
  _APP_DIR = os.path.dirname(os.path.abspath(__file__))
13
  ZIP_PATH = os.path.join(_APP_DIR, "helios_diffusers.zip")
14
  EXTRACT_DIR = os.path.join(_APP_DIR, "_helios_diffusers")
 
15
  _PKG_ROOT = os.path.join(EXTRACT_DIR, "diffusers-new-model-addition-helios-helios")
16
 
17
  if not os.path.isdir(_PKG_ROOT):
 
25
  except subprocess.CalledProcessError as e:
26
  print(f"[setup] pip install failed (exit {e.returncode}), falling back to sys.path")
27
 
 
28
  _SRC_DIR = os.path.join(_PKG_ROOT, "src")
29
  if os.path.isdir(_SRC_DIR):
30
  sys.path.insert(0, _SRC_DIR)
 
31
 
32
  import gradio as gr
33
  import spaces
 
 
34
  from diffusers import (
35
  AutoencoderKLWan,
36
  HeliosPyramidPipeline,
 
39
  from diffusers.utils import export_to_video, load_image, load_video
40
 
41
  # ---------------------------------------------------------------------------
42
+ # Pre-load model
43
  # ---------------------------------------------------------------------------
44
  MODEL_ID = "BestWishYsh/Helios-Distilled"
45
 
 
54
  )
55
  pipe.to("cuda")
56
 
57
+ # ---------------------------------------------------------------------------
58
+ # 🔥 AOT LOADING LOGIC 🔥
59
+ # ---------------------------------------------------------------------------
60
+ AOT_FILENAME = "helios_distilled_transformer.pt2"
61
+ AOT_PATH = os.path.join(_APP_DIR, AOT_FILENAME)
62
+
63
+ def load_aot_model(path, original_module):
64
+ """
65
+ Loads a raw AOTI package (.pt2) and patches the original module.
66
+ """
67
+ print(f"[AOT] Loading AOTI package from {path}...")
68
+
69
+ # 1. Load the compiled runner
70
+ # This returns a torch._inductor.codecache.PyTorchCompiledModule
71
+ compiled_model = torch._inductor.aoti_load_package(path)
72
+
73
+ # 2. We need to load constants (weights) into it.
74
+ # Since we exported with 'package_constants_on_disk': True, weights are inside the pt2.
75
+ # However, to be safe, we usually need to map them.
76
+ # But for a simple load, let's try the direct callable first.
77
+
78
+ # 3. Patch the forward method
79
+ # We create a wrapper to handle the call signature if needed,
80
+ # but AOTI usually preserves the signature of the exported graph.
81
+ original_module.forward = compiled_model
82
+
83
+ # 4. Clear old weights to save VRAM (optional but recommended)
84
+ # BE CAREFUL: This deletes the original weights. If AOT failed to embed them, this breaks things.
85
+ # Since we used default AOTI export, weights are embedded in the .so or .pt2
86
+ original_module.to("meta")
87
+
88
+ print("[AOT] Model patched successfully!")
89
+
90
+ if os.path.exists(AOT_PATH):
91
+ try:
92
+ load_aot_model(AOT_PATH, pipe.transformer)
93
+ except Exception as e:
94
+ print(f"[AOT] ❌ Failed to load compiled graph: {e}")
95
+ # Restore device if failed
96
+ pipe.to("cuda")
97
+ pipe.transformer.set_attention_backend("_flash_3_hub")
98
+ else:
99
+ print(f"[AOT] ⚠️ No compiled graph found at {AOT_PATH}.")
100
+ pipe.transformer.set_attention_backend("_flash_3_hub")
101
 
102
  # ---------------------------------------------------------------------------
103
+ # Generation
104
  # ---------------------------------------------------------------------------
105
  @spaces.GPU(duration=300)
106
  def generate_video(
 
116
  is_amplify_first_chunk: bool,
117
  progress=gr.Progress(track_tqdm=True),
118
  ):
 
119
  if not prompt:
120
  raise gr.Error("Please provide a prompt.")
121
 
 
137
  "is_amplify_first_chunk": is_amplify_first_chunk,
138
  }
139
 
 
140
  if mode == "Image-to-Video" and image_input is not None:
141
  img = load_image(image_input).resize((int(width), int(height)))
142
  kwargs["image"] = img
 
152
  info = f"Generated in {elapsed:.1f}s · {num_frames} frames · {height}×{width}"
153
  return tmp.name, info
154
 
 
155
  # ---------------------------------------------------------------------------
156
+ # UI Setup
157
  # ---------------------------------------------------------------------------
158
  def update_conditional_visibility(mode):
 
159
  if mode == "Image-to-Video":
160
  return gr.update(visible=True), gr.update(visible=False)
161
  elif mode == "Video-to-Video":
 
163
  else:
164
  return gr.update(visible=False), gr.update(visible=False)
165
 
 
 
 
 
166
  CSS = """
167
  #header { text-align: center; margin-bottom: 0.5em; }
168
  #header h1 { font-size: 2.2em; margin-bottom: 0; }
 
169
  .contain { max-width: 1350px; margin: 0 auto !important; }
170
  """
171
 
 
174
  """
175
  <div id="header">
176
  <h1>🎬 Helios 14B distilled</h1>
 
177
  </div>
178
  """
179
  )
180
 
181
  with gr.Row():
 
182
  with gr.Column(scale=1):
183
  mode = gr.Radio(
184
  choices=["Text-to-Video", "Image-to-Video", "Video-to-Video"],
185
  value="Text-to-Video",
186
  label="Generation Mode",
187
  )
188
+ image_input = gr.Image(label="Image (for I2V)", type="filepath", visible=False)
189
+ video_input = gr.Video(label="Video (for V2V)", visible=False)
 
 
 
 
 
 
 
190
  prompt = gr.Textbox(
191
  label="Prompt",
192
  lines=4,
193
+ value="A vibrant tropical fish swimming gracefully...",
 
 
 
 
 
 
194
  )
 
195
  with gr.Accordion("Advanced Settings", open=False):
196
  with gr.Row():
197
  height = gr.Number(value=384, label="Height", precision=0, interactive=False)
198
  width = gr.Number(value=640, label="Width", precision=0, interactive=False)
199
  with gr.Row():
200
+ num_frames = gr.Slider(33, 240, value=33, step=33, label="Num Frames")
201
+ num_inference_steps = gr.Slider(1, 10, value=2, step=1, label="Steps per stage")
 
 
202
  with gr.Row():
203
  seed = gr.Number(value=42, label="Seed", precision=0)
204
+ is_amplify_first_chunk = gr.Checkbox(label="Amplify First Chunk", value=True)
 
 
205
 
206
  generate_btn = gr.Button("🚀 Generate Video", variant="primary", size="lg")
207
 
 
208
  with gr.Column(scale=1):
209
  video_output = gr.Video(label="Generated Video", autoplay=True)
210
  info_output = gr.Textbox(label="Info", interactive=False)
211
 
212
+ mode.change(fn=update_conditional_visibility, inputs=[mode], outputs=[image_input, video_input])
 
 
 
 
 
 
 
213
  generate_btn.click(
214
  fn=generate_video,
215
+ inputs=[mode, prompt, image_input, video_input, height, width, num_frames, num_inference_steps, seed, is_amplify_first_chunk],
 
 
 
 
 
 
 
 
 
 
 
216
  outputs=[video_output, info_output],
217
  )
218
 
 
219
  gr.Examples(
220
  examples=[
221
  [
 
244
  label="Example Prompts",
245
  )
246
 
 
247
  if __name__ == "__main__":
248
  demo.launch()