Gjm1234's picture
Update handler.py
463b1bd verified
import torch
from diffusers import DiffusionPipeline
class EndpointHandler:
def __init__(self, model_dir):
self.pipe = DiffusionPipeline.from_pretrained(
model_dir,
custom_pipeline="pipeline_wan_i2v",
torch_dtype=torch.float16
).to("cuda")
def __call__(self, data):
image = data["image"]
output = self.pipe(image)
return output