| from diffusers import DiffusionPipeline | |
| from concurrent.futures import ThreadPoolExecutor | |
| import pandas as pd | |
| import argparse | |
| import torch | |
| import os | |
| ALL_CKPTS = [ | |
| "runwayml/stable-diffusion-v1-5", | |
| "segmind/SSD-1B", | |
| "PixArt-alpha/PixArt-XL-2-1024-MS", | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "stabilityai/sdxl-turbo", | |
| ] | |
| SEED = 2024 | |
| def load_dataframe(): | |
| dataframe = pd.read_csv( | |
| "https://huggingface.co/datasets/sayakpaul/sample-datasets/raw/main/coco_30k_randomly_sampled_2014_val.csv" | |
| ) | |
| return dataframe | |
| def load_pipeline(args): | |
| if "runway" in args.pipeline_id: | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| args.pipeline_id, torch_dtype=torch.float16, safety_checker=None | |
| ).to("cuda") | |
| else: | |
| pipeline = DiffusionPipeline.from_pretrained(args.pipeline_id, torch_dtype=torch.float16).to("cuda") | |
| pipeline.set_progress_bar_config(disable=True) | |
| return pipeline | |
| def generate_images(args, dataframe, pipeline): | |
| all_images = [] | |
| for i in range(0, len(dataframe), args.chunk_size): | |
| if "sdxl-turbo" not in args.pipeline_id: | |
| images = pipeline( | |
| dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(), | |
| num_inference_steps=args.num_inference_steps, | |
| generator=torch.manual_seed(SEED), | |
| ).images | |
| else: | |
| images = pipeline( | |
| dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(), | |
| num_inference_steps=args.num_inference_steps, | |
| generator=torch.manual_seed(SEED), | |
| guidance_scale=0.0, | |
| ).images | |
| all_images.extend(images) | |
| return all_images | |
| def serialize_image(image, path): | |
| image.save(path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--pipeline_id", default="runwayml/stable-diffusion-v1-5", type=str, choices=ALL_CKPTS) | |
| parser.add_argument("--num_inference_steps", default=30, type=int) | |
| parser.add_argument("--chunk_size", default=2, type=int) | |
| parser.add_argument("--root_img_path", default="sdv15", type=str) | |
| parser.add_argument("--num_workers", type=int, default=4) | |
| args = parser.parse_args() | |
| dataset = load_dataframe() | |
| pipeline = load_pipeline(args) | |
| images = generate_images(args, dataset, pipeline) | |
| image_paths = [os.path.join(args.root_img_path, f"{i}.jpg") for i in range(len(images))] | |
| if not os.path.exists(args.root_img_path): | |
| os.makedirs(args.root_img_path) | |
| with ThreadPoolExecutor(max_workers=args.num_workers) as executor: | |
| executor.map(serialize_image, images, image_paths) | |