Spaces:
Runtime error
Runtime error
| import jax | |
| import numpy as np | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from diffusers import DiffusionPipeline | |
| model_path = "sabman/map-diffuser-v3" | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| model_path, | |
| from_flax=True, safety_checker=None).to("cuda") | |
| # prompt = "create a map with traffic signals, busway and residential buildings, in water color style" | |
| def generate_images(prompt): | |
| prng_seed = jax.random.PRNGKey(-1) | |
| num_inference_steps = 20 | |
| images = pipeline(prompt, width=512, num_inference_steps=num_inference_steps, num_images_per_prompt=1).images | |
| return images[0] | |