Tsmith2024 commited on
Commit
627335d
·
verified ·
1 Parent(s): 90b1f7c

Fix: use WanImageToVideoPipeline not WanPipeline

Browse files
Files changed (1) hide show
  1. handler.py +18 -83
handler.py CHANGED
@@ -1,107 +1,51 @@
1
- """
2
- HuggingFace Inference Endpoint handler for Wan2.2-TI2V-5B
3
- Accepts first + last frame images, returns interpolated video.
4
-
5
- Input JSON:
6
- {
7
- "inputs": {
8
- "start_image": "<base64 png>",
9
- "end_image": "<base64 png>",
10
- "prompt": "...",
11
- "num_frames": 41,
12
- "guidance_scale": 5.0,
13
- "num_inference_steps": 20
14
- }
15
- }
16
-
17
- Output JSON:
18
- { "video": "<base64 mp4>" }
19
- """
20
-
21
  import base64
22
  import io
23
  import os
24
  import tempfile
25
  from typing import Any, Dict
26
 
27
- import numpy as np
28
  import torch
29
  from PIL import Image
30
- from diffusers import WanPipeline, AutoencoderKLWan
31
  from diffusers.utils import export_to_video
32
 
33
 
34
  class EndpointHandler:
35
  def __init__(self, path: str = ""):
36
- """Load Wan2.2-TI2V-5B from /repository (HF mounts model here)."""
37
-
38
- model_path = path or "/repository"
39
  print(f"Loading Wan2.2-TI2V-5B from {model_path}…")
40
-
41
- dtype = torch.bfloat16
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
-
44
- # VAE in float32 for better decoding quality
45
  vae = AutoencoderKLWan.from_pretrained(
46
- model_path,
47
- subfolder="vae",
48
- torch_dtype=torch.float32,
49
  )
50
-
51
- self.pipe = WanPipeline.from_pretrained(
52
- model_path,
53
- vae=vae,
54
- torch_dtype=dtype,
55
  )
56
  self.pipe.to(device)
57
-
58
- # Memory optimisation — helps on 24GB GPUs
59
  self.pipe.enable_attention_slicing()
60
-
61
  self.device = device
62
  print("✓ Model loaded and ready")
63
 
64
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
65
- """
66
- Called on every request.
67
- data = { "inputs": { "start_image": b64, "end_image": b64, "prompt": str, ... } }
68
- """
69
- inputs = data.get("inputs", data) # handle both wrapped and unwrapped
70
-
71
- # Decode images
72
- start_img = self._decode_image(inputs["start_image"])
73
- end_img = self._decode_image(inputs["end_image"])
74
-
75
  prompt = inputs.get("prompt", "Smooth cinematic motion, natural movement")
76
- num_frames = int(inputs.get("num_frames", 41)) # must be 4N+1
77
  guidance = float(inputs.get("guidance_scale", 5.0))
78
  steps = int(inputs.get("num_inference_steps", 20))
79
  fps = int(inputs.get("fps", 16))
80
-
81
- # Ensure num_frames follows 4N+1 pattern
82
  num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)
