Spaces:
Runtime error
Runtime error
File size: 5,784 Bytes
04eaca9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import random
import time
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
from utils import get_args
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
import gradio as gr
MAX_SEED = 1000000000
args = get_args()
if args.precision == "bf16":
pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = args.precision
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
img = image["composite"].convert("RGB")
start_time = time.time()
result_image = pipeline(
prompt=prompt,
image=img,
height=img.height,
width=img.width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache()
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
header = gr.HTML(header_str)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.ImageEditor(
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Input",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt")
gr.Markdown("**2**. Upload an image")
gr.Markdown("**3**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
randomize_seed.click(
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
|