| | import time |
| |
|
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| | from flax.jax_utils import replicate |
| | from jax import pmap |
| |
|
| | |
| | from jax.experimental.compilation_cache import compilation_cache as cc |
| |
|
| | from diffusers import FlaxStableDiffusionXLPipeline |
| |
|
| |
|
| | cc.initialize_cache("/tmp/sdxl_cache") |
| |
|
| |
|
| | NUM_DEVICES = jax.device_count() |
| |
|
| | |
| | |
| | |
| | pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True |
| | ) |
| |
|
| | |
| | |
| | scheduler_state = params.pop("scheduler") |
| | params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) |
| | params["scheduler"] = scheduler_state |
| |
|
| | |
| | default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart" |
| | default_neg_prompt = "fog, grainy, purple" |
| | default_seed = 33 |
| | default_guidance_scale = 5.0 |
| | default_num_steps = 25 |
| | width = 1024 |
| | height = 1024 |
| |
|
| |
|
| | |
| | |
| | |
| | def tokenize_prompt(prompt, neg_prompt): |
| | prompt_ids = pipeline.prepare_inputs(prompt) |
| | neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) |
| | return prompt_ids, neg_prompt_ids |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | p_params = replicate(params) |
| |
|
| |
|
| | def replicate_all(prompt_ids, neg_prompt_ids, seed): |
| | p_prompt_ids = replicate(prompt_ids) |
| | p_neg_prompt_ids = replicate(neg_prompt_ids) |
| | rng = jax.random.PRNGKey(seed) |
| | rng = jax.random.split(rng, NUM_DEVICES) |
| | return p_prompt_ids, p_neg_prompt_ids, rng |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | def aot_compile( |
| | prompt=default_prompt, |
| | negative_prompt=default_neg_prompt, |
| | seed=default_seed, |
| | guidance_scale=default_guidance_scale, |
| | num_inference_steps=default_num_steps, |
| | ): |
| | prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) |
| | prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) |
| | g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32) |
| | g = g[:, None] |
| |
|
| | return ( |
| | pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9]) |
| | .lower( |
| | prompt_ids, |
| | p_params, |
| | rng, |
| | num_inference_steps, |
| | height, |
| | width, |
| | g, |
| | None, |
| | neg_prompt_ids, |
| | False, |
| | ) |
| | .compile() |
| | ) |
| |
|
| |
|
| | start = time.time() |
| | print("Compiling ...") |
| | p_generate = aot_compile() |
| | print(f"Compiled in {time.time() - start}") |
| |
|
| |
|
| | |
| | def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale): |
| | prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) |
| | prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) |
| | g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32) |
| | g = g[:, None] |
| | images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids) |
| |
|
| | |
| | images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) |
| | return pipeline.numpy_to_pil(np.array(images)) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | start = time.time() |
| | prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" |
| | neg_prompt = "cartoon, illustration, animation. face. male, female" |
| | images = generate(prompt, neg_prompt) |
| | print(f"First inference in {time.time() - start}") |
| |
|
| | |
| | |
| | start = time.time() |
| | prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" |
| | neg_prompt = "cartoon, illustration, animation. face. male, female" |
| | images = generate(prompt, neg_prompt) |
| | print(f"Inference in {time.time() - start}") |
| |
|
| | for i, image in enumerate(images): |
| | image.save(f"castle_{i}.png") |
| |
|