File size: 1,264 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""Functional interface for samplers."""
import torch
from src.sample.BaseSampler import DPMPPSDESampler, EulerSampler, EulerAncestralSampler, DPMPP2MSampler

def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
    sampler = DPMPPSDESampler(pipeline=pipeline, **kwargs)
    return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)

def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
    sampler = EulerSampler(pipeline=pipeline, **kwargs)
    return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)

def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
    sampler = EulerAncestralSampler(pipeline=pipeline, **kwargs)
    return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)

def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, pipeline=False, **kwargs):
    sampler = DPMPP2MSampler(pipeline=pipeline, **kwargs)
    return sampler.sample(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable)