| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import jax.numpy as jnp | |
| from flax.training import checkpoints | |
| from diffusers import FlaxControlNetModel, FlaxUNet2DConditionModel, FlaxAutoencoderKL, FlaxDDIMScheduler | |
| from codi.controlnet_flax import FlaxControlNetModel | |
| from codi.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline | |
| from transformers import CLIPTokenizer, FlaxCLIPTextModel | |
| from flax.training.common_utils import shard | |
| from flax.jax_utils import replicate | |
| MODEL_NAME = "CompVis/stable-diffusion-v1-4" | |
| unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( | |
| MODEL_NAME, | |
| subfolder="unet", | |
| revision="flax", | |
| dtype=jnp.float32, | |
| ) | |
| vae, vae_params = FlaxAutoencoderKL.from_pretrained( | |
| MODEL_NAME, | |
| subfolder="vae", | |
| revision="flax", | |
| dtype=jnp.float32, | |
| ) | |
| text_encoder = FlaxCLIPTextModel.from_pretrained( | |
| MODEL_NAME, | |
| subfolder="text_encoder", | |
| revision="flax", | |
| dtype=jnp.float32, | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| subfolder="tokenizer", | |
| revision="flax", | |
| dtype=jnp.float32, | |
| ) | |
| controlnet = FlaxControlNetModel( | |
| in_channels=unet.config.in_channels, | |
| down_block_types=unet.config.down_block_types, | |
| only_cross_attention=unet.config.only_cross_attention, | |
| block_out_channels=unet.config.block_out_channels, | |
| layers_per_block=unet.config.layers_per_block, | |
| attention_head_dim=unet.config.attention_head_dim, | |
| cross_attention_dim=unet.config.cross_attention_dim, | |
| use_linear_projection=unet.config.use_linear_projection, | |
| flip_sin_to_cos=unet.config.flip_sin_to_cos, | |
| freq_shift=unet.config.freq_shift, | |
| ) | |
| scheduler = FlaxDDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| trained_betas=None, | |
| set_alpha_to_one=True, | |
| steps_offset=0, | |
| ) | |
| scheduler_state = scheduler.create_state() | |
| pipeline = FlaxStableDiffusionControlNetPipeline( | |
| vae, | |
| text_encoder, | |
| tokenizer, | |
| unet, | |
| controlnet, | |
| scheduler, | |
| None, | |
| None, | |
| dtype=jnp.float32, | |
| ) | |
| controlnet_params = checkpoints.restore_checkpoint("checkpoint_100000.orbax", target=None) | |
| pipeline_params = { | |
| "vae": vae_params, | |
| "unet": unet_params, | |
| "text_encoder": text_encoder.params, | |
| "scheduler": scheduler_state, | |
| "controlnet": controlnet_params, | |
| } | |
| pipeline_params = replicate(pipeline_params) | |
| def infer(seed, prompt, negative_prompt, steps, cfgr): | |
| rng = jax.random.PRNGKey(int(seed)) | |
| num_samples = jax.device_count() | |
| rng = jax.random.split(rng, num_samples) | |
| prompt_ids = pipeline.prepare_text_inputs([prompt] * num_samples) | |
| negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) | |
| prompt_ids = shard(prompt_ids) | |
| negative_prompt_ids = shard(negative_prompt_ids) | |
| output = pipeline( | |
| prompt_ids=prompt_ids, | |
| image=None, | |
| params=pipeline_params, | |
| prng_seed=rng, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(cfgr), | |
| neg_prompt_ids=negative_prompt_ids, | |
| jit=True, | |
| ).images | |
| output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) | |
| return output_images | |
| with gr.Blocks(theme='gradio/soft') as demo: | |
| gr.Markdown("## CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation") | |
| gr.Markdown("[\[Paper\]](https://arxiv.org/abs/2310.01407) [\[Project Page\]](https://fast-codi.github.io) [\[Code\]](https://github.com/fast-codi/CoDi)") | |
| with gr.Tab("CoDi on Text-to-Image"): | |
| with gr.Row(): | |
| gr.Radio(["CompVis/stable-diffusion-v1-4"], value="CompVis/stable-diffusion-v1-4", label="baseline model", info="Chose the undistilled baseline model") | |
| gr.Radio(["CoDi/text-to-image-v0-1 (366M)"], value="CoDi/text-to-image-v0-1 (366M)", label="distilled codi", info="Chose the distilled conditional model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt") | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value="monochrome, lowres, bad anatomy, worst quality, low quality") | |
| seed = gr.Number(label="Seed", value=0) | |
| output = gr.Gallery(label="Output Images") | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider(2, 50, value=4, step=1, label="Steps") | |
| guidance_scale = gr.Slider(2.0, 14.0, value=7.5, step=0.5, label='Guidance Scale') | |
| submit_btn = gr.Button(value = "Submit") | |
| inputs = [ | |
| seed, | |
| prompt_input, | |
| negative_prompt, | |
| num_inference_steps, | |
| guidance_scale | |
| ] | |
| submit_btn.click(fn=infer, inputs=inputs, outputs=[output]) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=["oranges", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"], | |
| inputs=prompt_input, | |
| fn=infer | |
| ) | |
| demo.launch() | |