Spaces:
Runtime error
Runtime error
| import jax | |
| import numpy as np | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from diffusers import DiffusionPipeline | |
| model_path = "sabman/map-diffuser-v3" | |
| # pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| model_path, | |
| from_flax=True, safety_checker=None).to("cuda") | |
| # prompt = "create a map with traffic signals, busway and residential buildings, in water color style" | |
| def generate_images(prompt): | |
| prng_seed = jax.random.PRNGKey(-1) | |
| num_inference_steps = 20 | |
| images = pipeline(prompt, width=512, num_inference_steps=20, num_images_per_prompt=1).images | |
| images = pipeline.numpy_to_pil(np.asarray(images.reshape((1,) + images.shape[-3:]))) | |
| # num_samples = jax.device_count() | |
| # prompt = num_samples * [prompt] | |
| # prompt_ids = pipeline.prepare_inputs(prompt) | |
| # # shard inputs and rng | |
| # params = replicate(_params) | |
| # prng_seed = jax.random.split(prng_seed, jax.device_count()) | |
| # prompt_ids = shard(prompt_ids) | |
| # images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | |
| # images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
| return images[0] | |