Ksjsjjdj commited on
Commit
c0d0dd5
·
verified ·
1 Parent(s): 1c1e20c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -87
app.py CHANGED
@@ -2,54 +2,237 @@ import os
2
  import sys
3
  import subprocess
4
  import traceback
 
 
 
 
5
  from pathlib import Path
6
 
7
- def install_dependencies():
8
- commands = [
9
- "pip install spaces-0.1.0-py3-none-any.whl",
10
- "pip install librosa"
11
- ]
12
- for cmd in commands:
13
- os.system(cmd)
14
-
15
- install_dependencies()
16
 
17
  import spaces
18
- import numpy as np
19
- from PIL import Image
20
- import soundfile as sf
21
  import torch
22
- import gradio as gr
23
  import librosa
 
 
 
24
  from huggingface_hub import snapshot_download
 
25
 
26
  try:
27
  import diffusers
 
 
28
  except ImportError:
29
- os.system("pip install diffusers")
30
- import diffusers
31
 
32
- MODEL_ID = "tolgacangoz/Wan2.2-S2V-14B-Diffusers"
33
- try:
34
- LOCAL_DIR = snapshot_download(repo_id=MODEL_ID, repo_type="model")
35
- except Exception:
36
- LOCAL_DIR = MODEL_ID
37
 
38
- pipe = None
 
 
 
 
39
 
40
- def load_audio_for_model(audio_filepath):
 
 
 
 
41
  try:
