File size: 2,894 Bytes
ed04cd9
 
 
 
 
72d7a72
ed04cd9
627335d
ed04cd9
 
 
 
72d7a72
 
 
 
 
ed04cd9
72d7a72
 
 
627335d
72d7a72
 
 
 
ed04cd9
 
 
 
 
72d7a72
 
 
 
 
 
 
627335d
72d7a72
 
 
 
ed04cd9
72d7a72
 
 
 
 
 
 
 
 
ed04cd9
 
627335d
 
ed04cd9
 
 
 
 
 
 
72d7a72
 
ed04cd9
 
72d7a72
 
 
ed04cd9
 
72d7a72
ed04cd9
 
 
72d7a72
ed04cd9
 
72d7a72
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import base64
import io
import os
import tempfile
import torch
from typing import Any, Dict
from PIL import Image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video

class EndpointHandler:
    def __init__(self, path: str = ""):
        # Use the MODEL_ID env var or default to the 5B TI2V model
        model_id = os.environ.get("MODEL_ID", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
        print(f"Loading Wan2.2-TI2V-5B from {model_id}...")
        
        dtype = torch.bfloat16
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # VAE in float32 for precision, rest in bfloat16 for speed/memory
        vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
        self.pipe = WanImageToVideoPipeline.from_pretrained(
            model_id, 
            vae=vae, 
            torch_dtype=dtype,
            device_map="auto"
        )
        self.device = device
        print("✓ Model loaded and ready")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs = data.get("inputs", data)
        
        # Decode start and end images
        start_img = self._decode_image(inputs["start_image"])
        end_img = self._decode_image(inputs["end_image"])
        
        prompt = inputs.get("prompt", "Smooth cinematic motion")
        num_frames = int(inputs.get("num_frames", 41))
        guidance = float(inputs.get("guidance_scale", 5.0))
        steps = int(inputs.get("num_inference_steps", 20))
        
        # Wan requires (4N + 1) frames
        num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)

        # Dimension snapping
        w, h = start_img.size
        width = (w // 32) * 32
        height = (h // 32) * 32
        
        start_img = start_img.resize((width, height))
        end_img = end_img.resize((width, height))

        with torch.inference_mode():
            output = self.pipe(
                image=start_img,
                last_image=end_img,
                prompt=prompt,
                height=height,
                width=width,
                num_frames=num_frames,
                guidance_scale=guidance,
                num_inference_steps=steps,
            ).frames[0]

        # Export video to bytes
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
            tmp_path = tmp.name
        
        export_to_video(output, tmp_path, fps=16)

        with open(tmp_path, "rb") as f:
            video_b64 = base64.b64encode(f.read()).decode("utf-8")

        os.unlink(tmp_path)
        return {"video": video_b64}

    def _decode_image(self, b64_str: str) -> Image.Image:
        if "," in b64_str:
            b64_str = b64_str.split(",", 1)[1]
        img_bytes = base64.b64decode(b64_str)
        return Image.open(io.BytesIO(img_bytes)).convert("RGB")