File size: 413 Bytes
e8ee465 471e7be cc9a515 463b1bd e8ee465 b0d6809 463b1bd e1aac6c cc9a515 463b1bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from diffusers import DiffusionPipeline
class EndpointHandler:
def __init__(self, model_dir):
self.pipe = DiffusionPipeline.from_pretrained(
model_dir,
custom_pipeline="pipeline_wan_i2v",
torch_dtype=torch.float16
).to("cuda")
def __call__(self, data):
image = data["image"]
output = self.pipe(image)
return output |