| import threading |
| from collections import deque |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from constants import DESCRIPTION, LOGO |
| from gradio_examples import EXAMPLES |
| from model import get_pipeline |
| from utils import replace_background |
|
|
| MAX_QUEUE_SIZE = 4 |
|
|
| pipeline = get_pipeline() |
|
|
|
|
| @dataclass |
| class GenerationState: |
| prompts: deque |
| generations: deque |
|
|
|
|
| def get_initial_state() -> GenerationState: |
| return GenerationState( |
| prompts=deque(maxlen=MAX_QUEUE_SIZE), |
| generations=deque(maxlen=MAX_QUEUE_SIZE), |
| ) |
|
|
|
|
| def load_initial_state(request: gr.Request) -> GenerationState: |
| print("Loading initial state for", request.client.host) |
| print("Total number of active threads", threading.active_count()) |
|
|
| return get_initial_state() |
|
|
|
|
| async def put_to_queue( |
| image: Optional[Image.Image], |
| prompt: str, |
| seed: int, |
| strength: float, |
| state: GenerationState, |
| ): |
| prompts_queue = state.prompts |
|
|
| if prompt and image is not None: |
| prompts_queue.append((image, prompt, seed, strength)) |
|
|
| return state |
|
|
|
|
| def inference(state: GenerationState) -> Image.Image: |
| prompts_queue = state.prompts |
| generations_queue = state.generations |
|
|
| if len(prompts_queue) == 0: |
| return state |
|
|
| image, prompt, seed, strength = prompts_queue.popleft() |
|
|
| original_image_size = image.size |
| image = replace_background(image.resize((512, 512))) |
|
|
| result = pipeline( |
| prompt=prompt, |
| image=image, |
| strength=strength, |
| seed=seed, |
| guidance_scale=1, |
| num_inference_steps=4, |
| ) |
|
|
| output_image = result.images[0].resize(original_image_size) |
|
|
| generations_queue.append(output_image) |
|
|
| return state |
|
|
|
|
| def update_output_image(state: GenerationState): |
| image_update = gr.update() |
|
|
| generations_queue = state.generations |
|
|
| if len(generations_queue) > 0: |
| generated_image = generations_queue.popleft() |
| image_update = gr.update(value=generated_image) |
|
|
| return image_update, state |
|
|
|
|
| with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: |
| generation_state = gr.State(get_initial_state()) |
|
|
| gr.HTML(f'<div style="width: 70px;">{LOGO}</div>') |
| gr.Markdown(DESCRIPTION) |
| with gr.Row(variant="default"): |
| input_image = gr.Image( |
| tool="color-sketch", |
| source="canvas", |
| label="Initial Image", |
| type="pil", |
| height=512, |
| width=512, |
| brush_radius=40.0, |
| ) |
|
|
| output_image = gr.Image( |
| label="Generated Image", |
| type="pil", |
| interactive=False, |
| elem_id="output_image", |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0]) |
|
|
| with gr.Accordion(label="Advanced Options", open=False): |
| with gr.Row(): |
| with gr.Column(): |
| strength = gr.Slider( |
| label="Strength", |
| minimum=0.1, |
| maximum=1.0, |
| step=0.05, |
| value=0.8, |
| info=""" |
| Strength of the initial image that will be applied during inference. |
| """, |
| ) |
| with gr.Column(): |
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=2**31 - 1, |
| step=1, |
| randomize=True, |
| info=""" |
| Seed for the random number generator. |
| """, |
| ) |
|
|
| demo.load( |
| load_initial_state, |
| outputs=[generation_state], |
| ) |
| demo.load( |
| inference, |
| inputs=[generation_state], |
| outputs=[generation_state], |
| every=0.1, |
| ) |
| demo.load( |
| update_output_image, |
| inputs=[generation_state], |
| outputs=[output_image, generation_state], |
| every=0.1, |
| ) |
| for event in [input_image.change, prompt_box.change, strength.change, seed.change]: |
| event( |
| put_to_queue, |
| [input_image, prompt_box, seed, strength, generation_state], |
| [generation_state], |
| show_progress=False, |
| queue=True, |
| ) |
|
|
| gr.Markdown("## Example Prompts") |
| gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(concurrency_count=20, api_open=False).launch(max_threads=1024) |
|
|