LightDiffusion-Next / src /sample /samplers.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Functional interface for samplers."""
import torch
from src.sample.BaseSampler import DPMPPSDESampler, EulerSampler, EulerAncestralSampler, DPMPP2MSampler
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = DPMPPSDESampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = EulerSampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = EulerAncestralSampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = DPMPP2MSampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)