Code for running <= 24GB cards

#8
by rvorias - opened

Tested with an RTX 5000 24GB.
It's slow though.

from diffusers import QwenImageTransformer2DModel
import torch
from PIL import Image
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from transformers import Qwen2_5_VLForConditionalGeneration

from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import QwenImageLayeredPipeline


model_id = "Qwen/Qwen-Image-Layered"
torch_dtype = torch.bfloat16

quantization_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
transformer = QwenImageTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=torch_dtype,
)
transformer = transformer.to("cpu")

quantization_config = TransformersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    subfolder="text_encoder",
    quantization_config=quantization_config,
    torch_dtype=torch_dtype,
)
text_encoder = text_encoder.to("cpu")

pipeline = QwenImageLayeredPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    text_encoder=text_encoder,
    torch_dtype=torch_dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)

image = Image.open("workdir/bc1e03f40776b8bee006ea2b2b0d8103.webp").convert("RGBA")
inputs = {
    "image": image,
    "generator": torch.Generator(device="cuda").manual_seed(777),
    "true_cfg_scale": 4.0,
    "negative_prompt": " ",
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
    "layers": 4,
    "resolution": 640,  # Using different bucket (640, 1024) to determine the resolution. For this version, 640 is recommended
    "cfg_normalize": True,  # Whether enable cfg normalization.
    "use_en_prompt": True,  # Automatic caption language if user does not provide caption
}

with torch.inference_mode():
    output = pipeline(**inputs)
    output_image = output.images[0]

for i, image in enumerate(output_image):
    image.save(f"{i}.png")

@rvorias can you elaborate on "slow" part? I wanna get 5090 mobile with 24gb vram. Is it usable at all?

use sageattention and triton , u can run full bf16 on 10gb vram 64gb ram, speed almost 7s/it

@rzgar

Thanks for replying!

Just to make sure I understand correctly:
You mentioned using SageAttention + Triton can run full bf16 on 10GB VRAM at ~7s/it.
Does this mean that on an RTX 5090 Mobile with 24GB VRAM, I should be able to run the model:
In full bf16 precision (no quantization needed)
Without CPU offload (everything fits in VRAM)
Significantly faster than 7s/it?
Or would you still recommend using your 4-bit quantization code for 24GB cards?
Thanks!

Sign up or log in to comment