| ```python | |
| #!/usr/bin/env python3 | |
| from diffusers import FlaxStableDiffusionPipeline | |
| from jax import pmap | |
| import numpy as np | |
| import jax | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| prng_seed = jax.random.PRNGKey(0) | |
| num_inference_steps = 50 | |
| pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("fusing/stable-diffusion-flax-new", use_auth_token=True) | |
| del params["safety_checker"] | |
| # pmap | |
| p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) | |
| # prep prompts | |
| prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" | |
| num_samples = jax.device_count() | |
| prompt = num_samples * [prompt] | |
| prompt_ids = pipeline.prepare_inputs(prompt) | |
| # replicate | |
| params = replicate(params) | |
| prng_seed = jax.random.split(prng_seed, 8) | |
| prompt_ids = shard(prompt_ids) | |
| # run | |
| images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images | |
| # get pil images | |
| images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
| import ipdb; ipdb.set_trace() | |
| print("Images should be good") | |
| # images_pil[0].save(...) | |
| ``` |