|
|
""" |
|
|
Model Loader Utilities |
|
|
模型加载工具 - 用于加载各种模型组件 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import safetensors.torch |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
def load_unet_from_safetensors(unet_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16): |
|
|
""" |
|
|
Load UNet from safetensors file |
|
|
从 safetensors 文件加载 UNet |
|
|
|
|
|
Args: |
|
|
unet_path: Path to UNet safetensors file |
|
|
config_path: Path to UNet config JSON file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
UNet2DConditionModel or None if loading fails |
|
|
""" |
|
|
try: |
|
|
from diffusers import UNet2DConditionModel |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
unet_config = json.load(f) |
|
|
|
|
|
|
|
|
unet = UNet2DConditionModel.from_config(unet_config) |
|
|
|
|
|
|
|
|
state_dict = safetensors.torch.load_file(unet_path) |
|
|
unet.load_state_dict(state_dict) |
|
|
unet.to(device, dtype) |
|
|
|
|
|
return unet |
|
|
except Exception as e: |
|
|
print(f"Error loading UNet: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_vae_from_safetensors(vae_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16): |
|
|
""" |
|
|
Load VAE from safetensors file |
|
|
从 safetensors 文件加载 VAE |
|
|
|
|
|
Args: |
|
|
vae_path: Path to VAE safetensors file |
|
|
config_path: Path to VAE config JSON file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
AutoencoderKL or None if loading fails |
|
|
""" |
|
|
try: |
|
|
from diffusers import AutoencoderKL |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
vae_config = json.load(f) |
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_config(vae_config) |
|
|
|
|
|
|
|
|
state_dict = safetensors.torch.load_file(vae_path) |
|
|
vae.load_state_dict(state_dict) |
|
|
vae.to(device, dtype) |
|
|
|
|
|
return vae |
|
|
except Exception as e: |
|
|
print(f"Error loading VAE: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def create_scheduler(scheduler_type: str = "EulerAncestral", model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"): |
|
|
""" |
|
|
Create scheduler for diffusion process |
|
|
创建扩散过程调度器 |
|
|
|
|
|
Args: |
|
|
scheduler_type: Type of scheduler to create |
|
|
model_id: Model ID to load scheduler config from |
|
|
|
|
|
Returns: |
|
|
Scheduler object or None if creation fails |
|
|
""" |
|
|
try: |
|
|
if scheduler_type == "DDPM": |
|
|
from diffusers import DDPMScheduler |
|
|
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") |
|
|
elif scheduler_type == "DDIM": |
|
|
from diffusers import DDIMScheduler |
|
|
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") |
|
|
elif scheduler_type == "DPMSolverMultistep": |
|
|
from diffusers import DPMSolverMultistepScheduler |
|
|
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") |
|
|
elif scheduler_type == "EulerAncestral": |
|
|
from diffusers import EulerAncestralDiscreteScheduler |
|
|
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") |
|
|
else: |
|
|
print(f"Unsupported scheduler type: {scheduler_type}, using DDPM") |
|
|
from diffusers import DDPMScheduler |
|
|
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") |
|
|
|
|
|
return scheduler |
|
|
except Exception as e: |
|
|
print(f"Error creating scheduler: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_qwen_model(model_path: str, device: str = "cuda"): |
|
|
""" |
|
|
Load Qwen3 embedding model |
|
|
加载 Qwen3 嵌入模型 |
|
|
|
|
|
Args: |
|
|
model_path: Path to Qwen model |
|
|
device: Device to load model on |
|
|
|
|
|
Returns: |
|
|
SentenceTransformer model or None if loading fails |
|
|
""" |
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer |
|
|
model = SentenceTransformer(model_path) |
|
|
model.to(device) |
|
|
return model |
|
|
except ImportError: |
|
|
print("Warning: sentence-transformers not available. Using mock embeddings.") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error loading Qwen model: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def save_model_components( |
|
|
unet, |
|
|
vae, |
|
|
adapter, |
|
|
text_encoder, |
|
|
save_dir: str, |
|
|
save_format: str = "safetensors" |
|
|
): |
|
|
""" |
|
|
Save model components for training checkpoints |
|
|
保存模型组件用于训练检查点 |
|
|
|
|
|
Args: |
|
|
unet: UNet model |
|
|
vae: VAE model |
|
|
adapter: Qwen embedding adapter |
|
|
text_encoder: Qwen text encoder |
|
|
save_dir: Directory to save components |
|
|
save_format: Format to save in (safetensors or pt) |
|
|
""" |
|
|
import os |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
|
if save_format == "safetensors": |
|
|
|
|
|
if unet is not None: |
|
|
safetensors.torch.save_file( |
|
|
unet.state_dict(), |
|
|
os.path.join(save_dir, "unet.safetensors") |
|
|
) |
|
|
|
|
|
|
|
|
if vae is not None: |
|
|
safetensors.torch.save_file( |
|
|
vae.state_dict(), |
|
|
os.path.join(save_dir, "vae.safetensors") |
|
|
) |
|
|
|
|
|
|
|
|
if adapter is not None: |
|
|
safetensors.torch.save_file( |
|
|
adapter.state_dict(), |
|
|
os.path.join(save_dir, "adapter.safetensors") |
|
|
) |
|
|
|
|
|
else: |
|
|
if unet is not None: |
|
|
torch.save(unet.state_dict(), os.path.join(save_dir, "unet.pt")) |
|
|
if vae is not None: |
|
|
torch.save(vae.state_dict(), os.path.join(save_dir, "vae.pt")) |
|
|
if adapter is not None: |
|
|
torch.save(adapter.state_dict(), os.path.join(save_dir, "adapter.pt")) |
|
|
|
|
|
print(f"Model components saved to {save_dir}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error saving model components: {e}") |
|
|
|
|
|
|
|
|
def load_unet_with_lora( |
|
|
unet_path: str, |
|
|
unet_config_path: str, |
|
|
lora_weights_path: Optional[str] = None, |
|
|
lora_config_path: Optional[str] = None, |
|
|
device: str = "cuda", |
|
|
dtype: torch.dtype = torch.bfloat16 |
|
|
): |
|
|
""" |
|
|
Load UNet with optional LoRA weights |
|
|
加载带有可选LoRA权重的UNet |
|
|
|
|
|
Args: |
|
|
base_unet_path: Path to base UNet (can be safetensors file or HF model path) |
|
|
lora_weights_path: Optional path to LoRA weights (safetensors file) |
|
|
lora_config_path: Optional path to LoRA config directory |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
UNet model with LoRA applied if specified |
|
|
""" |
|
|
try: |
|
|
from diffusers import UNet2DConditionModel |
|
|
from peft import PeftModel, LoraConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype) |
|
|
|
|
|
|
|
|
if lora_weights_path and lora_config_path: |
|
|
print(f"Loading LoRA weights from {lora_weights_path}") |
|
|
|
|
|
|
|
|
if lora_weights_path.endswith(".safetensors"): |
|
|
import safetensors.torch |
|
|
lora_state_dict = safetensors.torch.load_file(lora_weights_path) |
|
|
else: |
|
|
lora_state_dict = torch.load(lora_weights_path, map_location=device) |
|
|
|
|
|
|
|
|
lora_config = LoraConfig.from_pretrained(lora_config_path) |
|
|
|
|
|
|
|
|
from peft import get_peft_model, set_peft_model_state_dict |
|
|
unet = get_peft_model(unet, lora_config) |
|
|
set_peft_model_state_dict(unet, lora_state_dict) |
|
|
|
|
|
print("LoRA weights applied to UNet") |
|
|
|
|
|
unet.to(device, dtype) |
|
|
return unet |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading UNet with LoRA: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_fused_unet( |
|
|
fused_unet_path: str, |
|
|
device: str = "cuda", |
|
|
dtype: torch.dtype = torch.bfloat16 |
|
|
): |
|
|
""" |
|
|
Load UNet with fused LoRA weights |
|
|
加载融合了LoRA权重的UNet |
|
|
|
|
|
Args: |
|
|
fused_unet_path: Path to fused UNet model directory |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
UNet model with fused LoRA weights |
|
|
""" |
|
|
try: |
|
|
from diffusers import UNet2DConditionModel |
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
fused_unet_path, |
|
|
torch_dtype=dtype |
|
|
) |
|
|
|
|
|
unet.to(device, dtype) |
|
|
print(f"Fused UNet loaded from {fused_unet_path}") |
|
|
return unet |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading fused UNet: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_checkpoint(checkpoint_path: str, device: str = "cuda"): |
|
|
""" |
|
|
Load training checkpoint |
|
|
加载训练检查点 |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint file |
|
|
device: Device to load on |
|
|
|
|
|
Returns: |
|
|
Dictionary containing checkpoint data |
|
|
""" |
|
|
try: |
|
|
if checkpoint_path.endswith(".safetensors"): |
|
|
return safetensors.torch.load_file(checkpoint_path, device=device) |
|
|
else: |
|
|
return torch.load(checkpoint_path, map_location=device) |
|
|
except Exception as e: |
|
|
print(f"Error loading checkpoint: {e}") |
|
|
return None |
|
|
|