D3vShoaib's picture
Add Git LFS support and remove binary files
04eaca9
raw
history blame
5.78 kB
import random
import time
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
from utils import get_args
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
import gradio as gr
MAX_SEED = 1000000000
args = get_args()
if args.precision == "bf16":
pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = args.precision
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
img = image["composite"].convert("RGB")
start_time = time.time()
result_image = pipeline(
prompt=prompt,
image=img,
height=img.height,
width=img.width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache()
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
header = gr.HTML(header_str)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.ImageEditor(
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Input",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt")
gr.Markdown("**2**. Upload an image")
gr.Markdown("**3**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
randomize_seed.click(
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)