LD3 / models /__init__.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import os
from .edm_uncond import get_pretrained_sde_model
from .latent_diff import get_pretrained_ldm_model, get_pretrained_conditioned_ldm_model
from .condition_loader import RandomNumberIterator, UniformNumberIterator, TextFileIterator
def prepare_stuff(args):
if args.model == 'edm':
prepare_model_fn = get_pretrained_sde_model
elif args.model == 'latent_diff':
prepare_model_fn = get_pretrained_ldm_model
elif args.model == 'conditioned_latent_diff':
prepare_model_fn = get_pretrained_conditioned_ldm_model
else:
raise NotImplementedError
return prepare_model_fn(args)
def prepare_condition_loader(model_type, model, scale, condition, sampling_batch_size, num_samples_per_class=50, num_prompt=5, num_samples_per_prompt=1):
if model_type == 'edm' or model_type == 'latent_diff':
return None
if model_type == 'conditioned_latent_diff':
if os.path.isfile(condition):
return TextFileIterator(model, scale, condition, sampling_batch_size, num_prompt, num_samples_per_prompt)
elif condition == 'random':
return RandomNumberIterator(model, scale, sampling_batch_size)
elif condition == 'uniform':
return UniformNumberIterator(model, scale, sampling_batch_size, num_samples_per_class)
else:
raise NotImplementedError
else:
raise NotImplementedError