File size: 7,736 Bytes
1e103b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | import torch
from nunchaku.utils import get_gpu_memory, get_precision
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
class QwenBackend:
def __init__(self, model_id, optimized_model_path=None, optimized_edit_model_path=None, uma=False):
self.model_id = model_id
self.optimized_model_path = optimized_model_path
self.optimized_edit_model_path = optimized_edit_model_path
self.uma = uma
self.pipeline = None
self.rank = 32 # Default from example (was 128 in snippet, user example has 32)
# Check snippet: rank = 32 in the example content I read.
def load(self):
print(f"Loading Qwen backend from {self.model_id}...")
if not self.optimized_model_path:
print("Warning: No optimized model path provided for QwenBackend. This requires the Nunchaku optimized model.")
# Scheduler config from example
import math
from diffusers import FlowMatchEulerDiscreteScheduler
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
# Load the base transformer (T2I)
print(f"Loading T2I NunchakuQwenImageTransformer2DModel from {self.optimized_model_path} with FA2...")
transformer_t2i = NunchakuQwenImageTransformer2DModel.from_pretrained(
self.optimized_model_path,
attn_implementation="flash_attention_2"
)
# Load the edit transformer
if self.optimized_edit_model_path:
print(f"Loading Edit NunchakuQwenImageTransformer2DModel from {self.optimized_edit_model_path} with FA2...")
transformer_edit = NunchakuQwenImageTransformer2DModel.from_pretrained(
self.optimized_edit_model_path,
attn_implementation="flash_attention_2"
)
else:
print(f"Using shared transformer for Edit pipeline...")
transformer_edit = transformer_t2i
print(f"Loading QwenImagePipeline from {self.model_id}...")
# Use QwenImagePipeline (T2I)
from diffusers import QwenImagePipeline, QwenImageEditPlusPipeline
text_encoder = None
if self.uma:
print("UMA mode: Loading text_encoder in 8-bit using BitsAndBytes...")
from transformers import BitsAndBytesConfig, AutoModel
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder = AutoModel.from_pretrained(
self.model_id,
subfolder="text_encoder",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# 1. Load Edit Pipeline (To handle processor correctly)
print(f"Loading QwenImageEditPlusPipeline from {self.model_id}...")
pipeline_kwargs = {
"transformer": transformer_edit,
"scheduler": scheduler,
"torch_dtype": torch.bfloat16
}
if text_encoder is not None:
pipeline_kwargs["text_encoder"] = text_encoder
edit_pipeline = QwenImageEditPlusPipeline.from_pretrained(
self.model_id,
**pipeline_kwargs
)
# 2. Create T2I Pipeline sharing components (except transformer if separate)
print("Creating QwenImagePipeline (T2I) with shared components...")
# Ensure we have a text_encoder and tokenizer
if edit_pipeline.text_encoder is None:
print("Text encoder not found in edit_pipeline, loading manually...")
# Load from model_id or subfolder
if text_encoder is None:
from transformers import AutoModel
text_encoder = AutoModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, trust_remote_code=True)
# CRITICAL FIX: Assign it back to the pipeline!
edit_pipeline.register_modules(text_encoder=text_encoder)
else:
text_encoder = edit_pipeline.text_encoder
tokenizer = edit_pipeline.tokenizer
if tokenizer is None:
print("Tokenizer not found in edit_pipeline, loading manually...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", trust_remote_code=True)
edit_pipeline.register_modules(tokenizer=tokenizer)
pipeline = QwenImagePipeline(
transformer=transformer_t2i,
scheduler=edit_pipeline.scheduler,
vae=edit_pipeline.vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# Manually assign processors if needed (though QwenImagePipeline creates its own image_processor)
# pipeline.feature_extractor = edit_pipeline.image_processor
# Logic for offloading / UMA
if self.uma:
print("UMA mode enabled: Text encoder loaded in 8-bit. Moving other components to GPU.")
# Note: 8-bit text encoder is already handled by bitsandbytes (on GPU or offloaded as needed, typically GPU).
# Explicitly move transformers to CUDA
print("Moving T2I Transformer to CUDA...")
transformer_t2i.to("cuda")
if transformer_edit != transformer_t2i:
print("Moving Edit Transformer to CUDA...")
transformer_edit.to("cuda")
# We need to ensure other components (VAE) are on CUDA.
if hasattr(edit_pipeline, "vae") and edit_pipeline.vae:
print("Moving VAE to CUDA...")
edit_pipeline.vae.to("cuda")
# Since we can't call pipeline.to("cuda") generally if 8-bit modules are present (sometimes safe, sometimes not),
# we manually handle it or trust loaded components.
pass
# Note: pipeline (T2I) shares components, so it should be on cuda too.
else:
print("Non-UMA mode: Using aggressive per-layer offloading.")
transformer_t2i.set_offload(
True, use_pin_memory=True, num_blocks_on_gpu=8
)
if self.optimized_edit_model_path:
transformer_edit.set_offload(
True, use_pin_memory=True, num_blocks_on_gpu=8
)
edit_pipeline._exclude_from_cpu_offload.append("transformer")
edit_pipeline.enable_sequential_cpu_offload()
# The T2I pipeline (pipeline) also needs to handle offloading.
# If we manually loaded text_encoder, it might not be attached to edit_pipeline's offload hooks.
# We should enable sequential CPU offload for the T2I pipeline too.
pipeline.enable_sequential_cpu_offload()
if self.optimized_edit_model_path:
pass
self.pipeline = pipeline
self.edit_pipeline = edit_pipeline
return self.pipeline, self.edit_pipeline
|