| | from diffusers import DiffusionPipeline |
| | import gradio as gr |
| | import torch |
| | import time |
| | import psutil |
| |
|
| |
|
| | start_time = time.time() |
| |
|
| | device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶" |
| |
|
| |
|
| | def error_str(error, title="Error"): |
| | return ( |
| | f"""#### {title} |
| | {error}""" |
| | if error |
| | else "" |
| | ) |
| |
|
| |
|
| | def inference( |
| | repo_id, |
| | discuss_nr, |
| | prompt, |
| | ): |
| |
|
| | print(psutil.virtual_memory()) |
| |
|
| | seed = 0 |
| | torch_device = "cuda" if "GPU" in device else "cpu" |
| |
|
| | generator = torch.Generator(torch_device).manual_seed(seed) |
| |
|
| | dtype = torch.float16 if torch_device == "cuda" else torch.float32 |
| |
|
| | try: |
| | revision = f"refs/pr/{discuss_nr}" |
| | pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype) |
| | pipe.to(torch_device) |
| |
|
| | return pipe(prompt, generator=generator, num_inference_steps=25).images |
| | except Exception as e: |
| | url = f"https://huggingface.co/{repo_id}/discussions/{discuss_nr}" |
| | message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n" |
| | return None, error_str(message + e) |
| |
|
| |
|
| | with gr.Blocks(css="style.css") as demo: |
| | gr.HTML( |
| | f""" |
| | <div class="diffusion"> |
| | <p> |
| | Space to test whether `diffusers` PRs work. |
| | </p> |
| | <p> |
| | Running on <b>{device}</b> |
| | </p> |
| | </div> |
| | """ |
| | ) |
| | with gr.Row(): |
| |
|
| | with gr.Column(scale=55): |
| | with gr.Group(): |
| | repo_id = gr.Textbox( |
| | label="Repo id on Hub", |
| | placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co/CompVis/stable-diffusion-v1-4", |
| | ) |
| | discuss_nr = gr.Textbox( |
| | label="Discussion number", |
| | placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co/CompVis/stable-diffusion-v1-4/discussions/171", |
| | ) |
| | prompt = gr.Textbox( |
| | label="Prompt", |
| | default="An astronaut riding a horse on Mars.", |
| | placeholder="Enter prompt.", |
| | ) |
| | gallery = gr.Gallery( |
| | label="Generated images", show_label=False, elem_id="gallery" |
| | ).style(grid=[2], height="auto") |
| |
|
| | error_output = gr.Markdown() |
| |
|
| | generate = gr.Button(value="Generate").style( |
| | rounded=(False, True, True, False) |
| | ) |
| |
|
| | inputs = [ |
| | repo_id, |
| | discuss_nr, |
| | prompt, |
| | ] |
| | outputs = [gallery, error_output] |
| | prompt.submit(inference, inputs=inputs, outputs=outputs) |
| | generate.click(inference, inputs=inputs, outputs=outputs) |
| |
|
| | print(f"Space built in {time.time() - start_time:.2f} seconds") |
| |
|
| | demo.queue(concurrency_count=1) |
| | demo.launch() |
| |
|