File size: 10,371 Bytes
4960ef6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
"""
Model Loader Utilities
模型加载工具 - 用于加载各种模型组件
"""
import torch
import json
import safetensors.torch
from typing import Optional
def load_unet_from_safetensors(unet_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
"""
Load UNet from safetensors file
从 safetensors 文件加载 UNet
Args:
unet_path: Path to UNet safetensors file
config_path: Path to UNet config JSON file
device: Device to load model on
dtype: Data type for model weights
Returns:
UNet2DConditionModel or None if loading fails
"""
try:
from diffusers import UNet2DConditionModel
# Load config
with open(config_path, 'r') as f:
unet_config = json.load(f)
# Create UNet
unet = UNet2DConditionModel.from_config(unet_config)
# Load weights
state_dict = safetensors.torch.load_file(unet_path)
unet.load_state_dict(state_dict)
unet.to(device, dtype)
return unet
except Exception as e:
print(f"Error loading UNet: {e}")
return None
def load_vae_from_safetensors(vae_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
"""
Load VAE from safetensors file
从 safetensors 文件加载 VAE
Args:
vae_path: Path to VAE safetensors file
config_path: Path to VAE config JSON file
device: Device to load model on
dtype: Data type for model weights
Returns:
AutoencoderKL or None if loading fails
"""
try:
from diffusers import AutoencoderKL
# Load config
with open(config_path, 'r') as f:
vae_config = json.load(f)
# Create VAE
vae = AutoencoderKL.from_config(vae_config)
# Load weights
state_dict = safetensors.torch.load_file(vae_path)
vae.load_state_dict(state_dict)
vae.to(device, dtype)
return vae
except Exception as e:
print(f"Error loading VAE: {e}")
return None
def create_scheduler(scheduler_type: str = "EulerAncestral", model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
"""
Create scheduler for diffusion process
创建扩散过程调度器
Args:
scheduler_type: Type of scheduler to create
model_id: Model ID to load scheduler config from
Returns:
Scheduler object or None if creation fails
"""
try:
if scheduler_type == "DDPM":
from diffusers import DDPMScheduler
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
elif scheduler_type == "DDIM":
from diffusers import DDIMScheduler
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
elif scheduler_type == "DPMSolverMultistep":
from diffusers import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
elif scheduler_type == "EulerAncestral":
from diffusers import EulerAncestralDiscreteScheduler
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
else:
print(f"Unsupported scheduler type: {scheduler_type}, using DDPM")
from diffusers import DDPMScheduler
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
return scheduler
except Exception as e:
print(f"Error creating scheduler: {e}")
return None
def load_qwen_model(model_path: str, device: str = "cuda"):
"""
Load Qwen3 embedding model
加载 Qwen3 嵌入模型
Args:
model_path: Path to Qwen model
device: Device to load model on
Returns:
SentenceTransformer model or None if loading fails
"""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_path)
model.to(device)
return model
except ImportError:
print("Warning: sentence-transformers not available. Using mock embeddings.")
return None
except Exception as e:
print(f"Error loading Qwen model: {e}")
return None
def save_model_components(
unet,
vae,
adapter,
text_encoder,
save_dir: str,
save_format: str = "safetensors"
):
"""
Save model components for training checkpoints
保存模型组件用于训练检查点
Args:
unet: UNet model
vae: VAE model
adapter: Qwen embedding adapter
text_encoder: Qwen text encoder
save_dir: Directory to save components
save_format: Format to save in (safetensors or pt)
"""
import os
os.makedirs(save_dir, exist_ok=True)
try:
if save_format == "safetensors":
# Save UNet
if unet is not None:
safetensors.torch.save_file(
unet.state_dict(),
os.path.join(save_dir, "unet.safetensors")
)
# Save VAE
if vae is not None:
safetensors.torch.save_file(
vae.state_dict(),
os.path.join(save_dir, "vae.safetensors")
)
# Save adapter
if adapter is not None:
safetensors.torch.save_file(
adapter.state_dict(),
os.path.join(save_dir, "adapter.safetensors")
)
else: # PyTorch format
if unet is not None:
torch.save(unet.state_dict(), os.path.join(save_dir, "unet.pt"))
if vae is not None:
torch.save(vae.state_dict(), os.path.join(save_dir, "vae.pt"))
if adapter is not None:
torch.save(adapter.state_dict(), os.path.join(save_dir, "adapter.pt"))
print(f"Model components saved to {save_dir}")
except Exception as e:
print(f"Error saving model components: {e}")
def load_unet_with_lora(
unet_path: str,
unet_config_path: str,
lora_weights_path: Optional[str] = None,
lora_config_path: Optional[str] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16
):
"""
Load UNet with optional LoRA weights
加载带有可选LoRA权重的UNet
Args:
base_unet_path: Path to base UNet (can be safetensors file or HF model path)
lora_weights_path: Optional path to LoRA weights (safetensors file)
lora_config_path: Optional path to LoRA config directory
device: Device to load model on
dtype: Data type for model weights
Returns:
UNet model with LoRA applied if specified
"""
try:
from diffusers import UNet2DConditionModel
from peft import PeftModel, LoraConfig
# Load base UNet
# if unet_path.endswith(".safetensors"):
# # Load from safetensors file (need config too)
# print("Loading UNet from safetensors format requires separate config file")
# return None
# else:
# Load from HuggingFace model path
# unet = UNet2DConditionModel.from_pretrained(
# base_unet_path,
# subfolder="unet" if "/" in base_unet_path and not base_unet_path.endswith("unet") else None,
# torch_dtype=dtype
# )
unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
# Apply LoRA if provided
if lora_weights_path and lora_config_path:
print(f"Loading LoRA weights from {lora_weights_path}")
# Load LoRA weights
if lora_weights_path.endswith(".safetensors"):
import safetensors.torch
lora_state_dict = safetensors.torch.load_file(lora_weights_path)
else:
lora_state_dict = torch.load(lora_weights_path, map_location=device)
# Load LoRA config
lora_config = LoraConfig.from_pretrained(lora_config_path)
# Apply LoRA to UNet
from peft import get_peft_model, set_peft_model_state_dict
unet = get_peft_model(unet, lora_config)
set_peft_model_state_dict(unet, lora_state_dict)
print("LoRA weights applied to UNet")
unet.to(device, dtype)
return unet
except Exception as e:
print(f"Error loading UNet with LoRA: {e}")
return None
def load_fused_unet(
fused_unet_path: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16
):
"""
Load UNet with fused LoRA weights
加载融合了LoRA权重的UNet
Args:
fused_unet_path: Path to fused UNet model directory
device: Device to load model on
dtype: Data type for model weights
Returns:
UNet model with fused LoRA weights
"""
try:
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(
fused_unet_path,
torch_dtype=dtype
)
unet.to(device, dtype)
print(f"Fused UNet loaded from {fused_unet_path}")
return unet
except Exception as e:
print(f"Error loading fused UNet: {e}")
return None
def load_checkpoint(checkpoint_path: str, device: str = "cuda"):
"""
Load training checkpoint
加载训练检查点
Args:
checkpoint_path: Path to checkpoint file
device: Device to load on
Returns:
Dictionary containing checkpoint data
"""
try:
if checkpoint_path.endswith(".safetensors"):
return safetensors.torch.load_file(checkpoint_path, device=device)
else:
return torch.load(checkpoint_path, map_location=device)
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
|