happyfrom commited on
Commit
8d76c77
·
1 Parent(s): bfd65f4

change gpu duration

Browse files
Files changed (2) hide show
  1. app.py +87 -45
  2. diffsynth/pipelines/wan_video.py +3 -3
app.py CHANGED
@@ -6,18 +6,12 @@ import spaces
6
  from huggingface_hub import snapshot_download
7
  from diffsynth import ModelManager, save_video, WanVideoPipeline
8
 
9
- # ====== Public WAN model repo ======
10
- WAN_REPO_ID = "dsr2026/wan" # public model repo
11
- WAN_LOCAL_DIR = "../wan_model" # where to cache downloads
12
 
13
- # ====== Optional LoRA ======
14
- LORA_PATH = "./step=02400.lora_only.ckpt"
15
-
16
-
17
- # ====== Outputs ======
18
  OUT_DIR = "outputs"
19
 
20
- # ====== Fixed inference params (demo) ======
21
  NEGATIVE_PROMPT = (
22
  "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
23
  "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, "
@@ -31,13 +25,35 @@ NUM_INFERENCE_STEPS = 50
31
  FPS = 16
32
  QUALITY = 5
33
 
34
- # ====== Global cache ======
35
  _PIPE = None
36
  _MODEL_FILES = None
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def _ensure_wan_downloaded():
40
- """Download/cache WAN weights from a PUBLIC repo and return file paths."""
41
  global _MODEL_FILES
42
  if _MODEL_FILES is not None:
43
  return _MODEL_FILES
@@ -56,38 +72,28 @@ def _ensure_wan_downloaded():
56
 
57
  missing = [p for p in model_files if not os.path.exists(p)]
58
  if missing:
59
- raise FileNotFoundError(
60
- "Missing model files after snapshot_download:\n"
61
- + "\n".join(missing)
62
- + "\nCheck your model repo filenames."
63
- )
64
 
65
  _MODEL_FILES = model_files
66
  return _MODEL_FILES
67
 
68
- @spaces.GPU
 
69
  def generate(prompt: str, seed: int):
70
  global _PIPE
71
 
72
  if not prompt or not prompt.strip():
73
- return "Please input a prompt.", None
74
 
75
  if _PIPE is None:
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
77
-
78
- try:
79
- model_files = _ensure_wan_downloaded()
80
- except Exception as e:
81
- return f"[Model download/load error]\n{repr(e)}", None
82
 
83
  mm = ModelManager(device="cpu")
84
  mm.load_models(model_files, torch_dtype=torch.bfloat16)
85
 
86
- if LORA_PATH:
87
- if os.path.exists(LORA_PATH):
88
- mm.load_lora(LORA_PATH, lora_alpha=1.0)
89
- else:
90
- print(f"[WARN] LoRA not found, skip: {LORA_PATH}")
91
 
92
  pipe = WanVideoPipeline.from_model_manager(mm, torch_dtype=torch.bfloat16, device=device)
93
  pipe.enable_vram_management(num_persistent_param_in_dit=None)
@@ -105,20 +111,56 @@ def generate(prompt: str, seed: int):
105
  num_frames=NUM_FRAMES,
106
  )
107
  save_video(video, out_path, fps=FPS, quality=QUALITY)
108
-
109
- return f"Saved: {out_path}", out_path
110
-
111
-
112
- with gr.Blocks(title="WAN Demo (Public)") as demo:
113
- gr.Markdown("## WAN Demo (public model)\nInput prompt + seed to generate a video.")
114
-
115
- prompt = gr.Textbox(label="prompt", lines=3, placeholder="Describe a scene...")
116
- seed = gr.Number(label="seed", value=0, precision=0)
117
-
118
- btn = gr.Button("Generate")
119
- log = gr.Textbox(label="log", lines=4)
120
- vid = gr.Video(label="output")
121
-
122
- btn.click(generate, inputs=[prompt, seed], outputs=[log, vid])
123
-
124
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from huggingface_hub import snapshot_download
7
  from diffsynth import ModelManager, save_video, WanVideoPipeline
8
 
9
+ WAN_REPO_ID = "dsr2026/wan"
10
+ WAN_LOCAL_DIR = "../wan_model"
 
11
 
12
+ LORA_PATH = "./step=02400.lora_only.ckpt" # "" to disable
 
 
 
 
13
  OUT_DIR = "outputs"
