Yuanshi commited on
Commit
bda8013
·
1 Parent(s): f2b3d29

add AutoencoderKLWan

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import spaces
7
 
8
  import torch
9
- from diffusers import WanPipeline
10
  from diffusers.utils import export_to_video, load_video
11
  from vibt.wan import load_vibt_weight, encode_video
12
  from vibt.scheduler import ViBTScheduler
@@ -23,12 +23,11 @@ def get_fps(path):
23
 
24
 
25
  base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
26
- pipe = WanPipeline.from_pretrained(
27
- base_model_id,
28
- torch_dtype=torch.bfloat16,
29
- low_cpu_mem_usage=True,
30
- keep_in_fp32_modules=False, # <-- 或者你也可以强制关掉
31
- ).to("cuda")
32
  load_vibt_weight(
33
  pipe.transformer,
34
  "Yuanshi/ViBT",
 
6
  import spaces
7
 
8
  import torch
9
+ from diffusers import WanPipeline, AutoencoderKLWan
10
  from diffusers.utils import export_to_video, load_video
11
  from vibt.wan import load_vibt_weight, encode_video
12
  from vibt.scheduler import ViBTScheduler
 
23
 
24
 
25
  base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
26
+ vae = AutoencoderKLWan.from_pretrained(
27
+ base_model_id, subfolder="vae", torch_dtype=torch.float32
28
+ )
29
+ pipe = WanPipeline.from_pretrained(base_model_id, vae=vae, torch_dtype=torch.bfloat16)
30
+ pipe.to("cuda")
 
31
  load_vibt_weight(
32
  pipe.transformer,
33
  "Yuanshi/ViBT",