Update custom_wan_pipeline.py
Browse files- custom_wan_pipeline.py +33 -46
custom_wan_pipeline.py
CHANGED
|
@@ -1,55 +1,42 @@
|
|
| 1 |
import torch
|
| 2 |
from diffusers import DiffusionPipeline
|
| 3 |
-
from diffusers.utils import
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
-
import tempfile
|
| 7 |
-
import os
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
class WanImageToVideoPipeline(DiffusionPipeline):
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
self.
|
| 19 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
-
self.to(self.device)
|
| 21 |
-
print(f"✅ Custom WAN 2.2 I2V pipeline initialized on {self.device}")
|
| 22 |
|
| 23 |
@torch.no_grad()
|
| 24 |
-
def __call__(self, image,
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
frames = (frames.clamp(-1, 1) + 1) / 2
|
| 48 |
-
frames = (frames * 255).round().byte().cpu().permute(0, 2, 3, 1).numpy()
|
| 49 |
-
pil_frames = [Image.fromarray(f) for f in frames]
|
| 50 |
-
|
| 51 |
-
tmpdir = tempfile.mkdtemp()
|
| 52 |
-
out_path = os.path.join(tmpdir, "wan2v_output.mp4")
|
| 53 |
-
export_to_video(pil_frames, out_path, fps=12)
|
| 54 |
-
print(f"🎬 Generated {len(pil_frames)} frames → {out_path}")
|
| 55 |
-
return {"frames": pil_frames, "video_path": out_path}
|
|
|
|
| 1 |
import torch
|
| 2 |
from diffusers import DiffusionPipeline
|
| 3 |
+
from diffusers.utils import logging
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
logger = logging.get_logger(__name__)
|
| 8 |
|
| 9 |
class WanImageToVideoPipeline(DiffusionPipeline):
|
| 10 |
+
def __init__(self, vae, transformer, scheduler, text_encoder, tokenizer, image_encoder):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.vae = vae
|
| 13 |
+
self.transformer = transformer
|
| 14 |
+
self.scheduler = scheduler
|
| 15 |
+
self.text_encoder = text_encoder
|
| 16 |
+
self.tokenizer = tokenizer
|
| 17 |
+
self.image_encoder = image_encoder
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
@torch.no_grad()
|
| 20 |
+
def __call__(self, image: Image.Image, prompt: str = "", num_frames: int = 16, num_inference_steps: int = 25):
|
| 21 |
+
logger.info("✅ Generating latent motion sequence...")
|
| 22 |
+
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0) / 255.0
|
| 23 |
+
image_tensor = image_tensor.to(self.device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
| 24 |
+
|
| 25 |
+
# Dummy latent sampling for demonstration
|
| 26 |
+
latents = self.vae.encode(image_tensor).latent_dist.sample() * 0.18215
|
| 27 |
+
latents = torch.randn_like(latents)
|
| 28 |
+
|
| 29 |
+
frames = []
|
| 30 |
+
for i in range(num_frames):
|
| 31 |
+
noise = torch.randn_like(latents)
|
| 32 |
+
frame = latents + 0.05 * i * noise
|
| 33 |
+
decoded = self.vae.decode(frame / 0.18215).sample
|
| 34 |
+
decoded = (decoded.clamp(-1, 1) + 1) / 2
|
| 35 |
+
frame_img = (decoded * 255).cpu().numpy().astype("uint8")[0].transpose(1, 2, 0)
|
| 36 |
+
frames.append(Image.fromarray(frame_img))
|
| 37 |
+
|
| 38 |
+
# Simple video assembly (you can later swap this for real motion)
|
| 39 |
+
import imageio
|
| 40 |
+
output_path = "output.mp4"
|
| 41 |
+
imageio.mimsave(output_path, frames, fps=12)
|
| 42 |
+
return type("Result", (), {"videos": [output_path]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|