Gjm1234 commited on
Commit
f6448f0
·
verified ·
1 Parent(s): 50b2c3d

Update custom_wan_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_wan_pipeline.py +33 -46
custom_wan_pipeline.py CHANGED
@@ -1,55 +1,42 @@
1
  import torch
2
  from diffusers import DiffusionPipeline
3
- from diffusers.utils import export_to_video
4
  from PIL import Image
5
  import numpy as np
6
- import tempfile
7
- import os
8
 
 
9
 
10
  class WanImageToVideoPipeline(DiffusionPipeline):
11
- """
12
- Custom WAN 2.2 I2V pipeline – converts a single still image into a short animated clip.
13
- """
14
-
15
- def __init__(self, *args, **kwargs):
16
- # Accepts both positional and keyword args properly
17
- super().__init__(*args, **kwargs)
18
- self.__dict__.update(kwargs)
19
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- self.to(self.device)
21
- print(f"✅ Custom WAN 2.2 I2V pipeline initialized on {self.device}")
22
 
23
  @torch.no_grad()
24
- def __call__(self, image, num_inference_steps=25, motion_scale=1.0, guidance_scale=7.5):
25
- if image is None:
26
- raise ValueError("No image provided for video generation.")
27
-
28
- self.scheduler.set_timesteps(num_inference_steps)
29
-
30
- if isinstance(image, Image.Image):
31
- arr = np.array(image.convert("RGB")).astype(np.float32) / 255.0
32
- tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
33
- else:
34
- tensor = image
35
- tensor = tensor.to(self.device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
36
-
37
- latents = tensor.clone()
38
- if hasattr(self, "image_encoder") and self.image_encoder is not None:
39
- latents = self.image_encoder(tensor).to(self.device)
40
- latents = latents * motion_scale
41
-
42
- for t in self.scheduler.timesteps:
43
- noise_pred = self.transformer(latents, t)
44
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
45
-
46
- frames = self.vae.decode(latents / 0.18215).sample
47
- frames = (frames.clamp(-1, 1) + 1) / 2
48
- frames = (frames * 255).round().byte().cpu().permute(0, 2, 3, 1).numpy()
49
- pil_frames = [Image.fromarray(f) for f in frames]
50
-
51
- tmpdir = tempfile.mkdtemp()
52
- out_path = os.path.join(tmpdir, "wan2v_output.mp4")
53
- export_to_video(pil_frames, out_path, fps=12)
54
- print(f"🎬 Generated {len(pil_frames)} frames → {out_path}")
55
- return {"frames": pil_frames, "video_path": out_path}
 
1
  import torch
2
  from diffusers import DiffusionPipeline
3
+ from diffusers.utils import logging
4
  from PIL import Image
5
  import numpy as np
 
 
6
 
7
+ logger = logging.get_logger(__name__)
8
 
9
  class WanImageToVideoPipeline(DiffusionPipeline):
10
+ def __init__(self, vae, transformer, scheduler, text_encoder, tokenizer, image_encoder):
11
+ super().__init__()
12
+ self.vae = vae
13
+ self.transformer = transformer
14
+ self.scheduler = scheduler
15
+ self.text_encoder = text_encoder
16
+ self.tokenizer = tokenizer
17
+ self.image_encoder = image_encoder
 
 
 
18
 
19
  @torch.no_grad()
20
+ def __call__(self, image: Image.Image, prompt: str = "", num_frames: int = 16, num_inference_steps: int = 25):
21
+ logger.info("✅ Generating latent motion sequence...")
22
+ image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0) / 255.0
23
+ image_tensor = image_tensor.to(self.device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
24
+
25
+ # Dummy latent sampling for demonstration
26
+ latents = self.vae.encode(image_tensor).latent_dist.sample() * 0.18215
27
+ latents = torch.randn_like(latents)
28
+
29
+ frames = []
30
+ for i in range(num_frames):
31
+ noise = torch.randn_like(latents)
32
+ frame = latents + 0.05 * i * noise
33
+ decoded = self.vae.decode(frame / 0.18215).sample
34
+ decoded = (decoded.clamp(-1, 1) + 1) / 2
35
+ frame_img = (decoded * 255).cpu().numpy().astype("uint8")[0].transpose(1, 2, 0)
36
+ frames.append(Image.fromarray(frame_img))
37
+
38
+ # Simple video assembly (you can later swap this for real motion)
39
+ import imageio
40
+ output_path = "output.mp4"
41
+ imageio.mimsave(output_path, frames, fps=12)
42
+ return type("Result", (), {"videos": [output_path]})