add AutoencoderKLWan
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import gradio as gr
|
|
| 6 |
import spaces
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
from diffusers import WanPipeline
|
| 10 |
from diffusers.utils import export_to_video, load_video
|
| 11 |
from vibt.wan import load_vibt_weight, encode_video
|
| 12 |
from vibt.scheduler import ViBTScheduler
|
|
@@ -23,12 +23,11 @@ def get_fps(path):
|
|
| 23 |
|
| 24 |
|
| 25 |
base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
| 26 |
-
|
| 27 |
-
base_model_id,
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
).to("cuda")
|
| 32 |
load_vibt_weight(
|
| 33 |
pipe.transformer,
|
| 34 |
"Yuanshi/ViBT",
|
|
|
|
| 6 |
import spaces
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
from diffusers import WanPipeline, AutoencoderKLWan
|
| 10 |
from diffusers.utils import export_to_video, load_video
|
| 11 |
from vibt.wan import load_vibt_weight, encode_video
|
| 12 |
from vibt.scheduler import ViBTScheduler
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
| 26 |
+
vae = AutoencoderKLWan.from_pretrained(
|
| 27 |
+
base_model_id, subfolder="vae", torch_dtype=torch.float32
|
| 28 |
+
)
|
| 29 |
+
pipe = WanPipeline.from_pretrained(base_model_id, vae=vae, torch_dtype=torch.bfloat16)
|
| 30 |
+
pipe.to("cuda")
|
|
|
|
| 31 |
load_vibt_weight(
|
| 32 |
pipe.transformer,
|
| 33 |
"Yuanshi/ViBT",
|