File size: 2,894 Bytes
ed04cd9 72d7a72 ed04cd9 627335d ed04cd9 72d7a72 ed04cd9 72d7a72 627335d 72d7a72 ed04cd9 72d7a72 627335d 72d7a72 ed04cd9 72d7a72 ed04cd9 627335d ed04cd9 72d7a72 ed04cd9 72d7a72 ed04cd9 72d7a72 ed04cd9 72d7a72 ed04cd9 72d7a72 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | import base64
import io
import os
import tempfile
import torch
from typing import Any, Dict
from PIL import Image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video
class EndpointHandler:
def __init__(self, path: str = ""):
# Use the MODEL_ID env var or default to the 5B TI2V model
model_id = os.environ.get("MODEL_ID", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
print(f"Loading Wan2.2-TI2V-5B from {model_id}...")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# VAE in float32 for precision, rest in bfloat16 for speed/memory
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
self.pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=dtype,
device_map="auto"
)
self.device = device
print("✓ Model loaded and ready")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
# Decode start and end images
start_img = self._decode_image(inputs["start_image"])
end_img = self._decode_image(inputs["end_image"])
prompt = inputs.get("prompt", "Smooth cinematic motion")
num_frames = int(inputs.get("num_frames", 41))
guidance = float(inputs.get("guidance_scale", 5.0))
steps = int(inputs.get("num_inference_steps", 20))
# Wan requires (4N + 1) frames
num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)
# Dimension snapping
w, h = start_img.size
width = (w // 32) * 32
height = (h // 32) * 32
start_img = start_img.resize((width, height))
end_img = end_img.resize((width, height))
with torch.inference_mode():
output = self.pipe(
image=start_img,
last_image=end_img,
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance,
num_inference_steps=steps,
).frames[0]
# Export video to bytes
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp_path = tmp.name
export_to_video(output, tmp_path, fps=16)
with open(tmp_path, "rb") as f:
video_b64 = base64.b64encode(f.read()).decode("utf-8")
os.unlink(tmp_path)
return {"video": video_b64}
def _decode_image(self, b64_str: str) -> Image.Image:
if "," in b64_str:
b64_str = b64_str.split(",", 1)[1]
img_bytes = base64.b64decode(b64_str)
return Image.open(io.BytesIO(img_bytes)).convert("RGB")
|