import base64 import io import os import tempfile import torch from typing import Any, Dict from PIL import Image from diffusers import AutoencoderKLWan, WanImageToVideoPipeline from diffusers.utils import export_to_video class EndpointHandler: def __init__(self, path: str = ""): # Use the MODEL_ID env var or default to the 5B TI2V model model_id = os.environ.get("MODEL_ID", "Wan-AI/Wan2.2-TI2V-5B-Diffusers") print(f"Loading Wan2.2-TI2V-5B from {model_id}...") dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # VAE in float32 for precision, rest in bfloat16 for speed/memory vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) self.pipe = WanImageToVideoPipeline.from_pretrained( model_id, vae=vae, torch_dtype=dtype, device_map="auto" ) self.device = device print("✓ Model loaded and ready") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", data) # Decode start and end images start_img = self._decode_image(inputs["start_image"]) end_img = self._decode_image(inputs["end_image"]) prompt = inputs.get("prompt", "Smooth cinematic motion") num_frames = int(inputs.get("num_frames", 41)) guidance = float(inputs.get("guidance_scale", 5.0)) steps = int(inputs.get("num_inference_steps", 20)) # Wan requires (4N + 1) frames num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1) # Dimension snapping w, h = start_img.size width = (w // 32) * 32 height = (h // 32) * 32 start_img = start_img.resize((width, height)) end_img = end_img.resize((width, height)) with torch.inference_mode(): output = self.pipe( image=start_img, last_image=end_img, prompt=prompt, height=height, width=width, num_frames=num_frames, guidance_scale=guidance, num_inference_steps=steps, ).frames[0] # Export video to bytes with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: tmp_path = tmp.name export_to_video(output, tmp_path, fps=16) with open(tmp_path, "rb") as f: video_b64 = base64.b64encode(f.read()).decode("utf-8") os.unlink(tmp_path) return {"video": video_b64} def _decode_image(self, b64_str: str) -> Image.Image: if "," in b64_str: b64_str = b64_str.split(",", 1)[1] img_bytes = base64.b64decode(b64_str) return Image.open(io.BytesIO(img_bytes)).convert("RGB")