qwenillustrious / arch /pipeline.py
lsmpp's picture
Add files using upload-large-folder tool
4960ef6 verified
"""
Qwen-SDXL Inference Pipeline
Qwen-SDXL 推理管道 - 使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道
"""
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from typing import List, Optional, Union, Tuple
from .adapter import QwenEmbeddingAdapter
from .text_encoder import QwenTextEncoder
from .model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler
class QwenIllustriousInference:
"""
Qwen-SDXL 推理管道
使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道
"""
def __init__(
self,
qwen_model_path: str = "models/Qwen3-Embedding-0.6B",
unet_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors",
unet_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet_config.json",
vae_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors",
vae_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae_config.json",
adapter_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/adapter/adapter.safetensors",
lora_weights_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/lora_weights.safetensors",
lora_config_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/adapter_config.json",
use_fused_unet: bool = False,
fused_unet_path: Optional[str] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
scheduler_type: str = "DDPM"
):
self.device = device
self.dtype = dtype
self.vae_scale_factor = 8 # SDXL default
print("🚀 初始化 Qwen-SDXL 推理管道...")
# Initialize text encoder
print("📝 初始化 Qwen 文本编码器...")
self.text_encoder = QwenTextEncoder(
model_path=qwen_model_path,
device=device,
freeze_encoder=True
)
# Initialize adapter layer
print("🔧 初始化适配器层...")
self.adapter = QwenEmbeddingAdapter()
self.adapter.to(device, dtype)
# Load adapter weights if provided
if adapter_path is not None:
print(f"📥 加载适配器权重: {adapter_path}")
try:
if adapter_path.endswith(".safetensors"):
import safetensors.torch
adapter_state = safetensors.torch.load_file(adapter_path)
else:
adapter_state = torch.load(adapter_path, map_location=device)
self.adapter.load_state_dict(adapter_state)
except Exception as e:
print(f"⚠️ 加载适配器权重失败: {e}")
# Load UNet (with LoRA support)
print("🏗️ 加载 UNet 模型...")
from .model_loader import load_unet_with_lora, load_fused_unet
if use_fused_unet and fused_unet_path:
# Load fused UNet with merged LoRA weights
print("📦 使用融合LoRA权重的UNet...")
self.unet = load_fused_unet(fused_unet_path, device, dtype)
elif lora_weights_path and lora_config_path:
# Load UNet with separate LoRA weights
print("🔧 加载UNet并应用LoRA权重...")
# For this case, use the base SDXL model path instead of safetensors
# base_model_path = unet_path.replace("/unet.safetensors", "").replace("/extracted_components/waiNSFWIllustrious_v140_unet.safetensors", "")
self.unet = load_unet_with_lora(
unet_path=unet_path,
unet_config_path=unet_config_path,
lora_weights_path=lora_weights_path,
lora_config_path=lora_config_path,
device=device,
dtype=dtype
)
else:
# Load standard UNet from safetensors
self.unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
# Load VAE
print("🎨 加载 VAE 模型...")
self.vae = load_vae_from_safetensors(vae_path, vae_config_path, device, dtype)
# Initialize scheduler
print(f"⏰ 创建调度器 ({scheduler_type})...")
self.scheduler = create_scheduler(scheduler_type)
# Check if all components loaded successfully
self.is_ready = all([
self.text_encoder is not None,
self.adapter is not None,
self.unet is not None,
self.vae is not None,
self.scheduler is not None
])
if self.is_ready:
print("✅ 管道初始化完成!")
else:
print("❌ 管道初始化失败,某些组件加载失败")
def encode_prompts(
self,
prompts: List[str],
negative_prompts: Optional[List[str]] = None,
do_classifier_free_guidance: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encode prompts using Qwen3 + adapter
使用 Qwen3 + 适配器编码提示词
"""
# Get raw embeddings from Qwen
text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts(
prompts, negative_prompts, do_classifier_free_guidance
)
batch_size = len(prompts)
if do_classifier_free_guidance:
batch_size *= 2
# Add sequence dimension for text embeddings (We uses 512 tokens for SDXL)
seq_len = 512
text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1) # [B, 512, 1024]
# Project to SDXL dimensions using adapter
prompt_embeds = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype)) # [B, 512, 2048]
pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype)) # [B, 1280]
return prompt_embeds, pooled_prompt_embeds
def prepare_latents(
self,
batch_size: int,
height: int,
width: int,
generator: Optional[torch.Generator] = None
) -> torch.Tensor:
"""
Prepare initial latents
准备初始潜在变量
"""
if self.unet is None:
# Mock latents for testing
shape = (batch_size, 4, height // self.vae_scale_factor, width // self.vae_scale_factor)
return torch.randn(shape, device=self.device, dtype=self.dtype)
shape = (
batch_size,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
try:
from diffusers.utils import randn_tensor
latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
except ImportError:
latents = torch.randn(shape, device=self.device, dtype=self.dtype, generator=generator)
# Scale initial noise
if self.scheduler is not None:
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_time_ids(
self,
height: int,
width: int,
original_size: Tuple[int, int],
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None
) -> torch.Tensor:
"""
Get SDXL time IDs for micro-conditioning
获取 SDXL 时间 ID 用于微调节
"""
if target_size is None:
target_size = (height, width)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype, device=self.device)
return add_time_ids
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
generator: Optional[torch.Generator] = None,
return_type: str = "pil"
) -> List[Image.Image]:
"""
Generate images using Qwen-SDXL pipeline
使用 Qwen-SDXL 管道生成图像
"""
if not self.is_ready:
print("❌ 管道未准备就绪,无法生成图像")
return []
# Prepare prompts
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
batch_size = len(prompt)
do_classifier_free_guidance = guidance_scale > 1.0
print(f"🎯 开始生成 {batch_size} 张图像...")
print(f"📏 尺寸: {width}x{height}")
print(f"🔄 推理步数: {num_inference_steps}")
print(f"🎚️ 引导强度: {guidance_scale}")
# 1. Encode prompts
print("📝 编码提示词...")
prompt_embeds, pooled_prompt_embeds = self.encode_prompts(
prompt, negative_prompt, do_classifier_free_guidance
)
# 2. Prepare timesteps
print("⏰ 准备时间步...")
if self.scheduler is not None:
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
else:
timesteps = torch.linspace(1000, 0, num_inference_steps, device=self.device)
# 3. Prepare latents
print("🌀 准备潜在变量...")
latents = self.prepare_latents(batch_size, height, width, generator)
# 4. Prepare time IDs
original_size = (height, width)
target_size = (height, width)
add_time_ids = self.get_time_ids(height, width, original_size, target_size=target_size)
if do_classifier_free_guidance:
add_time_ids = add_time_ids.repeat(2, 1)
add_time_ids = add_time_ids.repeat(batch_size, 1)
# 5. Denoising loop
print("🔄 开始去噪过程...")
for i, t in enumerate(timesteps):
# Expand latents for classifier-free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if self.scheduler is not None:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Predict noise
if self.unet is not None:
added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds,
"time_ids": add_time_ids
}
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# Classifier-free guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Scheduler step
if self.scheduler is not None:
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if (i + 1) % 5 == 0:
print(f" 步骤 {i+1}/{len(timesteps)} 完成")
# 6. Decode latents
print("🎨 解码生成图像...")
if self.vae is not None:
latents = latents / self.vae.config.scaling_factor
images = self.vae.decode(latents, return_dict=False)[0]
else:
# Mock image generation for testing
images = torch.randn(batch_size, 3, height, width, device=self.device)
# 7. Convert to PIL images
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
if return_type == "pil":
images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]
print("✅ 图像生成完成!")
return images
def save_adapter(self, save_path: str):
"""
Save adapter weights
保存适配器权重
"""
try:
if save_path.endswith(".safetensors"):
import safetensors.torch
safetensors.torch.save_file(self.adapter.state_dict(), save_path)
else:
torch.save(self.adapter.state_dict(), save_path)
print(f"✅ 适配器权重已保存到: {save_path}")
except Exception as e:
print(f"❌ 保存适配器权重失败: {e}")
def load_adapter(self, load_path: str):
"""
Load adapter weights
加载适配器权重
"""
try:
if load_path.endswith(".safetensors"):
import safetensors.torch
state_dict = safetensors.torch.load_file(load_path)
else:
state_dict = torch.load(load_path, map_location=self.device)
self.adapter.load_state_dict(state_dict)
print(f"✅ 适配器权重已从 {load_path} 加载")
except Exception as e:
print(f"❌ 加载适配器权重失败: {e}")