Spaces:
Running on Zero
Running on Zero
File size: 1,264 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | """Functional interface for samplers."""
import torch
from src.sample.BaseSampler import DPMPPSDESampler, EulerSampler, EulerAncestralSampler, DPMPP2MSampler
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = DPMPPSDESampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = EulerSampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = EulerAncestralSampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
sampler = DPMPP2MSampler(pipeline=pipeline, **kwargs)
return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)
|