OOM on 32gb / diffusers
#25
by
xtat
- opened
Is there something special that I need to do to get it to fit in 32gb on diffusers?
I am using the example code from the diffusers PR
import torch
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image
pipe = LTX2ImageToVideoPipeline.from_pretrained(
"Lightricks/LTX-2", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
prompt = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outwa>
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transitio>
frame_rate = 24.0
video, audio = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
height=512,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=20,
guidance_scale=4.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_sample.mp4",
)
These setups works for me:
- Use SDNQ 4bit quantization:
https://huggingface.co/Disty0/LTX-2-SDNQ-4bit-dynamic - Or, use this code to load the 8bit quantization:
import os
import torch
from diffusers import LTX2Pipeline
from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformer3DModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Required environment variable for fp8
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Downloading pre-quantized FP8 transformer weights...")
# Download the fp8 safetensors file
fp8_weights_path = hf_hub_download(
repo_id="Lightricks/LTX-2",
filename="ltx-2-19b-dev-fp8.safetensors",
)
print(f"Downloaded to: {fp8_weights_path}")
print("Loading transformer model structure...")
# First, load the model structure from the config
transformer = LTX2VideoTransformer3DModel.from_pretrained(
"Lightricks/LTX-2",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
print("Loading fp8 weights into model...")
# Load the fp8 weights
fp8_state_dict = load_file(fp8_weights_path)
# Load the weights into the model
transformer.load_state_dict(fp8_state_dict, strict=False)
print("Loading full pipeline with pre-quantized transformer...")
pipeline = LTX2Pipeline.from_pretrained(
"Lightricks/LTX-2",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
# Enable CPU offload for additional memory savings
pipeline.enable_model_cpu_offload()
print("Pipeline loaded with pre-quantized FP8 weights!")
I think option 1 is better for a low VRAM GPU.
Good Luck.