File size: 4,140 Bytes
9b57ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from generator import Generator
import json
import os
import torch
import gc
from utils.pipelines import *
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="生成图片")
parser.add_argument(
"--json_path",
type=str,
help="json路径",
)
parser.add_argument(
"--out_dir",
type=str,
help="输出目录",
)
parser.add_argument("--num_devices", type=int, default=8, help="设备数量")
parser.add_argument("--batch_size", type=int, default=1, help="批量大小")
parser.add_argument("--num_machine", type=int, default=1, help="机器数量")
parser.add_argument("--machine_id", type=int, default=0, help="机器id")
parser.add_argument(
"--pipeline_name", type=str, nargs="+", default=None, help="pipeline名称"
)
parser.add_argument("--enable_availabel_check", action="store_true")
parser.add_argument("--reverse", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
num_devices = args.num_devices
pipeline_params = [globals()[f"{name}_pipe"] for name in args.pipeline_name]
if args.reverse:
pipeline_params = pipeline_params[::-1]
# first check all pipeline
if args.enable_availabel_check:
print(f"Checking {len(pipeline_params)} pipelines")
for pipeline_param in pipeline_params:
generator = Generator(
pipe_name=pipeline_param.pipeline_name,
pipe_type=pipeline_param.pipeline_type,
pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
num_devices=num_devices,
)
with open(args.json_path, "r") as f:
entries = json.load(f)
info_dict = entries[: args.batch_size]
generator.generate(
info_dict,
os.path.join(args.out_dir, pipeline_param.generation_path),
batch_size=args.batch_size,
num_processes=num_devices,
seed=42,
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
generation_kwargs=pipeline_param.generation_kwargs,
base_resolution=pipeline_param.base_resolution,
force_aspect_ratio=pipeline_param.force_aspect_ratio,
)
del generator
gc.collect()
torch.cuda.empty_cache()
print(f"Finished Checking {pipeline_param.pipeline_name}")
for pipeline_param in pipeline_params:
generator = Generator(
pipe_name=pipeline_param.pipeline_name,
pipe_type=pipeline_param.pipeline_type,
pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
num_devices=num_devices,
)
with open(args.json_path, "r") as f:
entries = json.load(f)
for i in range(args.num_machine):
start_idx = i * len(entries) // args.num_machine
end_idx = (
(i + 1) * len(entries) // args.num_machine
if i != args.num_machine - 1
else len(entries)
)
if i == args.machine_id:
info_dict = entries[start_idx:end_idx]
info_dict = sorted(info_dict, key=lambda x: x["aspect_ratio"])
print(f"Generating {len(info_dict)} images")
generator.generate(
info_dict,
os.path.join(args.out_dir, pipeline_param.generation_path),
batch_size=args.batch_size,
num_processes=num_devices,
seed=42,
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
generation_kwargs=pipeline_param.generation_kwargs,
base_resolution=pipeline_param.base_resolution,
force_aspect_ratio=pipeline_param.force_aspect_ratio,
)
print(f"Finished generating {pipeline_param.pipeline_name}")
for pipeline in generator.pipelines:
pipeline.to("cpu")
del generator
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
main()
|