| from huggingface_hub import from_pretrained_keras |
| from keras_cv import models |
| import gradio as gr |
|
|
| from tensorflow import keras |
|
|
| keras.mixed_precision.set_global_policy("mixed_float16") |
|
|
| |
| resolution = 512 |
| sd_dreambooth_model = models.StableDiffusion( |
| img_width=resolution, img_height=resolution |
| ) |
| db_diffusion_model = from_pretrained_keras("keras-dreambooth/dreambooth_diffusion_model") |
| sd_dreambooth_model._diffusion_model = db_diffusion_model |
|
|
| |
| def infer(prompt): |
| generated_images = sd_dreambooth_model.text_to_image( |
| prompt, batch_size=2 |
| ) |
| return generated_images |
| |
| output = gr.Gallery(label="Outputs").style(grid=(1,2)) |
|
|
| |
| title = "Dreambooth Demo on Dog Images" |
| description = "This is a dreambooth model fine-tuned on dog images. To try it, input the concept with {sks dog}." |
| examples=[["sks dog in space"]] |
| gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples).queue().launch() |