matthewkram commited on
Commit
d8a502c
·
verified ·
1 Parent(s): 682984a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -10,23 +10,19 @@ from PIL import Image
10
  import numpy as np
11
  import cv2
12
  import tempfile
13
- from diffusers.utils import export_to_video # Добавил для экспорта видео
14
 
15
  class WanAnimateApp:
16
  def __init__(self):
17
  model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
18
  self.pipe = StableVideoDiffusionPipeline.from_pretrained(
19
  model_name,
20
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
21
- variant="fp16",
22
- device_map="auto" # Изменил на "auto" для GPU на HF
23
  )
24
- if torch.cuda.is_available():
25
- self.pipe.to("cuda")
26
- self.pipe.enable_model_cpu_offload() # Оптимизация памяти
27
- self.pipe.enable_vae_slicing()
28
- self.pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
29
- torch.backends.cuda.matmul.allow_tf32 = True
30
 
31
  def predict(self, ref_img, video, model_id, model):
32
  if ref_img is None or video is None:
@@ -39,33 +35,33 @@ class WanAnimateApp:
39
  else:
40
  ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
41
 
42
- # Извлечение motion из видео (frame_count для hint)
43
  cap = cv2.VideoCapture(video)
44
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
  cap.release()
46
  motion_hint = f" with dynamic motion from {frame_count} frames"
47
 
48
- # Параметры на основе model
49
  num_frames = 25 if model == "wan-pro" else 14
50
  num_steps = 25 if model == "wan-pro" else 15
51
 
52
- # Адаптация для modes (имитация Wan2.2 с SVD)
53
- noise_aug_strength = 0.02 # Базовый шум
54
- if model_id == "wan2.2-animate-mix": # Для "mix" — больше шума для "замены"
55
  noise_aug_strength = 0.1
56
 
57
- # Генерация
58
- generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(42)
59
  output = self.pipe(
60
  ref_image,
61
  num_inference_steps=num_steps,
62
  num_frames=num_frames,
63
  generator=generator,
64
  decode_chunk_size=2,
65
- noise_aug_strength=noise_aug_strength # Добавил для вариации
66
  ).frames[0]
67
 
68
- # Экспорт видео (без subprocess, используем export_to_video)
69
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
70
  export_to_video(output, temp_video.name, fps=7)
71
 
 
10
  import numpy as np
11
  import cv2
12
  import tempfile
13
+ from diffusers.utils import export_to_video # Для экспорта видео
14
 
15
  class WanAnimateApp:
16
  def __init__(self):
17
  model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
18
  self.pipe = StableVideoDiffusionPipeline.from_pretrained(
19
  model_name,
20
+ torch_dtype=torch.float32, # Для CPU FP32
21
+ variant="fp16", # Вариант остаётся, но dtype переопределяем
22
+ device_map="cpu" # Явно CPU, чтобы избежать ошибок
23
  )
24
+ # Нет CUDA-оптимизаций, так как на CPU они не нужны/не работают
25
+ self.pipe.enable_vae_slicing() # Оптимизация для памяти (работает на CPU)
 
 
 
 
26
 
27
  def predict(self, ref_img, video, model_id, model):
28
  if ref_img is None or video is None:
 
35
  else:
36
  ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
37
 
38
+ # Извлечение motion из видео
39
  cap = cv2.VideoCapture(video)
40
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
41
  cap.release()
42
  motion_hint = f" with dynamic motion from {frame_count} frames"
43
 
44
+ # Параметры
45
  num_frames = 25 if model == "wan-pro" else 14
46
  num_steps = 25 if model == "wan-pro" else 15
47
 
48
+ # Адаптация modes (имитация с SVD)
49
+ noise_aug_strength = 0.02
50
+ if model_id == "wan2.2-animate-mix":
51
  noise_aug_strength = 0.1
52
 
53
+ # Генерация (на CPU)
54
+ generator = torch.Generator(device="cpu").manual_seed(42)
55
  output = self.pipe(
56
  ref_image,
57
  num_inference_steps=num_steps,
58
  num_frames=num_frames,
59
  generator=generator,
60
  decode_chunk_size=2,
61
+ noise_aug_strength=noise_aug_strength
62
  ).frames[0]
63
 
64
+ # Экспорт видео
65
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
66
  export_to_video(output, temp_video.name, fps=7)
67