Darius Morawiec
Add video of denoising steps
053d37b
raw
history blame
9.3 kB
import shutil
import time
from pathlib import Path
import cv2
import gradio as gr
import PIL.Image
import torch
from diffusers import (
DiffusionPipeline, # type: ignore
QwenImageEditPlusPipeline, # type: ignore
)
# from diffusers.utils import load_image
from nunchaku import NunchakuQwenImageTransformer2DModel
from nunchaku.utils import get_gpu_memory, get_precision
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 128
TRANSFORMER_ID = f"nunchaku-tech/nunchaku-qwen-image-edit-2509/svdq-{get_precision()}_r{RANK}-qwen-image-edit-2509.safetensors"
PIPELINE_ID = "Qwen/Qwen-Image-Edit-2509"
OUTPUT_DIR = Path(__file__).parent / "output"
IMAGES_DIR = OUTPUT_DIR / "images"
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
VIDEO_PATH = OUTPUT_DIR / "video.mp4"
class Model:
def __init__(self):
self.progress = gr.Progress()
self.num_inference_steps = 50
self.current_inference_step = 0
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
TRANSFORMER_ID
)
pipeline = QwenImageEditPlusPipeline.from_pretrained(
PIPELINE_ID,
transformer=transformer,
torch_dtype=torch.bfloat16,
)
if get_gpu_memory() > 18:
pipeline.enable_model_cpu_offload()
else:
transformer.set_offload(
True,
use_pin_memory=False,
num_blocks_on_gpu=1,
)
pipeline._exclude_from_cpu_offload.append("transformer")
pipeline.enable_sequential_cpu_offload()
self.pipeline = pipeline
def compute(
self,
images: list[PIL.Image.Image],
prompt: str,
negative_prompt: str = " ",
true_cfg_scale: float = 4.0,
num_inference_steps: int = 40,
image_width: int = 512,
image_height: int = 512,
) -> tuple[PIL.Image.Image, Path]:
self.num_inference_steps = num_inference_steps
self.current_inference_step = 0
self.progress((self.current_inference_step, self.num_inference_steps))
shutil.rmtree(IMAGES_DIR, ignore_errors=True)
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
# Validate inputs
if not images:
raise gr.Error("No images provided. Please upload at least one image.")
# Ensure all images are valid PIL Images
processed_images = []
for i, img in enumerate(images):
if img is None:
raise gr.Error(f"Image {i + 1} is invalid or could not be loaded.")
processed_images.append(img)
inputs = dict(
image=processed_images,
prompt=prompt,
negative_prompt=negative_prompt,
true_cfg_scale=true_cfg_scale,
num_inference_steps=num_inference_steps,
width=image_width,
height=image_height,
generator=torch.manual_seed(0),
callback_on_step_end=self.callback,
# output_type="latent"
)
output = self.pipeline(**inputs)
output_image = output.images[0]
# Create video from saved images
print(list(IMAGES_DIR.glob("*.png")))
# Get all PNG files and sort them
image_files = sorted(IMAGES_DIR.glob("step_*.png"))
if image_files:
# Read first image to get dimensions
first_img = cv2.imread(str(image_files[0]))
height, width, _ = first_img.shape
# Create video writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = 10 # Adjust frame rate as needed
video_writer = cv2.VideoWriter(
str(VIDEO_PATH.absolute()), fourcc, fps, (width, height)
)
# Add each image to video
for img_path in image_files:
img = cv2.imread(str(img_path))
video_writer.write(img)
video_writer.release()
print(f"Video saved to: {VIDEO_PATH}")
time.sleep(3)
return output_image, VIDEO_PATH
def callback(
self,
pipeline: DiffusionPipeline,
step: int,
timestep: int,
callback_kwargs: dict,
):
latents = callback_kwargs.get("latents", None)
height = callback_kwargs.get("height", 800)
width = callback_kwargs.get("width", 512)
if latents is not None:
print(f"Latents shape: {latents.shape}, dtype: {latents.dtype}")
latents = pipeline._unpack_latents(
latents, height, width, pipeline.vae_scale_factor
)
latents = latents.to(pipeline.vae.dtype)
latents_mean = (
torch.tensor(pipeline.vae.config.latents_mean)
.view(1, pipeline.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(pipeline.vae.config.latents_std).view(
1, pipeline.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
image = pipeline.vae.decode(latents, return_dict=False)[0][:, :, 0]
image = pipeline.image_processor.postprocess(image, output_type="pil")
image = image[0]
image.save(IMAGES_DIR / f"step_{step:03d}.png")
self.current_inference_step += 1
self.progress((self.current_inference_step, self.num_inference_steps))
return {}
with gr.Blocks() as demo:
gr.Markdown("# Nunchaku Qwen-Image-Edit-2509")
with gr.Row():
with gr.Column():
gr.Markdown("## Input Images")
image_inputs = gr.Gallery(
label="Input Images",
show_label=True,
elem_id="gallery",
columns=3,
rows=2,
object_fit="contain",
height="auto",
type="pil",
allow_preview=True,
interactive=True,
)
with gr.Column():
gr.Markdown("## Output Image")
image_output = gr.Image(
label="Output Image",
format="png",
)
with gr.Column():
gr.Markdown("## Output Video")
video_output = gr.Video(
label="Output Video",
format="mp4",
show_download_button=True,
streaming=True,
autoplay=True,
loop=False,
)
with gr.Row():
with gr.Column():
gr.Markdown("## Prompts")
prompt = gr.Textbox(label="Prompt:", lines=1)
negative_prompt = gr.Textbox(label="Negative Prompt:", lines=1)
with gr.Column():
gr.Markdown("## Settings")
true_cfg_scale = gr.Slider(
0,
20,
value=4.0,
step=0.1,
interactive=True,
label="True CFG scale:",
)
num_inference_steps = gr.Slider(
1,
300,
value=50,
step=1,
interactive=True,
label="Number of denoising steps:",
)
image_width = gr.Slider(
128,
1024,
value=512,
step=16,
interactive=True,
label="Image Width:",
)
image_height = gr.Slider(
128,
1024,
value=800,
step=16,
interactive=True,
label="Image Height:",
)
with gr.Row():
run_button = gr.Button("Run")
model = Model()
def process_images(
images,
prompt,
negative_prompt,
true_cfg_scale,
num_inference_steps,
image_width,
image_height,
):
"""Wrapper function to handle errors gracefully"""
pil_images = []
for contents in images:
for content in contents:
if isinstance(content, PIL.Image.Image):
pil_images.append(content)
break
try:
return model.compute(
pil_images,
prompt,
negative_prompt,
true_cfg_scale,
num_inference_steps,
image_width,
image_height,
)
except Exception as e:
print(f"Error processing images: {e}")
raise gr.Error(f"Failed to process images: {str(e)}")
# Connect the button to the detection function
run_button.click(
fn=process_images,
inputs=[
image_inputs,
prompt,
negative_prompt,
true_cfg_scale,
num_inference_steps,
image_width,
image_height,
],
outputs=[
image_output,
video_output,
],
)
if __name__ == "__main__":
demo.launch(
allowed_paths=["output/video.mp4"],
share=True,
)