File size: 413 Bytes
e8ee465
471e7be
cc9a515
 
463b1bd
e8ee465
 
b0d6809
463b1bd
e1aac6c
cc9a515
 
463b1bd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from diffusers import DiffusionPipeline

class EndpointHandler:
    def __init__(self, model_dir):
        self.pipe = DiffusionPipeline.from_pretrained(
            model_dir,
            custom_pipeline="pipeline_wan_i2v",
            torch_dtype=torch.float16
        ).to("cuda")

    def __call__(self, data):
        image = data["image"]
        output = self.pipe(image)
        return output