| |
| import time |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from flax.jax_utils import replicate |
|
|
| |
| from jax.experimental.compilation_cache import compilation_cache as cc |
|
|
| from diffusers import FlaxStableDiffusionXLPipeline |
|
|
|
|
| cc.initialize_cache("/tmp/sdxl_cache") |
|
|
|
|
| NUM_DEVICES = jax.device_count() |
|
|
| |
| |
| |
| pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True |
| ) |
|
|
| |
| |
| scheduler_state = params.pop("scheduler") |
| params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) |
| params["scheduler"] = scheduler_state |
|
|
| |
| default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart" |
| default_neg_prompt = "fog, grainy, purple" |
| default_seed = 33 |
| default_guidance_scale = 5.0 |
| default_num_steps = 25 |
|
|
|
|
| |
| |
| |
| def tokenize_prompt(prompt, neg_prompt): |
| prompt_ids = pipeline.prepare_inputs(prompt) |
| neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) |
| return prompt_ids, neg_prompt_ids |
|
|
|
|
| |
| |
| |
| |
| |
| p_params = replicate(params) |
|
|
|
|
| def replicate_all(prompt_ids, neg_prompt_ids, seed): |
| p_prompt_ids = replicate(prompt_ids) |
| p_neg_prompt_ids = replicate(neg_prompt_ids) |
| rng = jax.random.PRNGKey(seed) |
| rng = jax.random.split(rng, NUM_DEVICES) |
| return p_prompt_ids, p_neg_prompt_ids, rng |
|
|
|
|
| |
| def generate( |
| prompt, |
| negative_prompt, |
| seed=default_seed, |
| guidance_scale=default_guidance_scale, |
| num_inference_steps=default_num_steps, |
| ): |
| prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) |
| prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) |
| images = pipeline( |
| prompt_ids, |
| p_params, |
| rng, |
| num_inference_steps=num_inference_steps, |
| neg_prompt_ids=neg_prompt_ids, |
| guidance_scale=guidance_scale, |
| jit=True, |
| ).images |
|
|
| |
| images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) |
| return pipeline.numpy_to_pil(np.array(images)) |
|
|
|
|
| |
| |
| start = time.time() |
| print("Compiling ...") |
| generate(default_prompt, default_neg_prompt) |
| print(f"Compiled in {time.time() - start}") |
|
|
| |
| start = time.time() |
| prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" |
| neg_prompt = "cartoon, illustration, animation. face. male, female" |
| images = generate(prompt, neg_prompt) |
| print(f"Inference in {time.time() - start}") |
|
|
| for i, image in enumerate(images): |
| image.save(f"castle_{i}.png") |
|
|