| import torch |
| import os |
|
|
| auth_token = os.getenv("HF_TOKEN") |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| torch_dtype = torch.float16 if device == "cuda" else None |
|
|
| from diffusers import StableDiffusionPipeline |
|
|
| model_id = "CompVis/stable-diffusion-v1-4" |
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_id, auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype |
| ).to(device) |
|
|
| def predict(prompt): |
| return pipe(prompt).images[0] |
| |
| import gradio as gr |
|
|
| gradio_ui = gr.Interface( |
| fn=predict, |
| title="Stable Diffusion Demo", |
| description="Enter a description of an image you'd like to generate!", |
| inputs=[ |
| gr.Textbox(lines=2, label="Paste some text here"), |
| ], |
| outputs=["image"], |
| examples=[["a photograph of an astronaut riding a horse"]], |
| ) |
|
|
| gradio_ui.launch() |