Tsmith2024's picture
Upload handler.py with huggingface_hub
72d7a72 verified
import base64
import io
import os
import tempfile
import torch
from typing import Any, Dict
from PIL import Image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video
class EndpointHandler:
def __init__(self, path: str = ""):
# Use the MODEL_ID env var or default to the 5B TI2V model
model_id = os.environ.get("MODEL_ID", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
print(f"Loading Wan2.2-TI2V-5B from {model_id}...")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# VAE in float32 for precision, rest in bfloat16 for speed/memory
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
self.pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=dtype,
device_map="auto"
)
self.device = device
print("✓ Model loaded and ready")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
# Decode start and end images
start_img = self._decode_image(inputs["start_image"])
end_img = self._decode_image(inputs["end_image"])
prompt = inputs.get("prompt", "Smooth cinematic motion")
num_frames = int(inputs.get("num_frames", 41))
guidance = float(inputs.get("guidance_scale", 5.0))
steps = int(inputs.get("num_inference_steps", 20))
# Wan requires (4N + 1) frames
num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)
# Dimension snapping
w, h = start_img.size
width = (w // 32) * 32
height = (h // 32) * 32
start_img = start_img.resize((width, height))
end_img = end_img.resize((width, height))
with torch.inference_mode():
output = self.pipe(
image=start_img,
last_image=end_img,
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance,
num_inference_steps=steps,
).frames[0]
# Export video to bytes
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp_path = tmp.name
export_to_video(output, tmp_path, fps=16)
with open(tmp_path, "rb") as f:
video_b64 = base64.b64encode(f.read()).decode("utf-8")
os.unlink(tmp_path)
return {"video": video_b64}
def _decode_image(self, b64_str: str) -> Image.Image:
if "," in b64_str:
b64_str = b64_str.split(",", 1)[1]
img_bytes = base64.b64decode(b64_str)
return Image.open(io.BytesIO(img_bytes)).convert("RGB")