Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,23 +10,19 @@ from PIL import Image
|
|
| 10 |
import numpy as np
|
| 11 |
import cv2
|
| 12 |
import tempfile
|
| 13 |
-
from diffusers.utils import export_to_video #
|
| 14 |
|
| 15 |
class WanAnimateApp:
|
| 16 |
def __init__(self):
|
| 17 |
model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
|
| 18 |
self.pipe = StableVideoDiffusionPipeline.from_pretrained(
|
| 19 |
model_name,
|
| 20 |
-
torch_dtype=torch.
|
| 21 |
-
variant="fp16",
|
| 22 |
-
device_map="
|
| 23 |
)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
self.pipe.enable_model_cpu_offload() # Оптимизация памяти
|
| 27 |
-
self.pipe.enable_vae_slicing()
|
| 28 |
-
self.pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
| 29 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 30 |
|
| 31 |
def predict(self, ref_img, video, model_id, model):
|
| 32 |
if ref_img is None or video is None:
|
|
@@ -39,33 +35,33 @@ class WanAnimateApp:
|
|
| 39 |
else:
|
| 40 |
ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
|
| 41 |
|
| 42 |
-
# Извлечение motion из видео
|
| 43 |
cap = cv2.VideoCapture(video)
|
| 44 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 45 |
cap.release()
|
| 46 |
motion_hint = f" with dynamic motion from {frame_count} frames"
|
| 47 |
|
| 48 |
-
# Параметры
|
| 49 |
num_frames = 25 if model == "wan-pro" else 14
|
| 50 |
num_steps = 25 if model == "wan-pro" else 15
|
| 51 |
|
| 52 |
-
# Адаптация
|
| 53 |
-
noise_aug_strength = 0.02
|
| 54 |
-
if model_id == "wan2.2-animate-mix":
|
| 55 |
noise_aug_strength = 0.1
|
| 56 |
|
| 57 |
-
# Генерация
|
| 58 |
-
generator = torch.Generator(device="
|
| 59 |
output = self.pipe(
|
| 60 |
ref_image,
|
| 61 |
num_inference_steps=num_steps,
|
| 62 |
num_frames=num_frames,
|
| 63 |
generator=generator,
|
| 64 |
decode_chunk_size=2,
|
| 65 |
-
noise_aug_strength=noise_aug_strength
|
| 66 |
).frames[0]
|
| 67 |
|
| 68 |
-
# Экспорт видео
|
| 69 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
|
| 70 |
export_to_video(output, temp_video.name, fps=7)
|
| 71 |
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import cv2
|
| 12 |
import tempfile
|
| 13 |
+
from diffusers.utils import export_to_video # Для экспорта видео
|
| 14 |
|
| 15 |
class WanAnimateApp:
|
| 16 |
def __init__(self):
|
| 17 |
model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
|
| 18 |
self.pipe = StableVideoDiffusionPipeline.from_pretrained(
|
| 19 |
model_name,
|
| 20 |
+
torch_dtype=torch.float32, # Для CPU — FP32
|
| 21 |
+
variant="fp16", # Вариант остаётся, но dtype переопределяем
|
| 22 |
+
device_map="cpu" # Явно CPU, чтобы избежать ошибок
|
| 23 |
)
|
| 24 |
+
# Нет CUDA-оптимизаций, так как на CPU они не нужны/не работают
|
| 25 |
+
self.pipe.enable_vae_slicing() # Оптимизация для памяти (работает на CPU)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def predict(self, ref_img, video, model_id, model):
|
| 28 |
if ref_img is None or video is None:
|
|
|
|
| 35 |
else:
|
| 36 |
ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
|
| 37 |
|
| 38 |
+
# Извлечение motion из видео
|
| 39 |
cap = cv2.VideoCapture(video)
|
| 40 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 41 |
cap.release()
|
| 42 |
motion_hint = f" with dynamic motion from {frame_count} frames"
|
| 43 |
|
| 44 |
+
# Параметры
|
| 45 |
num_frames = 25 if model == "wan-pro" else 14
|
| 46 |
num_steps = 25 if model == "wan-pro" else 15
|
| 47 |
|
| 48 |
+
# Адаптация modes (имитация с SVD)
|
| 49 |
+
noise_aug_strength = 0.02
|
| 50 |
+
if model_id == "wan2.2-animate-mix":
|
| 51 |
noise_aug_strength = 0.1
|
| 52 |
|
| 53 |
+
# Генерация (на CPU)
|
| 54 |
+
generator = torch.Generator(device="cpu").manual_seed(42)
|
| 55 |
output = self.pipe(
|
| 56 |
ref_image,
|
| 57 |
num_inference_steps=num_steps,
|
| 58 |
num_frames=num_frames,
|
| 59 |
generator=generator,
|
| 60 |
decode_chunk_size=2,
|
| 61 |
+
noise_aug_strength=noise_aug_strength
|
| 62 |
).frames[0]
|
| 63 |
|
| 64 |
+
# Экспорт видео
|
| 65 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
|
| 66 |
export_to_video(output, temp_video.name, fps=7)
|
| 67 |
|