File size: 4,970 Bytes
ab2369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import sys

import torch
import os
import json
import argparse
sys.path.append(os.getcwd())

from diffusers import StableDiffusionPipeline, DPMSolverMultistepLMScheduler, DDIMLMScheduler, PNDMScheduler, UniPCMultistepScheduler

from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler

def main():
    parser = argparse.ArgumentParser(description="sampling script for COCO14.")
    parser.add_argument('--test_num', type=int, default=1000)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_inference_steps', type=int, default=20)
    parser.add_argument('--guidance', type=float, default=7.5)
    parser.add_argument('--sampler_type', type = str, default='ddim')
    parser.add_argument('--model_id', type=str, default='/xxx/xxx/stable-diffusion-v1-5')
    parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
    parser.add_argument('--lamb', type=float, default=5.0)
    parser.add_argument('--kappa', type=float, default=0.0)
    parser.add_argument('--device', type=str, default='cuda')


    args = parser.parse_args()

    start_index = args.start_index
    sampler_type = args.sampler_type
    test_num = args.test_num
    guidance_scale = args.guidance
    num_inference_steps = args.num_inference_steps
    lamb = args.lamb
    kappa = args.kappa
    device = args.device
    model_id = args.model_id


    # load model
    sd_pipe = None

    sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None)
    sd_pipe = sd_pipe.to(device)
    print("sd model loaded")
    

    if sampler_type in ['dpm_lm']:
        sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.config.solver_order = 3
        sd_pipe.scheduler.config.algorithm_type = "dpmsolver"
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = True
    elif sampler_type in ['dpm']:
        sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.config.solver_order = 3
        sd_pipe.scheduler.config.algorithm_type = "dpmsolver"
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = False
    elif sampler_type in ['dpm++']:
        sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.config.solver_order = 3
        sd_pipe.scheduler.config.algorithm_type = "dpmsolver++"
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = False
    elif sampler_type in ['dpm++_lm']:
        sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.config.solver_order = 3
        sd_pipe.scheduler.config.algorithm_type = "dpmsolver++"
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = True
    elif sampler_type in ['pndm']:
        sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
    elif sampler_type in ['ddim']:
        sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = False
        sd_pipe.scheduler.kappa = kappa
    elif sampler_type in ['ddim_lm']:
        sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = True
        sd_pipe.scheduler.kappa = kappa
    elif sampler_type in ['unipc']:
        sd_pipe.scheduler = UniPCMultistepScheduler.from_config(sd_pipe.scheduler.config)

    save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    # COCO prompts
    with open('/mnt/chongqinggeminiceph1fs/geminicephfs/mm-base-vision/pazelzhang/make_dataset/fid_3W_json.json') as fr:
        COCO_prompts_dict = json.load(fr)
    image_id = COCO_prompts_dict.keys()
    with torch.no_grad():
        for pi, key in enumerate(image_id):
            if pi >= start_index  and pi < start_index + test_num:
                print(key)
                print(COCO_prompts_dict[key])
                prompt = COCO_prompts_dict[key]
                negative_prompt = None

                for seed in [1]:
                    generator = torch.Generator(device='cuda')
                    generator = generator.manual_seed(args.seed)
                    res = sd_pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
                            guidance_scale=guidance_scale, generator=generator).images[0]
                    res.save(os.path.join(save_dir, f"{pi:05d}_{key}_guidance{guidance_scale}_inference{num_inference_steps}_seed{seed}_{sampler_type}.jpg"))
                    print(f"{sampler_type}##{key},done")


if __name__ == '__main__':
    main()