|
|
import torch |
|
|
import os |
|
|
import inspect |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from utils.utils import init_multiple_pipelines |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
|
|
|
|
|
class Generator: |
|
|
def __init__( |
|
|
self, pipe_name, pipe_type, pipe_init_kwargs, num_devices, device_id=None |
|
|
): |
|
|
self.pipe_names = pipe_name |
|
|
self.pipe_type = pipe_type |
|
|
self.pipe_init_kwargs = pipe_init_kwargs |
|
|
self.pipelines = init_multiple_pipelines( |
|
|
pipe_name, pipe_init_kwargs, num_devices, device_id |
|
|
) |
|
|
|
|
|
def generate_imgs( |
|
|
self, |
|
|
num_device, |
|
|
batch_size, |
|
|
generation_path, |
|
|
info_dict, |
|
|
pipeline, |
|
|
device_id, |
|
|
weight_dtype, |
|
|
seed, |
|
|
base_resolution, |
|
|
force_aspect_ratio, |
|
|
generation_kwargs, |
|
|
|
|
|
): |
|
|
|
|
|
torch.cuda.set_device(f"cuda:{device_id%num_device}") |
|
|
device = torch.device(f"cuda:{device_id%num_device}") |
|
|
|
|
|
num_prompts_per_device = len(info_dict) // num_device |
|
|
start_idx = device_id * num_prompts_per_device |
|
|
end_idx = ( |
|
|
start_idx + num_prompts_per_device |
|
|
if device_id != (num_device - 1) |
|
|
else len(info_dict) |
|
|
) |
|
|
|
|
|
device_info_dict = info_dict[start_idx:end_idx] |
|
|
|
|
|
print(f"Device {device} generating for prompts {start_idx} to {end_idx-1}") |
|
|
|
|
|
print("## Prepare generation dataset") |
|
|
|
|
|
total_batches = len(device_info_dict) // batch_size + ( |
|
|
1 if len(device_info_dict) % batch_size != 0 else 0 |
|
|
) |
|
|
for batch_idx in tqdm( |
|
|
range(total_batches), desc="Pipeline: " + self.pipe_names |
|
|
): |
|
|
batch_info_dict = device_info_dict[ |
|
|
batch_idx * batch_size : (batch_idx + 1) * batch_size |
|
|
] |
|
|
save_paths = [] |
|
|
for info_dict in batch_info_dict: |
|
|
if info_dict["image_file"] is not None: |
|
|
save_paths.append( |
|
|
os.path.join(generation_path, info_dict["image_file"][:-4] + ".png") |
|
|
) |
|
|
else: |
|
|
save_paths.append( |
|
|
os.path.join(generation_path, info_dict["save_name"] + ".png") |
|
|
) |
|
|
|
|
|
exists_idx = [] |
|
|
for i, save_path in enumerate(save_paths): |
|
|
if os.path.exists(save_path): |
|
|
exists_idx.append(i) |
|
|
|
|
|
batch_info_dict = [ |
|
|
batch_info_dict[i] |
|
|
for i in range(len(batch_info_dict)) |
|
|
if i not in exists_idx |
|
|
] |
|
|
if len(batch_info_dict) == 0: |
|
|
continue |
|
|
|
|
|
batch_prompts = [info_dict["caption"] for info_dict in batch_info_dict] |
|
|
batch_image_file = [ |
|
|
info_dict["image_file"] for info_dict in batch_info_dict |
|
|
] |
|
|
if batch_image_file[0] is not None: |
|
|
try: |
|
|
batch_image_sizes = [ |
|
|
Image.open(image_file).size for image_file in batch_image_file |
|
|
] |
|
|
except: |
|
|
batch_image_sizes = None |
|
|
else: |
|
|
batch_image_sizes = [ |
|
|
(batch_info_dict[i]["width"], batch_info_dict[i]["height"]) |
|
|
for i in range(len(batch_info_dict)) |
|
|
] |
|
|
|
|
|
if batch_image_sizes is None: |
|
|
aspect_ratios = [ |
|
|
info_dict["aspect_ratio"] for info_dict in batch_info_dict |
|
|
] |
|
|
else: |
|
|
aspect_ratios = [size[0] / size[1] for size in batch_image_sizes] |
|
|
|
|
|
if force_aspect_ratio: |
|
|
height = int(base_resolution / force_aspect_ratio // 64 * 64) |
|
|
width = int(base_resolution * force_aspect_ratio // 64 * 64) |
|
|
else: |
|
|
|
|
|
height = int(base_resolution / aspect_ratios[0] ** (0.5) // 64 * 64) |
|
|
width = int(base_resolution * aspect_ratios[0] ** (0.5) // 64 * 64) |
|
|
generation_kwargs.update({"height": height, "width": width}) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed + batch_idx) |
|
|
|
|
|
pipeline_signature = inspect.signature(pipeline) |
|
|
pipeline_params = pipeline_signature.parameters.keys() |
|
|
|
|
|
if 'height' not in pipeline_params: |
|
|
generation_kwargs.pop('height', None) |
|
|
print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs") |
|
|
if 'width' not in pipeline_params: |
|
|
generation_kwargs.pop('width', None) |
|
|
print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs") |
|
|
|
|
|
try: |
|
|
outputs = pipeline( |
|
|
prompt=batch_prompts, generator=generator, **generation_kwargs |
|
|
) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
continue |
|
|
if self.pipe_type == "t2i": |
|
|
images = outputs.images |
|
|
elif self.pipe_type == "t2v": |
|
|
images = outputs.frames[0] |
|
|
|
|
|
for img_idx, (img, prompt, image_file, info_dict) in enumerate( |
|
|
zip(images, batch_prompts, batch_image_file, batch_info_dict) |
|
|
): |
|
|
if image_file is None: |
|
|
img_path = os.path.join( |
|
|
generation_path, info_dict["save_name"] + ".png" |
|
|
) |
|
|
else: |
|
|
img_path = generation_path + image_file[:-4] + ".png" |
|
|
|
|
|
if not os.path.exists(os.path.dirname(img_path)): |
|
|
os.makedirs(os.path.dirname(img_path), exist_ok=True) |
|
|
img.save(img_path) |
|
|
if image_file is None: |
|
|
text_path = os.path.join( |
|
|
generation_path, info_dict["save_name"] + ".txt" |
|
|
) |
|
|
else: |
|
|
text_path = generation_path + image_file[:-4] + ".txt" |
|
|
try: |
|
|
with open(text_path, "w") as f: |
|
|
f.write(prompt) |
|
|
f.write("\n") |
|
|
f.write( |
|
|
image_file |
|
|
if image_file is not None |
|
|
else info_dict["save_name"] |
|
|
) |
|
|
except: |
|
|
pass |
|
|
return True |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
info_dict, |
|
|
generation_path, |
|
|
num_processes, |
|
|
batch_size, |
|
|
weight_dtype, |
|
|
seed, |
|
|
generation_kwargs, |
|
|
base_resolution, |
|
|
force_aspect_ratio, |
|
|
): |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=num_processes) as executor: |
|
|
futures = [ |
|
|
executor.submit( |
|
|
self.generate_imgs, |
|
|
num_processes, |
|
|
batch_size, |
|
|
generation_path, |
|
|
info_dict, |
|
|
self.pipelines[device_id], |
|
|
device_id, |
|
|
weight_dtype, |
|
|
seed, |
|
|
base_resolution, |
|
|
force_aspect_ratio, |
|
|
generation_kwargs, |
|
|
) |
|
|
for device_id in range(num_processes) |
|
|
] |
|
|
|
|
|
for future in as_completed(futures): |
|
|
print(f"Task completed: {future.result()}") |
|
|
|