qwenillustrious / arch /model_loader.py
lsmpp's picture
Add files using upload-large-folder tool
4960ef6 verified
"""
Model Loader Utilities
模型加载工具 - 用于加载各种模型组件
"""
import torch
import json
import safetensors.torch
from typing import Optional
def load_unet_from_safetensors(unet_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
"""
Load UNet from safetensors file
从 safetensors 文件加载 UNet
Args:
unet_path: Path to UNet safetensors file
config_path: Path to UNet config JSON file
device: Device to load model on
dtype: Data type for model weights
Returns:
UNet2DConditionModel or None if loading fails
"""
try:
from diffusers import UNet2DConditionModel
# Load config
with open(config_path, 'r') as f:
unet_config = json.load(f)
# Create UNet
unet = UNet2DConditionModel.from_config(unet_config)
# Load weights
state_dict = safetensors.torch.load_file(unet_path)
unet.load_state_dict(state_dict)
unet.to(device, dtype)
return unet
except Exception as e:
print(f"Error loading UNet: {e}")
return None
def load_vae_from_safetensors(vae_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
"""
Load VAE from safetensors file
从 safetensors 文件加载 VAE
Args:
vae_path: Path to VAE safetensors file
config_path: Path to VAE config JSON file
device: Device to load model on
dtype: Data type for model weights
Returns:
AutoencoderKL or None if loading fails
"""
try:
from diffusers import AutoencoderKL
# Load config
with open(config_path, 'r') as f:
vae_config = json.load(f)
# Create VAE
vae = AutoencoderKL.from_config(vae_config)
# Load weights
state_dict = safetensors.torch.load_file(vae_path)
vae.load_state_dict(state_dict)
vae.to(device, dtype)
return vae
except Exception as e:
print(f"Error loading VAE: {e}")
return None
def create_scheduler(scheduler_type: str = "EulerAncestral", model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
"""
Create scheduler for diffusion process
创建扩散过程调度器
Args:
scheduler_type: Type of scheduler to create
model_id: Model ID to load scheduler config from
Returns:
Scheduler object or None if creation fails
"""
try:
if scheduler_type == "DDPM":
from diffusers import DDPMScheduler
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
elif scheduler_type == "DDIM":
from diffusers import DDIMScheduler
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
elif scheduler_type == "DPMSolverMultistep":
from diffusers import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
elif scheduler_type == "EulerAncestral":
from diffusers import EulerAncestralDiscreteScheduler
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
else:
print(f"Unsupported scheduler type: {scheduler_type}, using DDPM")
from diffusers import DDPMScheduler
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
return scheduler
except Exception as e:
print(f"Error creating scheduler: {e}")
return None
def load_qwen_model(model_path: str, device: str = "cuda"):
"""
Load Qwen3 embedding model
加载 Qwen3 嵌入模型
Args:
model_path: Path to Qwen model
device: Device to load model on
Returns:
SentenceTransformer model or None if loading fails
"""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_path)
model.to(device)
return model
except ImportError:
print("Warning: sentence-transformers not available. Using mock embeddings.")
return None
except Exception as e:
print(f"Error loading Qwen model: {e}")
return None
def save_model_components(
unet,
vae,
adapter,
text_encoder,
save_dir: str,
save_format: str = "safetensors"
):
"""
Save model components for training checkpoints
保存模型组件用于训练检查点
Args:
unet: UNet model
vae: VAE model
adapter: Qwen embedding adapter
text_encoder: Qwen text encoder
save_dir: Directory to save components
save_format: Format to save in (safetensors or pt)
"""
import os
os.makedirs(save_dir, exist_ok=True)
try:
if save_format == "safetensors":
# Save UNet
if unet is not None:
safetensors.torch.save_file(
unet.state_dict(),
os.path.join(save_dir, "unet.safetensors")
)
# Save VAE
if vae is not None:
safetensors.torch.save_file(
vae.state_dict(),
os.path.join(save_dir, "vae.safetensors")
)
# Save adapter
if adapter is not None:
safetensors.torch.save_file(
adapter.state_dict(),
os.path.join(save_dir, "adapter.safetensors")
)
else: # PyTorch format
if unet is not None:
torch.save(unet.state_dict(), os.path.join(save_dir, "unet.pt"))
if vae is not None:
torch.save(vae.state_dict(), os.path.join(save_dir, "vae.pt"))
if adapter is not None:
torch.save(adapter.state_dict(), os.path.join(save_dir, "adapter.pt"))
print(f"Model components saved to {save_dir}")
except Exception as e:
print(f"Error saving model components: {e}")
def load_unet_with_lora(
unet_path: str,
unet_config_path: str,
lora_weights_path: Optional[str] = None,
lora_config_path: Optional[str] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16
):
"""
Load UNet with optional LoRA weights
加载带有可选LoRA权重的UNet
Args:
base_unet_path: Path to base UNet (can be safetensors file or HF model path)
lora_weights_path: Optional path to LoRA weights (safetensors file)
lora_config_path: Optional path to LoRA config directory
device: Device to load model on
dtype: Data type for model weights
Returns:
UNet model with LoRA applied if specified
"""
try:
from diffusers import UNet2DConditionModel
from peft import PeftModel, LoraConfig
# Load base UNet
# if unet_path.endswith(".safetensors"):
# # Load from safetensors file (need config too)
# print("Loading UNet from safetensors format requires separate config file")
# return None
# else:
# Load from HuggingFace model path
# unet = UNet2DConditionModel.from_pretrained(
# base_unet_path,
# subfolder="unet" if "/" in base_unet_path and not base_unet_path.endswith("unet") else None,
# torch_dtype=dtype
# )
unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
# Apply LoRA if provided
if lora_weights_path and lora_config_path:
print(f"Loading LoRA weights from {lora_weights_path}")
# Load LoRA weights
if lora_weights_path.endswith(".safetensors"):
import safetensors.torch
lora_state_dict = safetensors.torch.load_file(lora_weights_path)
else:
lora_state_dict = torch.load(lora_weights_path, map_location=device)
# Load LoRA config
lora_config = LoraConfig.from_pretrained(lora_config_path)
# Apply LoRA to UNet
from peft import get_peft_model, set_peft_model_state_dict
unet = get_peft_model(unet, lora_config)
set_peft_model_state_dict(unet, lora_state_dict)
print("LoRA weights applied to UNet")
unet.to(device, dtype)
return unet
except Exception as e:
print(f"Error loading UNet with LoRA: {e}")
return None
def load_fused_unet(
fused_unet_path: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16
):
"""
Load UNet with fused LoRA weights
加载融合了LoRA权重的UNet
Args:
fused_unet_path: Path to fused UNet model directory
device: Device to load model on
dtype: Data type for model weights
Returns:
UNet model with fused LoRA weights
"""
try:
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(
fused_unet_path,
torch_dtype=dtype
)
unet.to(device, dtype)
print(f"Fused UNet loaded from {fused_unet_path}")
return unet
except Exception as e:
print(f"Error loading fused UNet: {e}")
return None
def load_checkpoint(checkpoint_path: str, device: str = "cuda"):
"""
Load training checkpoint
加载训练检查点
Args:
checkpoint_path: Path to checkpoint file
device: Device to load on
Returns:
Dictionary containing checkpoint data
"""
try:
if checkpoint_path.endswith(".safetensors"):
return safetensors.torch.load_file(checkpoint_path, device=device)
else:
return torch.load(checkpoint_path, map_location=device)
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None