| import os | |
| from .edm_uncond import get_pretrained_sde_model | |
| from .latent_diff import get_pretrained_ldm_model, get_pretrained_conditioned_ldm_model | |
| from .condition_loader import RandomNumberIterator, UniformNumberIterator, TextFileIterator | |
| def prepare_stuff(args): | |
| if args.model == 'edm': | |
| prepare_model_fn = get_pretrained_sde_model | |
| elif args.model == 'latent_diff': | |
| prepare_model_fn = get_pretrained_ldm_model | |
| elif args.model == 'conditioned_latent_diff': | |
| prepare_model_fn = get_pretrained_conditioned_ldm_model | |
| else: | |
| raise NotImplementedError | |
| return prepare_model_fn(args) | |
| def prepare_condition_loader(model_type, model, scale, condition, sampling_batch_size, num_samples_per_class=50, num_prompt=5, num_samples_per_prompt=1): | |
| if model_type == 'edm' or model_type == 'latent_diff': | |
| return None | |
| if model_type == 'conditioned_latent_diff': | |
| if os.path.isfile(condition): | |
| return TextFileIterator(model, scale, condition, sampling_batch_size, num_prompt, num_samples_per_prompt) | |
| elif condition == 'random': | |
| return RandomNumberIterator(model, scale, sampling_batch_size) | |
| elif condition == 'uniform': | |
| return UniformNumberIterator(model, scale, sampling_batch_size, num_samples_per_class) | |
| else: | |
| raise NotImplementedError | |
| else: | |
| raise NotImplementedError |