OOM on 32gb / diffusers

#25
by xtat - opened

Is there something special that I need to do to get it to fit in 32gb on diffusers?

I am using the example code from the diffusers PR

import torch
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image

pipe = LTX2ImageToVideoPipeline.from_pretrained(
    "Lightricks/LTX-2", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
prompt = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outwa>
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transitio>

frame_rate = 24.0
video, audio = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=768,
    height=512,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=20,
    guidance_scale=4.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_sample.mp4",
)

These setups works for me:

  1. Use SDNQ 4bit quantization:
    https://huggingface.co/Disty0/LTX-2-SDNQ-4bit-dynamic
  2. Or, use this code to load the 8bit quantization:
import os
import torch


from diffusers import LTX2Pipeline
from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformer3DModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Required environment variable for fp8
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print("Downloading pre-quantized FP8 transformer weights...")

# Download the fp8 safetensors file
fp8_weights_path = hf_hub_download(
    repo_id="Lightricks/LTX-2",
    filename="ltx-2-19b-dev-fp8.safetensors",
)

print(f"Downloaded to: {fp8_weights_path}")
print("Loading transformer model structure...")

# First, load the model structure from the config
transformer = LTX2VideoTransformer3DModel.from_pretrained(
    "Lightricks/LTX-2",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)

print("Loading fp8 weights into model...")

# Load the fp8 weights
fp8_state_dict = load_file(fp8_weights_path)

# Load the weights into the model
transformer.load_state_dict(fp8_state_dict, strict=False)

print("Loading full pipeline with pre-quantized transformer...")
pipeline = LTX2Pipeline.from_pretrained(
    "Lightricks/LTX-2",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

# Enable CPU offload for additional memory savings
pipeline.enable_model_cpu_offload()

print("Pipeline loaded with pre-quantized FP8 weights!")

I think option 1 is better for a low VRAM GPU.
Good Luck.

Sign up or log in to comment