83
-
84
- # Size from input image (snap to multiples of 32)
85
- w, h = start_img.size
86
- width = (w // 32) * 32
87
- height = (h // 32) * 32
88
-
89
- # Build first+last frame conditioning using TI2V mask approach
90
- # First frame = start_img, last frame = end_img, middle = grey
91
- frames = [start_img.resize((width, height))]
92
- grey = Image.new("RGB", (width, height), (128, 128, 128))
93
- frames.extend([grey] * (num_frames - 2))
94
- frames.append(end_img.resize((width, height)))
95
-
96
- # Mask: 0 = conditioned (first/last), 1 = free generation (middle)
97
- mask_black = Image.new("L", (width, height), 0)
98
- mask_white = Image.new("L", (width, height), 255)
99
- mask = [mask_black] + [mask_white] * (num_frames - 2) + [mask_black]
100
-
101
  with torch.inference_mode():
102
  output = self.pipe(
103
- image=frames,
104
- mask=mask,
105
  prompt=prompt,
106
  negative_prompt="",
107
  height=height,
@@ -110,25 +54,16 @@ class EndpointHandler:
110
  guidance_scale=guidance,
111
  num_inference_steps=steps,
112
  ).frames[0]
113
-
114
- # Export to temp MP4 and encode as base64
115
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
116
  tmp_path = tmp.name
117
-
118
  export_to_video(output, tmp_path, fps=fps)
119
-
120
  with open(tmp_path, "rb") as f:
121
  video_b64 = base64.b64encode(f.read()).decode("utf-8")
122
-
123
  os.unlink(tmp_path)
124
-
125
  return {"video": video_b64}
126
 
127
  @staticmethod
128
  def _decode_image(b64_str: str) -> Image.Image:
129
- """Decode base64 string to PIL Image."""
130
- # Strip data URI prefix if present
131
  if "," in b64_str:
132
  b64_str = b64_str.split(",", 1)[1]
133
- img_bytes = base64.b64decode(b64_str)
134
- return Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  import io
3
  import os
4
  import tempfile
5
  from typing import Any, Dict
6
 
 
7
  import torch
8
  from PIL import Image
9
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
10
  from diffusers.utils import export_to_video
11
 
12
 
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
+ model_path = path or os.environ.get("MODEL_ID", "/repository")
 
 
16
  print(f"Loading Wan2.2-TI2V-5B from {model_path}…")
17
+ dtype = torch.bfloat16
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
19
  vae = AutoencoderKLWan.from_pretrained(
20
+ model_path, subfolder="vae", torch_dtype=torch.float32,
 
 
21
  )
22
+ self.pipe = WanImageToVideoPipeline.from_pretrained(
23
+ model_path, vae=vae, torch_dtype=dtype,
 
 
 
24
  )
25
  self.pipe.to(device)
 
 
26
  self.pipe.enable_attention_slicing()
 
27
  self.device = device
28
  print("✓ Model loaded and ready")
29
 
30
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
+ inputs = data.get("inputs", data)
32
+ start_img = self._decode_image(inputs["start_image"])
33
+ end_img = self._decode_image(inputs["end_image"])
 
 
 
 
 
 
 
34
  prompt = inputs.get("prompt", "Smooth cinematic motion, natural movement")
35
+ num_frames = int(inputs.get("num_frames", 41))
36
  guidance = float(inputs.get("guidance_scale", 5.0))
37
  steps = int(inputs.get("num_inference_steps", 20))
38
  fps = int(inputs.get("fps", 16))
 
 
39
  num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)
40
+ w, h = start_img.size
41
+ width = (w // 32) * 32
42
+ height = (h // 32) * 32
43
+ start_img = start_img.resize((width, height))
44
+ end_img = end_img.resize((width, height))
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  with torch.inference_mode():
46
  output = self.pipe(
47
+ image=start_img,
48
+ last_image=end_img,
49
  prompt=prompt,
50
  negative_prompt="",
51
  height=height,
 
54
  guidance_scale=guidance,
55
  num_inference_steps=steps,
56
  ).frames[0]
 
 
57
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
58
  tmp_path = tmp.name
 
59
  export_to_video(output, tmp_path, fps=fps)
 
60
  with open(tmp_path, "rb") as f:
61
  video_b64 = base64.b64encode(f.read()).decode("utf-8")
 
62
  os.unlink(tmp_path)
 
63
  return {"video": video_b64}
64
 
65
  @staticmethod
66
  def _decode_image(b64_str: str) -> Image.Image:
 
 
67
  if "," in b64_str:
68
  b64_str = b64_str.split(",", 1)[1]
69
+ return Image.open(io.BytesIO(base64.b64decode(b64_str))).convert("RGB")