File size: 1,411 Bytes
d382778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os
from .edm_uncond import get_pretrained_sde_model
from .latent_diff import get_pretrained_ldm_model, get_pretrained_conditioned_ldm_model
from .condition_loader import RandomNumberIterator, UniformNumberIterator, TextFileIterator
def prepare_stuff(args):
if args.model == 'edm':
prepare_model_fn = get_pretrained_sde_model
elif args.model == 'latent_diff':
prepare_model_fn = get_pretrained_ldm_model
elif args.model == 'conditioned_latent_diff':
prepare_model_fn = get_pretrained_conditioned_ldm_model
else:
raise NotImplementedError
return prepare_model_fn(args)
def prepare_condition_loader(model_type, model, scale, condition, sampling_batch_size, num_samples_per_class=50, num_prompt=5, num_samples_per_prompt=1):
if model_type == 'edm' or model_type == 'latent_diff':
return None
if model_type == 'conditioned_latent_diff':
if os.path.isfile(condition):
return TextFileIterator(model, scale, condition, sampling_batch_size, num_prompt, num_samples_per_prompt)
elif condition == 'random':
return RandomNumberIterator(model, scale, sampling_batch_size)
elif condition == 'uniform':
return UniformNumberIterator(model, scale, sampling_batch_size, num_samples_per_class)
else:
raise NotImplementedError
else:
raise NotImplementedError |