Spaces:
Runtime error
Runtime error
| import random | |
| import time | |
| import torch | |
| from diffusers import FluxKontextPipeline | |
| from PIL import Image | |
| from utils import get_args | |
| from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel | |
| from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel | |
| import gradio as gr | |
| MAX_SEED = 1000000000 | |
| args = get_args() | |
| if args.precision == "bf16": | |
| pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16) | |
| pipeline = pipeline.to("cuda") | |
| pipeline.precision = "bf16" | |
| else: | |
| assert args.precision in ["int4", "fp4"] | |
| pipeline_init_kwargs = {} | |
| transformer = NunchakuFluxTransformer2dModel.from_pretrained( | |
| f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors" | |
| ) | |
| pipeline_init_kwargs["transformer"] = transformer | |
| if args.use_qencoder: | |
| text_encoder_2 = NunchakuT5EncoderModel.from_pretrained( | |
| "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" | |
| ) | |
| pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 | |
| pipeline = FluxKontextPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs | |
| ) | |
| pipeline = pipeline.to("cuda") | |
| pipeline.precision = args.precision | |
| def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]: | |
| img = image["composite"].convert("RGB") | |
| start_time = time.time() | |
| result_image = pipeline( | |
| prompt=prompt, | |
| image=img, | |
| height=img.height, | |
| width=img.width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator().manual_seed(seed), | |
| ).images[0] | |
| latency = time.time() - start_time | |
| if latency < 1: | |
| latency = latency * 1000 | |
| latency_str = f"{latency:.2f}ms" | |
| else: | |
| latency_str = f"{latency:.2f}s" | |
| torch.cuda.empty_cache() | |
| return result_image, latency_str | |
| with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo: | |
| with open("assets/description.html", "r") as f: | |
| DESCRIPTION = f.read() | |
| # Get the GPU properties | |
| if torch.cuda.device_count() > 0: | |
| gpu_properties = torch.cuda.get_device_properties(0) | |
| gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB | |
| gpu_name = torch.cuda.get_device_name(0) | |
| device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory." | |
| else: | |
| device_info = "Running on CPU 🥶 This demo does not work on CPU." | |
| header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="") | |
| header = gr.HTML(header_str) | |
| with gr.Row(elem_id="main_row"): | |
| with gr.Column(elem_id="column_input"): | |
| gr.Markdown("## INPUT", elem_id="input_header") | |
| with gr.Group(): | |
| canvas = gr.ImageEditor( | |
| height=640, | |
| image_mode="RGB", | |
| sources=["upload", "clipboard"], | |
| type="pil", | |
| label="Input", | |
| show_label=False, | |
| show_download_button=True, | |
| interactive=True, | |
| transforms=[], | |
| canvas_size=(1024, 1024), | |
| scale=1, | |
| format="png", | |
| layers=False, | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) | |
| run_button = gr.Button("Run", scale=1, elem_id="run_button") | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) | |
| randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") | |
| with gr.Accordion("Advanced options", open=False): | |
| with gr.Group(): | |
| num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28) | |
| guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5) | |
| with gr.Column(elem_id="column_output"): | |
| gr.Markdown("## OUTPUT", elem_id="output_header") | |
| with gr.Group(): | |
| result = gr.Image( | |
| format="png", | |
| height=640, | |
| image_mode="RGB", | |
| type="pil", | |
| label="Result", | |
| show_label=False, | |
| show_download_button=True, | |
| interactive=False, | |
| elem_id="output_image", | |
| ) | |
| latency_result = gr.Text(label="Inference Latency", show_label=True) | |
| gr.Markdown("### Instructions") | |
| gr.Markdown("**1**. Enter a text prompt") | |
| gr.Markdown("**2**. Upload an image") | |
| gr.Markdown("**3**. Try different seeds to generate different results") | |
| run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed] | |
| run_outputs = [result, latency_result] | |
| randomize_seed.click( | |
| lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False | |
| ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) | |
| gr.on( | |
| triggers=[prompt.submit, run_button.click], | |
| fn=run, | |
| inputs=run_inputs, | |
| outputs=run_outputs, | |
| api_name=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path) | |