HPSv3 / generate /generator.py
sdsdgwe's picture
update
9b57ce7
import torch
import os
import inspect
from PIL import Image
from tqdm import tqdm
from utils.utils import init_multiple_pipelines
from concurrent.futures import ThreadPoolExecutor, as_completed
Image.MAX_IMAGE_PIXELS = None
class Generator:
def __init__(
self, pipe_name, pipe_type, pipe_init_kwargs, num_devices, device_id=None
):
self.pipe_names = pipe_name
self.pipe_type = pipe_type
self.pipe_init_kwargs = pipe_init_kwargs
self.pipelines = init_multiple_pipelines(
pipe_name, pipe_init_kwargs, num_devices, device_id
)
def generate_imgs(
self,
num_device,
batch_size,
generation_path,
info_dict,
pipeline,
device_id,
weight_dtype,
seed,
base_resolution,
force_aspect_ratio,
generation_kwargs,
):
torch.cuda.set_device(f"cuda:{device_id%num_device}")
device = torch.device(f"cuda:{device_id%num_device}")
num_prompts_per_device = len(info_dict) // num_device
start_idx = device_id * num_prompts_per_device
end_idx = (
start_idx + num_prompts_per_device
if device_id != (num_device - 1)
else len(info_dict)
)
device_info_dict = info_dict[start_idx:end_idx]
print(f"Device {device} generating for prompts {start_idx} to {end_idx-1}")
print("## Prepare generation dataset")
total_batches = len(device_info_dict) // batch_size + (
1 if len(device_info_dict) % batch_size != 0 else 0
)
for batch_idx in tqdm(
range(total_batches), desc="Pipeline: " + self.pipe_names
):
batch_info_dict = device_info_dict[
batch_idx * batch_size : (batch_idx + 1) * batch_size
]
save_paths = []
for info_dict in batch_info_dict:
if info_dict["image_file"] is not None:
save_paths.append(
os.path.join(generation_path, info_dict["image_file"][:-4] + ".png")
)
else:
save_paths.append(
os.path.join(generation_path, info_dict["save_name"] + ".png")
)
exists_idx = []
for i, save_path in enumerate(save_paths):
if os.path.exists(save_path):
exists_idx.append(i)
batch_info_dict = [
batch_info_dict[i]
for i in range(len(batch_info_dict))
if i not in exists_idx
]
if len(batch_info_dict) == 0:
continue
batch_prompts = [info_dict["caption"] for info_dict in batch_info_dict]
batch_image_file = [
info_dict["image_file"] for info_dict in batch_info_dict
]
if batch_image_file[0] is not None:
try:
batch_image_sizes = [
Image.open(image_file).size for image_file in batch_image_file
]
except:
batch_image_sizes = None
else:
batch_image_sizes = [
(batch_info_dict[i]["width"], batch_info_dict[i]["height"])
for i in range(len(batch_info_dict))
]
if batch_image_sizes is None:
aspect_ratios = [
info_dict["aspect_ratio"] for info_dict in batch_info_dict
]
else:
aspect_ratios = [size[0] / size[1] for size in batch_image_sizes]
if force_aspect_ratio:
height = int(base_resolution / force_aspect_ratio // 64 * 64)
width = int(base_resolution * force_aspect_ratio // 64 * 64)
else:
# 根据aspect_ratios调整base_resolution, 得到height和width, 保证调整后的乘积大概等于base_resolution**2
height = int(base_resolution / aspect_ratios[0] ** (0.5) // 64 * 64)
width = int(base_resolution * aspect_ratios[0] ** (0.5) // 64 * 64)
generation_kwargs.update({"height": height, "width": width})
generator = torch.Generator().manual_seed(seed + batch_idx)
pipeline_signature = inspect.signature(pipeline)
pipeline_params = pipeline_signature.parameters.keys()
if 'height' not in pipeline_params:
generation_kwargs.pop('height', None)
print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
if 'width' not in pipeline_params:
generation_kwargs.pop('width', None)
print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
try:
outputs = pipeline(
prompt=batch_prompts, generator=generator, **generation_kwargs
)
except Exception as e:
print(e)
continue
if self.pipe_type == "t2i":
images = outputs.images
elif self.pipe_type == "t2v":
images = outputs.frames[0]
for img_idx, (img, prompt, image_file, info_dict) in enumerate(
zip(images, batch_prompts, batch_image_file, batch_info_dict)
):
if image_file is None:
img_path = os.path.join(
generation_path, info_dict["save_name"] + ".png"
)
else:
img_path = generation_path + image_file[:-4] + ".png"
if not os.path.exists(os.path.dirname(img_path)):
os.makedirs(os.path.dirname(img_path), exist_ok=True)
img.save(img_path)
if image_file is None:
text_path = os.path.join(
generation_path, info_dict["save_name"] + ".txt"
)
else:
text_path = generation_path + image_file[:-4] + ".txt"
try:
with open(text_path, "w") as f:
f.write(prompt)
f.write("\n")
f.write(
image_file
if image_file is not None
else info_dict["save_name"]
)
except:
pass
return True
def generate(
self,
info_dict,
generation_path,
num_processes,
batch_size,
weight_dtype,
seed,
generation_kwargs,
base_resolution,
force_aspect_ratio,
):
with ThreadPoolExecutor(max_workers=num_processes) as executor:
futures = [
executor.submit(
self.generate_imgs,
num_processes,
batch_size,
generation_path,
info_dict,
self.pipelines[device_id],
device_id,
weight_dtype,
seed,
base_resolution,
force_aspect_ratio,
generation_kwargs,
)
for device_id in range(num_processes)
]
for future in as_completed(futures):
print(f"Task completed: {future.result()}")