File size: 1,411 Bytes
d382778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os 
from .edm_uncond import get_pretrained_sde_model
from .latent_diff import get_pretrained_ldm_model, get_pretrained_conditioned_ldm_model
from .condition_loader import RandomNumberIterator, UniformNumberIterator, TextFileIterator
def prepare_stuff(args):
    if args.model == 'edm':
        prepare_model_fn = get_pretrained_sde_model
    elif args.model == 'latent_diff':
        prepare_model_fn = get_pretrained_ldm_model
    elif args.model == 'conditioned_latent_diff':
        prepare_model_fn = get_pretrained_conditioned_ldm_model
    else:
        raise NotImplementedError 
    return prepare_model_fn(args)

def prepare_condition_loader(model_type, model, scale, condition, sampling_batch_size, num_samples_per_class=50, num_prompt=5, num_samples_per_prompt=1):
    if model_type == 'edm' or model_type == 'latent_diff':
        return None

    if model_type == 'conditioned_latent_diff':
        if os.path.isfile(condition):
            return TextFileIterator(model, scale, condition, sampling_batch_size, num_prompt, num_samples_per_prompt)
        elif condition == 'random':
            return RandomNumberIterator(model, scale, sampling_batch_size)
        elif condition == 'uniform':
            return UniformNumberIterator(model, scale, sampling_batch_size, num_samples_per_class)
        else:
            raise NotImplementedError  
    else:
        raise NotImplementedError