42
- wav, sr = librosa.load(audio_filepath, sr=16000)
43
- return wav, sr
44
- except Exception:
45
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- def to_pil(image):
48
- if image is None: return None
49
- if isinstance(image, Image.Image): return image.convert("RGB")
50
- if isinstance(image, str): return Image.open(image).convert("RGB")
51
- arr = np.array(image)
52
- return Image.fromarray(arr).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def merge_audio_video(video_path, audio_path, output_path):
55
  cmd = [
@@ -65,74 +248,110 @@ def merge_audio_video(video_path, audio_path, output_path):
65
  subprocess.run(cmd, check=True)
66
  return output_path
67
 
68
- @spaces.GPU(duration=120)
69
- def generate_video(image_input, audio_filepath, prompt):
70
- global pipe
71
-
 
 
 
 
 
 
 
 
 
 
72
  if image_input is None or audio_filepath is None:
73
- raise gr.Error("Error inputs")
 
 
 
 
74
 
75
- try:
76
- if pipe is None:
77
- from diffusers import WanSpeechToVideoPipeline
78
-
79
- pipe = WanSpeechToVideoPipeline.from_pretrained(
80
- LOCAL_DIR,
81
- use_safetensors=True,
82
- torch_dtype=torch.float32
83
- ).to("cpu")
84
-
85
- audio_values, sample_rate = load_audio_for_model(audio_filepath)
86
- init_image = to_pil(image_input)
87
-
88
- w, h = init_image.size
89
- w = (w // 16) * 16
90
- h = (h // 16) * 16
91
- init_image = init_image.resize((w, h), Image.LANCZOS)
92
 
 
93
  out = pipe(
94
  image=init_image,
95
  audio=audio_values,
96
  num_inference_steps=25,
97
  guidance_scale=4.0,
98
  sampling_rate=sample_rate,
99
- prompt=prompt
 
100
  )
101
-
102
- frames = out.frames[0]
103
-
104
- temp_mute_video = "temp_mute.mp4"
105
- final_video = "output_s2v.mp4"
106
-
107
- from diffusers.utils import export_to_video
108
- export_to_video(frames, temp_mute_video, fps=16)
109
-
110
- final_output = merge_audio_video(temp_mute_video, audio_filepath, final_video)
111
-
112
- return final_output
113
 
114
- except Exception as e:
115
- traceback.print_exc()
116
- raise gr.Error(str(e))
 
117
 
118
- with gr.Blocks(title="Wan2.1 Speech to Video") as demo:
119
- gr.Markdown("# Wan2.2-S2V Generador de Video")
120
 
121
- with gr.Row():
122
- with gr.Column():
123
- img_input = gr.Image(label="Imagen de referencia", type="pil")
124
- audio_input = gr.Audio(label="Audio (.wav)", type="filepath")
125
- prompt_input = gr.Textbox(label="Prompt")
126
- btn = gr.Button("Generar Video", variant="primary")
127
-
128
- with gr.Column():
129
- video_output = gr.Video(label="Resultado")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- btn.click(
132
- fn=generate_video,
133
- inputs=[img_input, audio_input, prompt_input],
134
- outputs=video_output
135
- )
 
 
 
 
 
 
 
 
 
 
136
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
2
  import sys
3
  import subprocess
4
  import traceback
5
+ import gc
6
+ import tempfile
7
+ import random
8
+ import time
9
  from pathlib import Path
10
 
11
+ os.system("pip install spaces-0.1.0-py3-none-any.whl moviepy==1.0.3 imageio[ffmpeg] librosa soundfile diffusers accelerate")
 
 
 
 
 
 
 
 
12
 
13
  import spaces
 
 
 
14
  import torch
15
+ import numpy as np
16
  import librosa
17
+ import soundfile as sf
18
+ from PIL import Image
19
+ from moviepy.editor import VideoFileClip, concatenate_videoclips
20
  from huggingface_hub import snapshot_download
21
+ import gradio as gr
22
 
23
  try:
24
  import diffusers
25
+ from diffusers import AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline, UniPCMultistepScheduler, WanSpeechToVideoPipeline
26
+ from diffusers.utils import export_to_video
27
  except ImportError:
28
+ pass
 
29
 
30
+ MODEL_ID_TI2V = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers"
31
+ MODEL_ID_S2V = "tolgacangoz/Wan2.2-S2V-14B-Diffusers"
 
 
 
32
 
33
+ MODELS = {
34
+ "ti2v_text": None,
35
+ "ti2v_image": None,
36
+ "s2v": None
37
+ }
38
 
39
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+ def load_models_at_startup():
42
+ global MODELS
43
+
44
  try:
45
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID_TI2V, subfolder="vae", torch_dtype=torch.float32)
46
+
47
+ text_pipe = WanPipeline.from_pretrained(MODEL_ID_TI2V, vae=vae, torch_dtype=torch.bfloat16)
48
+ text_pipe.scheduler = UniPCMultistepScheduler.from_config(text_pipe.scheduler.config, flow_shift=8.0)
49
+
50
+ try:
51
+ if DEVICE == "cuda":
52
+ text_pipe.enable_model_cpu_offload()
53
+ else:
54
+ text_pipe.to(DEVICE)
55
+ except RuntimeError:
56
+ text_pipe.to("cpu")
57
+
58
+ MODELS["ti2v_text"] = text_pipe
59
+
60
+ image_pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID_TI2V, vae=vae, torch_dtype=torch.bfloat16)
61
+ image_pipe.scheduler = UniPCMultistepScheduler.from_config(image_pipe.scheduler.config, flow_shift=8.0)
62
+
63
+ try:
64
+ if DEVICE == "cuda":
65
+ image_pipe.enable_model_cpu_offload()
66
+ else:
67
+ image_pipe.to(DEVICE)
68
+ except RuntimeError:
69
+ image_pipe.to("cpu")
70
+
71
+ MODELS["ti2v_image"] = image_pipe
72
+
73
+ except Exception as e:
74
+ pass
75
+
76
+ try:
77
+ s2v_pipe = WanSpeechToVideoPipeline.from_pretrained(
78
+ MODEL_ID_S2V,
79
+ torch_dtype=torch.bfloat16
80
+ )
81
+ try:
82
+ if DEVICE == "cuda":
83
+ s2v_pipe.enable_model_cpu_offload()
84
+ else:
85
+ s2v_pipe.to(DEVICE)
86
+ except RuntimeError:
87
+ s2v_pipe.to("cpu")
88
+
89
+ MODELS["s2v"] = s2v_pipe
90
+ except Exception as e:
91
+ pass
92
+
93
+ load_models_at_startup()
94
+
95
+ def auto_duration_estimator(mode, input_data, duration_val):
96
+ base_overhead = 45
97
+ if mode == "s2v":
98
+ audio_path = input_data
99
+ if audio_path:
100
+ try:
101
+ dur = librosa.get_duration(filename=audio_path)
102
+ return int(base_overhead + (dur * 15))
103
+ except:
104
+ return 120
105
+ return 120
106
+ else:
107
+ num_images = len(input_data) if input_data else 0
108
+ if num_images > 0:
109
+ total_seconds = max(duration_val, num_images * 2)
110
+ else:
111
+ total_seconds = duration_val
112
+ return int(base_overhead + (total_seconds * 12))
113
+
114
+ def fast_stitch_videos(video_paths):
115
+ if not video_paths: return None
116
+ if len(video_paths) == 1: return video_paths[0]
117
+
118
+ try:
119
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
120
+ for path in video_paths:
121
+ f.write(f"file '{path}'\n")
122
+ list_path = f.name
123
+
124
+ with tempfile.NamedTemporaryFile(suffix="_stitched_stream.mp4", delete=False) as tmp:
125
+ out_path = tmp.name
126
+
127
+ cmd = [
128
+ "ffmpeg", "-y", "-f", "concat", "-safe", "0",
129
+ "-i", list_path, "-c", "copy", out_path
130
+ ]
131
+ subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
132
+ os.remove(list_path)
133
+ return out_path
134
+ except:
135
+ return video_paths[-1]
136
+
137
+ @spaces.GPU(duration=lambda *args: auto_duration_estimator("ti2v", args[0], args[5]))
138
+ def generate_ti2v_gpu_stream(input_files, prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
139
+ global MODELS
140
+ text_to_video_pipe = MODELS.get("ti2v_text")
141
+ image_to_video_pipe = MODELS.get("ti2v_image")
142
+
143
+ if not text_to_video_pipe or not image_to_video_pipe:
144
+ raise gr.Error("Models failed to load at startup. Check system memory.")
145
 
146
+ MOD_VALUE = 32
147
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
148
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
149
+
150
+ master_seed = random.randint(0, 2**32 - 1) if randomize_seed else int(seed)
151
+
152
+ video_clips_paths = []
153
+ pil_images = []
154
+
155
+ if input_files:
156
+ files_list = input_files if isinstance(input_files, list) else [input_files]
157
+ for f in files_list:
158
+ try:
159
+ path = f.name if hasattr(f, "name") else f
160
+ img = Image.open(path).convert("RGB")
161
+ pil_images.append(img)
162
+ except:
163
+ continue
164
+
165
+ SAFE_CHUNK_DURATION = 4.0
166
+ FIXED_FPS = 24
167
+
168
+ last_preview_frame = None
169
+
170
+ if len(pil_images) > 0:
171
+ seconds_per_image = max(2.0, duration_seconds / len(pil_images))
172
+
173
+ for i, img in enumerate(pil_images):
174
+ current_chunk_duration = min(seconds_per_image, SAFE_CHUNK_DURATION)
175
+ num_frames = int(current_chunk_duration * FIXED_FPS)
176
+
177
+ local_seed = master_seed + i
178
+ generator = torch.Generator(device=DEVICE).manual_seed(local_seed)
179
+ resized_image = img.resize((target_w, target_h))
180
+
181
+ try:
182
+ with torch.inference_mode():
183
+ output_frames = image_to_video_pipe(
184
+ image=resized_image,
185
+ prompt=prompt,
186
+ negative_prompt=negative_prompt,
187
+ height=target_h,
188
+ width=target_w,
189
+ num_frames=num_frames,
190
+ guidance_scale=float(guidance_scale),
191
+ num_inference_steps=int(steps),
192
+ generator=generator
193
+ ).frames[0]
194
+
195
+ with tempfile.NamedTemporaryFile(suffix=f"_img_{i}.mp4", delete=False) as tmp:
196
+ export_to_video(output_frames, tmp.name, fps=FIXED_FPS)
197
+ video_clips_paths.append(tmp.name)
198
+
199
+ if len(output_frames) > 0:
200
+ last_preview_frame = output_frames[-1]
201
+
202
+ current_stitched = fast_stitch_videos(video_clips_paths)
203
+ yield current_stitched, last_preview_frame, master_seed
204
+
205
+ except Exception:
206
+ continue
207
+ else:
208
+ num_chunks = int(np.ceil(duration_seconds / SAFE_CHUNK_DURATION))
209
+ frames_per_chunk = int(SAFE_CHUNK_DURATION * FIXED_FPS)
210
+
211
+ for i in range(num_chunks):
212
+ chunk_seed = master_seed + (i * 100)
213
+ generator = torch.Generator(device=DEVICE).manual_seed(chunk_seed)
214
+
215
+ with torch.inference_mode():
216
+ output_frames = text_to_video_pipe(
217
+ prompt=prompt,
218
+ negative_prompt=negative_prompt,
219
+ height=target_h,
220
+ width=target_w,
221
+ num_frames=frames_per_chunk,
222
+ guidance_scale=float(guidance_scale),
223
+ num_inference_steps=int(steps),
224
+ generator=generator
225
+ ).frames[0]
226
+
227
+ with tempfile.NamedTemporaryFile(suffix=f"_chunk_{i}.mp4", delete=False) as tmp:
228
+ export_to_video(output_frames, tmp.name, fps=FIXED_FPS)
229
+ video_clips_paths.append(tmp.name)
230
+
231
+ if len(output_frames) > 0:
232
+ last_preview_frame = output_frames[-1]
233
+
234
+ current_stitched = fast_stitch_videos(video_clips_paths)
235
+ yield current_stitched, last_preview_frame, master_seed
236
 
237
  def merge_audio_video(video_path, audio_path, output_path):
238
  cmd = [
 
248
  subprocess.run(cmd, check=True)
249
  return output_path
250
 
251
+ def load_audio_for_model(audio_filepath):
252
+ try:
253
+ wav, sr = librosa.load(audio_filepath, sr=16000)
254
+ return wav, sr
255
+ except:
256
+ return None, None
257
+
258
+ @spaces.GPU(duration=lambda *args: auto_duration_estimator("s2v", args[1], 0))
259
+ def generate_s2v_gpu(image_input, audio_filepath, prompt, seed, randomize_seed):
260
+ global MODELS
261
+ pipe = MODELS.get("s2v")
262
+ if not pipe:
263
+ raise gr.Error("S2V Model not initialized.")
264
+
265
  if image_input is None or audio_filepath is None:
266
+ raise gr.Error("Inputs Missing")
267
+
268
+ audio_values, sample_rate = load_audio_for_model(audio_filepath)
269
+ if audio_values is None:
270
+ raise gr.Error("Invalid Audio")
271
 
272
+ init_image = image_input.convert("RGB")
273
+ w, h = init_image.size
274
+ w = (w // 16) * 16
275
+ h = (h // 16) * 16
276
+ init_image = init_image.resize((w, h), Image.LANCZOS)
277
+
278
+ current_seed = random.randint(0, 2**32 - 1) if randomize_seed else int(seed)
279
+ generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
 
 
 
 
 
 
 
 
 
280
 
281
+ with torch.inference_mode():
282
  out = pipe(
283
  image=init_image,
284
  audio=audio_values,
285
  num_inference_steps=25,
286
  guidance_scale=4.0,
287
  sampling_rate=sample_rate,
288
+ prompt=prompt,
289
+ generator=generator
290
  )
291
+
292
+ frames = out.frames[0]
293
+
294
+ with tempfile.NamedTemporaryFile(suffix="_temp_mute.mp4", delete=False) as tmp_vid:
295
+ temp_mute_path = tmp_vid.name
296
+
297
+ with tempfile.NamedTemporaryFile(suffix="_output_s2v.mp4", delete=False) as tmp_final:
298
+ final_video_path = tmp_final.name
 
 
 
 
299
 
300
+ export_to_video(frames, temp_mute_path, fps=30)
301
+ final_output = merge_audio_video(temp_mute_path, audio_filepath, final_video_path)
302
+
303
+ return final_output, current_seed
304
 
305
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
306
+ gr.Markdown("# Wan 2.2 Unified Streaming Video Platform")
307
 
308
+ with gr.Tabs():
309
+ with gr.TabItem("Text & Image to Video (Streaming & Long Duration)"):
310
+ with gr.Row():
311
+ with gr.Column(scale=1):
312
+ ti2v_files = gr.File(label="Input Images", file_count="multiple", type="filepath", file_types=["image"])
313
+ ti2v_prompt = gr.Textbox(label="Prompt", value="Cinematic view, realistic lighting, 4k", lines=2)
314
+ ti2v_duration = gr.Slider(minimum=2, maximum=300, step=1, value=5, label="Total Duration (s)")
315
+
316
+ with gr.Accordion("Advanced", open=False):
317
+ ti2v_neg = gr.Textbox(label="Negative Prompt", value="low quality, distortion, text, watermark", lines=2)
318
+ ti2v_seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=42)
319
+ ti2v_rand = gr.Checkbox(label="Random Seed", value=True)
320
+ with gr.Row():
321
+ ti2v_h = gr.Slider(256, 1024, 32, 832, label="Height")
322
+ ti2v_w = gr.Slider(256, 1024, 32, 832, label="Width")
323
+ ti2v_steps = gr.Slider(2, 10, 1, 4, label="Steps")
324
+ ti2v_scale = gr.Slider(1.0, 8.0, 0.1, 5.0, label="CFG")
325
+
326
+ btn_ti2v = gr.Button("Start Streaming Generation", variant="primary")
327
+
328
+ with gr.Column(scale=2):
329
+ with gr.Row():
330
+ out_ti2v = gr.Video(label="Live Video Stream", autoplay=True)
331
+ out_preview_ti2v = gr.Image(label="Last Frame Preview", interactive=False)
332
+ out_seed_ti2v = gr.Number(label="Seed Used")
333
+
334
+ btn_ti2v.click(
335
+ fn=generate_ti2v_gpu_stream,
336
+ inputs=[ti2v_files, ti2v_prompt, ti2v_h, ti2v_w, ti2v_neg, ti2v_duration, ti2v_scale, ti2v_steps, ti2v_seed, ti2v_rand],
337
+ outputs=[out_ti2v, out_preview_ti2v, out_seed_ti2v]
338
+ )
339
 
340
+ with gr.TabItem("Speech to Video (S2V)"):
341
+ with gr.Row():
342
+ with gr.Column(scale=1):
343
+ s2v_img = gr.Image(label="Reference Image", type="pil")
344
+ s2v_audio = gr.Audio(label="Audio Input", type="filepath")
345
+ s2v_prompt = gr.Textbox(label="Prompt", value="Realistic movement, talking face")
346
+ s2v_seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=42)
347
+ s2v_rand = gr.Checkbox(label="Random Seed", value=True)
348
+ btn_s2v = gr.Button("Generate S2V", variant="primary")
349
+
350
+ with gr.Column(scale=2):
351
+ out_s2v = gr.Video(label="Result")
352
+ out_seed_s2v = gr.Number(label="Seed Used")
353
+
354
+ btn_s2v.click(generate_s2v_gpu, [s2v_img, s2v_audio, s2v_prompt, s2v_seed, s2v_rand], [out_s2v, out_seed_s2v])
355
 
356
  if __name__ == "__main__":
357
+ demo.queue().launch()