File size: 4,140 Bytes
9b57ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from generator import Generator
import json
import os
import torch
import gc
from utils.pipelines import *
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="生成图片")
    parser.add_argument(
        "--json_path",
        type=str,
        help="json路径",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        help="输出目录",
    )
    parser.add_argument("--num_devices", type=int, default=8, help="设备数量")
    parser.add_argument("--batch_size", type=int, default=1, help="批量大小")
    parser.add_argument("--num_machine", type=int, default=1, help="机器数量")
    parser.add_argument("--machine_id", type=int, default=0, help="机器id")
    parser.add_argument(
        "--pipeline_name", type=str, nargs="+", default=None, help="pipeline名称"
    )
    parser.add_argument("--enable_availabel_check", action="store_true")
    parser.add_argument("--reverse", action="store_true")
    return parser.parse_args()


def main():
    args = parse_args()
    num_devices = args.num_devices
    pipeline_params = [globals()[f"{name}_pipe"] for name in args.pipeline_name]

    if args.reverse:
        pipeline_params = pipeline_params[::-1]

    # first check all pipeline
    if args.enable_availabel_check:
        print(f"Checking {len(pipeline_params)} pipelines")
        for pipeline_param in pipeline_params:
            generator = Generator(
                pipe_name=pipeline_param.pipeline_name,
                pipe_type=pipeline_param.pipeline_type,
                pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
                num_devices=num_devices,
            )

            with open(args.json_path, "r") as f:
                entries = json.load(f)
            info_dict = entries[: args.batch_size]
            generator.generate(
                info_dict,
                os.path.join(args.out_dir, pipeline_param.generation_path),
                batch_size=args.batch_size,
                num_processes=num_devices,
                seed=42,
                weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
                generation_kwargs=pipeline_param.generation_kwargs,
                base_resolution=pipeline_param.base_resolution,
                force_aspect_ratio=pipeline_param.force_aspect_ratio,
            )
            del generator
            gc.collect()
            torch.cuda.empty_cache()
            print(f"Finished Checking {pipeline_param.pipeline_name}")

    for pipeline_param in pipeline_params:
        generator = Generator(
            pipe_name=pipeline_param.pipeline_name,
            pipe_type=pipeline_param.pipeline_type,
            pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
            num_devices=num_devices,
        )

        with open(args.json_path, "r") as f:
            entries = json.load(f)

        for i in range(args.num_machine):
            start_idx = i * len(entries) // args.num_machine
            end_idx = (
                (i + 1) * len(entries) // args.num_machine
                if i != args.num_machine - 1
                else len(entries)
            )
            if i == args.machine_id:
                info_dict = entries[start_idx:end_idx]

        info_dict = sorted(info_dict, key=lambda x: x["aspect_ratio"])

        print(f"Generating {len(info_dict)} images")
        generator.generate(
            info_dict,
            os.path.join(args.out_dir, pipeline_param.generation_path),
            batch_size=args.batch_size,
            num_processes=num_devices,
            seed=42,
            weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
            generation_kwargs=pipeline_param.generation_kwargs,
            base_resolution=pipeline_param.base_resolution,
            force_aspect_ratio=pipeline_param.force_aspect_ratio,
        )

        print(f"Finished generating {pipeline_param.pipeline_name}")

        for pipeline in generator.pipelines:
            pipeline.to("cpu")
        del generator
        torch.cuda.empty_cache()
        gc.collect()


if __name__ == "__main__":
    main()