| import base64 |
| import io |
| import os |
| import tempfile |
| import torch |
| from typing import Any, Dict |
| from PIL import Image |
| from diffusers import AutoencoderKLWan, WanImageToVideoPipeline |
| from diffusers.utils import export_to_video |
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| |
| model_id = os.environ.get("MODEL_ID", "Wan-AI/Wan2.2-TI2V-5B-Diffusers") |
| print(f"Loading Wan2.2-TI2V-5B from {model_id}...") |
| |
| dtype = torch.bfloat16 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) |
| self.pipe = WanImageToVideoPipeline.from_pretrained( |
| model_id, |
| vae=vae, |
| torch_dtype=dtype, |
| device_map="auto" |
| ) |
| self.device = device |
| print("✓ Model loaded and ready") |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| inputs = data.get("inputs", data) |
| |
| |
| start_img = self._decode_image(inputs["start_image"]) |
| end_img = self._decode_image(inputs["end_image"]) |
| |
| prompt = inputs.get("prompt", "Smooth cinematic motion") |
| num_frames = int(inputs.get("num_frames", 41)) |
| guidance = float(inputs.get("guidance_scale", 5.0)) |
| steps = int(inputs.get("num_inference_steps", 20)) |
| |
| |
| num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1) |
|
|
| |
| w, h = start_img.size |
| width = (w // 32) * 32 |
| height = (h // 32) * 32 |
| |
| start_img = start_img.resize((width, height)) |
| end_img = end_img.resize((width, height)) |
|
|
| with torch.inference_mode(): |
| output = self.pipe( |
| image=start_img, |
| last_image=end_img, |
| prompt=prompt, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| guidance_scale=guidance, |
| num_inference_steps=steps, |
| ).frames[0] |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp_path = tmp.name |
| |
| export_to_video(output, tmp_path, fps=16) |
|
|
| with open(tmp_path, "rb") as f: |
| video_b64 = base64.b64encode(f.read()).decode("utf-8") |
|
|
| os.unlink(tmp_path) |
| return {"video": video_b64} |
|
|
| def _decode_image(self, b64_str: str) -> Image.Image: |
| if "," in b64_str: |
| b64_str = b64_str.split(",", 1)[1] |
| img_bytes = base64.b64decode(b64_str) |
| return Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|