14
 
 
15
  NEGATIVE_PROMPT = (
16
  "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
17
  "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, "
 
25
  FPS = 16
26
  QUALITY = 5
27
 
 
28
  _PIPE = None
29
  _MODEL_FILES = None
30
 
31
+ DSR_EXAMPLES = [
32
+ "In a quiet forest clearing, a squirrel is on the left of a lamp, then the squirrel scampers to the right of the lamp.",
33
+ "At the edge of a sunny meadow, a dog is on the left of a bucket, then the dog runs to the right of the bucket.",
34
+ "On a rocky hillside with moss, a fox is on the left of a chair, then the fox sprints to the right of the chair.",
35
+ ]
36
+
37
+ INTRODUCTION = '''
38
+ This demo is for SpatialAlign: Aligning Dynamic Spatial Relationships in Video Generation.
39
+
40
+ Users can specify a Dynamic Spatial Relationship prompt to generate videos, following the template:
41
+
42
+ 〈scene〉, the 〈animal〉 is 〈initial SSR〉 the 〈static object〉,
43
+ then the 〈animal〉 〈verb〉 〈final SSR〉 the 〈static object〉.
44
+
45
+ Here, the choice of SSR can be from ['the left of', 'the right of', 'the top of'].
46
+
47
+ For the initial SSR, an 'on' should be put in the front.
48
+
49
+ For the final SSR, an 'to' should be put in the front.
50
+
51
+ Examples are provided for better reference.
52
+
53
+ '''
54
+
55
 
56
  def _ensure_wan_downloaded():
 
57
  global _MODEL_FILES
58
  if _MODEL_FILES is not None:
59
  return _MODEL_FILES
 
72
 
73
  missing = [p for p in model_files if not os.path.exists(p)]
74
  if missing:
75
+ raise FileNotFoundError("Missing model files:\n" + "\n".join(missing))
 
 
 
 
76
 
77
  _MODEL_FILES = model_files
78
  return _MODEL_FILES
79
 
80
+
81
+ @spaces.GPU(duration=240)
82
  def generate(prompt: str, seed: int):
83
  global _PIPE
84
 
85
  if not prompt or not prompt.strip():
86
+ return None
87
 
88
  if _PIPE is None:
89
  device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ model_files = _ensure_wan_downloaded()
 
 
 
 
91
 
92
  mm = ModelManager(device="cpu")
93
  mm.load_models(model_files, torch_dtype=torch.bfloat16)
94
 
95
+ if LORA_PATH and os.path.exists(LORA_PATH):
96
+ mm.load_lora(LORA_PATH, lora_alpha=1.0)
 
 
 
97
 
98
  pipe = WanVideoPipeline.from_model_manager(mm, torch_dtype=torch.bfloat16, device=device)
99
  pipe.enable_vram_management(num_persistent_param_in_dit=None)
 
111
  num_frames=NUM_FRAMES,
112
  )
113
  save_video(video, out_path, fps=FPS, quality=QUALITY)
114
+ return out_path
115
+
116
+
117
+ CSS = """
118
+ /* Make example buttons look like clickable prompt cards */
119
+ #examples-col button{
120
+ white-space: normal !important;
121
+ text-align: left !important;
122
+ line-height: 1.35 !important;
123
+ padding: 12px 12px !important;
124
+ border-radius: 10px !important;
125
+ #main-row { align-items: flex-start !important; }
126
+ }
127
+ """
128
+
129
+
130
+ # Gradio 6.x: pass css=... to launch() to avoid the warning.
131
+ with gr.Blocks(title="SpatialAlign Demo") as demo:
132
+ gr.Markdown("## SpatialAlign Demo")
133
+
134
+ # We'll create example buttons first (to keep layout order),
135
+ # then bind their click handlers AFTER `prompt` is defined.
136
+ example_buttons = []
137
+
138
+ with gr.Row(elem_id="main-row"):
139
+ with gr.Column(scale=4):
140
+ gr.Markdown("### Introduction")
141
+ gr.Markdown(INTRODUCTION) # leave blank for now
142
+
143
+ with gr.Column(scale=3, elem_id="examples-col"):
144
+ gr.Markdown("### Propmt Examples (click to fill prompt)")
145
+ for p in DSR_EXAMPLES:
146
+ b = gr.Button(p)
147
+ example_buttons.append((b, p))
148
+
149
+ with gr.Column(scale=3):
150
+ gr.Markdown("### Generate")
151
+ prompt = gr.Textbox(label="prompt", lines=6, placeholder="Describe a dynamic spatial relationship...")
152
+ seed = gr.Number(label="seed", value=0, precision=0)
153
+ btn = gr.Button("Generate")
154
+ vid = gr.Video(label="output")
155
+ btn.click(generate, inputs=[prompt, seed], outputs=vid)
156
+
157
+ # Bind events after `prompt` exists (fixes NameError; keeps layout order).
158
+ for b, p in example_buttons:
159
+ b.click(fn=lambda _p=p: _p, inputs=None, outputs=prompt)
160
+
161
+ demo.queue().launch(
162
+ server_name="0.0.0.0",
163
+ server_port=7860,
164
+ ssr_mode=False,
165
+ css=CSS,
166
+ )
diffsynth/pipelines/wan_video.py CHANGED
@@ -451,9 +451,9 @@ class WanVideoPipeline(BasePipeline):
451
  # Scheduler
452
  latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
453
 
454
- m = psutil.virtual_memory()
455
- print("RAM used:", m.percent, "%")
456
- print(torch.cuda.memory_summary())
457
 
458
  if vace_reference_image is not None:
459
  latents = latents[:, :, 1:]
 
451
  # Scheduler
452
  latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
453
 
454
+ # m = psutil.virtual_memory()
455
+ # print("RAM used:", m.percent, "%")
456
+ # print(torch.cuda.memory_summary())
457
 
458
  if vace_reference_image is not None:
459
  latents = latents[:, :, 1:]