|
|
import shutil |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import PIL.Image |
|
|
import torch |
|
|
from diffusers import ( |
|
|
DiffusionPipeline, |
|
|
QwenImageEditPlusPipeline, |
|
|
) |
|
|
|
|
|
|
|
|
from nunchaku import NunchakuQwenImageTransformer2DModel |
|
|
from nunchaku.utils import get_gpu_memory, get_precision |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
RANK = 128 |
|
|
TRANSFORMER_ID = f"nunchaku-tech/nunchaku-qwen-image-edit-2509/svdq-{get_precision()}_r{RANK}-qwen-image-edit-2509.safetensors" |
|
|
PIPELINE_ID = "Qwen/Qwen-Image-Edit-2509" |
|
|
|
|
|
OUTPUT_DIR = Path(__file__).parent / "output" |
|
|
IMAGES_DIR = OUTPUT_DIR / "images" |
|
|
IMAGES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
VIDEO_PATH = OUTPUT_DIR / "video.mp4" |
|
|
|
|
|
|
|
|
class Model: |
|
|
def __init__(self): |
|
|
self.progress = gr.Progress() |
|
|
self.num_inference_steps = 50 |
|
|
self.current_inference_step = 0 |
|
|
|
|
|
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained( |
|
|
TRANSFORMER_ID |
|
|
) |
|
|
|
|
|
pipeline = QwenImageEditPlusPipeline.from_pretrained( |
|
|
PIPELINE_ID, |
|
|
transformer=transformer, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
if get_gpu_memory() > 18: |
|
|
pipeline.enable_model_cpu_offload() |
|
|
else: |
|
|
transformer.set_offload( |
|
|
True, |
|
|
use_pin_memory=False, |
|
|
num_blocks_on_gpu=1, |
|
|
) |
|
|
pipeline._exclude_from_cpu_offload.append("transformer") |
|
|
pipeline.enable_sequential_cpu_offload() |
|
|
|
|
|
self.pipeline = pipeline |
|
|
|
|
|
def compute( |
|
|
self, |
|
|
images: list[PIL.Image.Image], |
|
|
prompt: str, |
|
|
negative_prompt: str = " ", |
|
|
true_cfg_scale: float = 4.0, |
|
|
num_inference_steps: int = 40, |
|
|
image_width: int = 512, |
|
|
image_height: int = 512, |
|
|
) -> tuple[PIL.Image.Image, Path]: |
|
|
self.num_inference_steps = num_inference_steps |
|
|
self.current_inference_step = 0 |
|
|
self.progress((self.current_inference_step, self.num_inference_steps)) |
|
|
|
|
|
shutil.rmtree(IMAGES_DIR, ignore_errors=True) |
|
|
IMAGES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if not images: |
|
|
raise gr.Error("No images provided. Please upload at least one image.") |
|
|
|
|
|
|
|
|
processed_images = [] |
|
|
for i, img in enumerate(images): |
|
|
if img is None: |
|
|
raise gr.Error(f"Image {i + 1} is invalid or could not be loaded.") |
|
|
processed_images.append(img) |
|
|
|
|
|
inputs = dict( |
|
|
image=processed_images, |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
true_cfg_scale=true_cfg_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
width=image_width, |
|
|
height=image_height, |
|
|
generator=torch.manual_seed(0), |
|
|
callback_on_step_end=self.callback, |
|
|
|
|
|
) |
|
|
output = self.pipeline(**inputs) |
|
|
output_image = output.images[0] |
|
|
|
|
|
|
|
|
|
|
|
print(list(IMAGES_DIR.glob("*.png"))) |
|
|
|
|
|
|
|
|
image_files = sorted(IMAGES_DIR.glob("step_*.png")) |
|
|
|
|
|
if image_files: |
|
|
|
|
|
first_img = cv2.imread(str(image_files[0])) |
|
|
height, width, _ = first_img.shape |
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
fps = 10 |
|
|
video_writer = cv2.VideoWriter( |
|
|
str(VIDEO_PATH.absolute()), fourcc, fps, (width, height) |
|
|
) |
|
|
|
|
|
|
|
|
for img_path in image_files: |
|
|
img = cv2.imread(str(img_path)) |
|
|
video_writer.write(img) |
|
|
|
|
|
video_writer.release() |
|
|
print(f"Video saved to: {VIDEO_PATH}") |
|
|
|
|
|
time.sleep(3) |
|
|
|
|
|
return output_image, VIDEO_PATH |
|
|
|
|
|
def callback( |
|
|
self, |
|
|
pipeline: DiffusionPipeline, |
|
|
step: int, |
|
|
timestep: int, |
|
|
callback_kwargs: dict, |
|
|
): |
|
|
latents = callback_kwargs.get("latents", None) |
|
|
|
|
|
height = callback_kwargs.get("height", 800) |
|
|
width = callback_kwargs.get("width", 512) |
|
|
|
|
|
if latents is not None: |
|
|
print(f"Latents shape: {latents.shape}, dtype: {latents.dtype}") |
|
|
|
|
|
latents = pipeline._unpack_latents( |
|
|
latents, height, width, pipeline.vae_scale_factor |
|
|
) |
|
|
latents = latents.to(pipeline.vae.dtype) |
|
|
latents_mean = ( |
|
|
torch.tensor(pipeline.vae.config.latents_mean) |
|
|
.view(1, pipeline.vae.config.z_dim, 1, 1, 1) |
|
|
.to(latents.device, latents.dtype) |
|
|
) |
|
|
latents_std = 1.0 / torch.tensor(pipeline.vae.config.latents_std).view( |
|
|
1, pipeline.vae.config.z_dim, 1, 1, 1 |
|
|
).to(latents.device, latents.dtype) |
|
|
latents = latents / latents_std + latents_mean |
|
|
image = pipeline.vae.decode(latents, return_dict=False)[0][:, :, 0] |
|
|
image = pipeline.image_processor.postprocess(image, output_type="pil") |
|
|
image = image[0] |
|
|
|
|
|
image.save(IMAGES_DIR / f"step_{step:03d}.png") |
|
|
|
|
|
self.current_inference_step += 1 |
|
|
self.progress((self.current_inference_step, self.num_inference_steps)) |
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Nunchaku Qwen-Image-Edit-2509") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("## Input Images") |
|
|
|
|
|
image_inputs = gr.Gallery( |
|
|
label="Input Images", |
|
|
show_label=True, |
|
|
elem_id="gallery", |
|
|
columns=3, |
|
|
rows=2, |
|
|
object_fit="contain", |
|
|
height="auto", |
|
|
type="pil", |
|
|
allow_preview=True, |
|
|
interactive=True, |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## Output Image") |
|
|
|
|
|
image_output = gr.Image( |
|
|
label="Output Image", |
|
|
format="png", |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## Output Video") |
|
|
|
|
|
video_output = gr.Video( |
|
|
label="Output Video", |
|
|
format="mp4", |
|
|
show_download_button=True, |
|
|
streaming=True, |
|
|
autoplay=True, |
|
|
loop=False, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("## Prompts") |
|
|
|
|
|
prompt = gr.Textbox(label="Prompt:", lines=1) |
|
|
negative_prompt = gr.Textbox(label="Negative Prompt:", lines=1) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## Settings") |
|
|
|
|
|
true_cfg_scale = gr.Slider( |
|
|
0, |
|
|
20, |
|
|
value=4.0, |
|
|
step=0.1, |
|
|
interactive=True, |
|
|
label="True CFG scale:", |
|
|
) |
|
|
|
|
|
num_inference_steps = gr.Slider( |
|
|
1, |
|
|
300, |
|
|
value=50, |
|
|
step=1, |
|
|
interactive=True, |
|
|
label="Number of denoising steps:", |
|
|
) |
|
|
|
|
|
image_width = gr.Slider( |
|
|
128, |
|
|
1024, |
|
|
value=512, |
|
|
step=16, |
|
|
interactive=True, |
|
|
label="Image Width:", |
|
|
) |
|
|
|
|
|
image_height = gr.Slider( |
|
|
128, |
|
|
1024, |
|
|
value=800, |
|
|
step=16, |
|
|
interactive=True, |
|
|
label="Image Height:", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
run_button = gr.Button("Run") |
|
|
|
|
|
model = Model() |
|
|
|
|
|
def process_images( |
|
|
images, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
true_cfg_scale, |
|
|
num_inference_steps, |
|
|
image_width, |
|
|
image_height, |
|
|
): |
|
|
"""Wrapper function to handle errors gracefully""" |
|
|
|
|
|
pil_images = [] |
|
|
|
|
|
for contents in images: |
|
|
for content in contents: |
|
|
if isinstance(content, PIL.Image.Image): |
|
|
pil_images.append(content) |
|
|
break |
|
|
|
|
|
try: |
|
|
return model.compute( |
|
|
pil_images, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
true_cfg_scale, |
|
|
num_inference_steps, |
|
|
image_width, |
|
|
image_height, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error processing images: {e}") |
|
|
raise gr.Error(f"Failed to process images: {str(e)}") |
|
|
|
|
|
|
|
|
run_button.click( |
|
|
fn=process_images, |
|
|
inputs=[ |
|
|
image_inputs, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
true_cfg_scale, |
|
|
num_inference_steps, |
|
|
image_width, |
|
|
image_height, |
|
|
], |
|
|
outputs=[ |
|
|
image_output, |
|
|
video_output, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
allowed_paths=["output/video.mp4"], |
|
|
share=True, |
|
|
) |
|
|
|