| """ |
| Qwen-SDXL Inference Pipeline |
| Qwen-SDXL 推理管道 - 使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| from typing import List, Optional, Union, Tuple |
|
|
| from .adapter import QwenEmbeddingAdapter |
| from .text_encoder import QwenTextEncoder |
| from .model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler |
|
|
|
|
| class QwenIllustriousInference: |
| """ |
| Qwen-SDXL 推理管道 |
| 使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道 |
| """ |
| |
| def __init__( |
| self, |
| qwen_model_path: str = "models/Qwen3-Embedding-0.6B", |
| unet_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors", |
| unet_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet_config.json", |
| vae_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors", |
| vae_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae_config.json", |
| adapter_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/adapter/adapter.safetensors", |
| lora_weights_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/lora_weights.safetensors", |
| lora_config_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/adapter_config.json", |
| use_fused_unet: bool = False, |
| fused_unet_path: Optional[str] = None, |
| device: str = "cuda", |
| dtype: torch.dtype = torch.bfloat16, |
| scheduler_type: str = "DDPM" |
| ): |
| self.device = device |
| self.dtype = dtype |
| self.vae_scale_factor = 8 |
| |
| print("🚀 初始化 Qwen-SDXL 推理管道...") |
| |
| |
| print("📝 初始化 Qwen 文本编码器...") |
| self.text_encoder = QwenTextEncoder( |
| model_path=qwen_model_path, |
| device=device, |
| freeze_encoder=True |
| ) |
| |
| |
| print("🔧 初始化适配器层...") |
| self.adapter = QwenEmbeddingAdapter() |
| self.adapter.to(device, dtype) |
| |
| |
| if adapter_path is not None: |
| print(f"📥 加载适配器权重: {adapter_path}") |
| try: |
| if adapter_path.endswith(".safetensors"): |
| import safetensors.torch |
| adapter_state = safetensors.torch.load_file(adapter_path) |
| else: |
| adapter_state = torch.load(adapter_path, map_location=device) |
| self.adapter.load_state_dict(adapter_state) |
| except Exception as e: |
| print(f"⚠️ 加载适配器权重失败: {e}") |
| |
| |
| print("🏗️ 加载 UNet 模型...") |
| from .model_loader import load_unet_with_lora, load_fused_unet |
| |
| if use_fused_unet and fused_unet_path: |
| |
| print("📦 使用融合LoRA权重的UNet...") |
| self.unet = load_fused_unet(fused_unet_path, device, dtype) |
| elif lora_weights_path and lora_config_path: |
| |
| print("🔧 加载UNet并应用LoRA权重...") |
| |
| |
| self.unet = load_unet_with_lora( |
| unet_path=unet_path, |
| unet_config_path=unet_config_path, |
| lora_weights_path=lora_weights_path, |
| lora_config_path=lora_config_path, |
| device=device, |
| dtype=dtype |
| ) |
| else: |
| |
| self.unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype) |
| |
| |
| print("🎨 加载 VAE 模型...") |
| self.vae = load_vae_from_safetensors(vae_path, vae_config_path, device, dtype) |
| |
| |
| print(f"⏰ 创建调度器 ({scheduler_type})...") |
| self.scheduler = create_scheduler(scheduler_type) |
| |
| |
| self.is_ready = all([ |
| self.text_encoder is not None, |
| self.adapter is not None, |
| self.unet is not None, |
| self.vae is not None, |
| self.scheduler is not None |
| ]) |
| |
| if self.is_ready: |
| print("✅ 管道初始化完成!") |
| else: |
| print("❌ 管道初始化失败,某些组件加载失败") |
| |
| def encode_prompts( |
| self, |
| prompts: List[str], |
| negative_prompts: Optional[List[str]] = None, |
| do_classifier_free_guidance: bool = True |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Encode prompts using Qwen3 + adapter |
| 使用 Qwen3 + 适配器编码提示词 |
| """ |
| |
| text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts( |
| prompts, negative_prompts, do_classifier_free_guidance |
| ) |
| |
| batch_size = len(prompts) |
| if do_classifier_free_guidance: |
| batch_size *= 2 |
|
|
| |
| seq_len = 512 |
| text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1) |
|
|
| |
| prompt_embeds = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype)) |
| pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype)) |
| |
| return prompt_embeds, pooled_prompt_embeds |
| |
| def prepare_latents( |
| self, |
| batch_size: int, |
| height: int, |
| width: int, |
| generator: Optional[torch.Generator] = None |
| ) -> torch.Tensor: |
| """ |
| Prepare initial latents |
| 准备初始潜在变量 |
| """ |
| if self.unet is None: |
| |
| shape = (batch_size, 4, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| return torch.randn(shape, device=self.device, dtype=self.dtype) |
| |
| shape = ( |
| batch_size, |
| self.unet.config.in_channels, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor, |
| ) |
| |
| try: |
| from diffusers.utils import randn_tensor |
| latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype) |
| except ImportError: |
| latents = torch.randn(shape, device=self.device, dtype=self.dtype, generator=generator) |
| |
| |
| if self.scheduler is not None: |
| latents = latents * self.scheduler.init_noise_sigma |
| |
| return latents |
| |
| def get_time_ids( |
| self, |
| height: int, |
| width: int, |
| original_size: Tuple[int, int], |
| crops_coords_top_left: Tuple[int, int] = (0, 0), |
| target_size: Optional[Tuple[int, int]] = None |
| ) -> torch.Tensor: |
| """ |
| Get SDXL time IDs for micro-conditioning |
| 获取 SDXL 时间 ID 用于微调节 |
| """ |
| if target_size is None: |
| target_size = (height, width) |
| |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) |
| add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype, device=self.device) |
| |
| return add_time_ids |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| prompt: Union[str, List[str]], |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| height: int = 1024, |
| width: int = 1024, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 7.5, |
| generator: Optional[torch.Generator] = None, |
| return_type: str = "pil" |
| ) -> List[Image.Image]: |
| """ |
| Generate images using Qwen-SDXL pipeline |
| 使用 Qwen-SDXL 管道生成图像 |
| """ |
| if not self.is_ready: |
| print("❌ 管道未准备就绪,无法生成图像") |
| return [] |
| |
| |
| if isinstance(prompt, str): |
| prompt = [prompt] |
| if isinstance(negative_prompt, str): |
| negative_prompt = [negative_prompt] |
| |
| batch_size = len(prompt) |
| do_classifier_free_guidance = guidance_scale > 1.0 |
| |
| print(f"🎯 开始生成 {batch_size} 张图像...") |
| print(f"📏 尺寸: {width}x{height}") |
| print(f"🔄 推理步数: {num_inference_steps}") |
| print(f"🎚️ 引导强度: {guidance_scale}") |
| |
| |
| print("📝 编码提示词...") |
| prompt_embeds, pooled_prompt_embeds = self.encode_prompts( |
| prompt, negative_prompt, do_classifier_free_guidance |
| ) |
| |
| |
| print("⏰ 准备时间步...") |
| if self.scheduler is not None: |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) |
| timesteps = self.scheduler.timesteps |
| else: |
| timesteps = torch.linspace(1000, 0, num_inference_steps, device=self.device) |
| |
| |
| print("🌀 准备潜在变量...") |
| latents = self.prepare_latents(batch_size, height, width, generator) |
| |
| |
| original_size = (height, width) |
| target_size = (height, width) |
| add_time_ids = self.get_time_ids(height, width, original_size, target_size=target_size) |
| |
| if do_classifier_free_guidance: |
| add_time_ids = add_time_ids.repeat(2, 1) |
| add_time_ids = add_time_ids.repeat(batch_size, 1) |
| |
| |
| print("🔄 开始去噪过程...") |
| for i, t in enumerate(timesteps): |
| |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| |
| if self.scheduler is not None: |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
| |
| if self.unet is not None: |
| added_cond_kwargs = { |
| "text_embeds": pooled_prompt_embeds, |
| "time_ids": add_time_ids |
| } |
| |
| noise_pred = self.unet( |
| latent_model_input, |
| t, |
| encoder_hidden_states=prompt_embeds, |
| added_cond_kwargs=added_cond_kwargs, |
| return_dict=False, |
| )[0] |
| |
| |
| if do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
| |
| if self.scheduler is not None: |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
| |
| if (i + 1) % 5 == 0: |
| print(f" 步骤 {i+1}/{len(timesteps)} 完成") |
| |
| |
| print("🎨 解码生成图像...") |
| if self.vae is not None: |
| latents = latents / self.vae.config.scaling_factor |
| images = self.vae.decode(latents, return_dict=False)[0] |
| else: |
| |
| images = torch.randn(batch_size, 3, height, width, device=self.device) |
| |
| |
| images = (images / 2 + 0.5).clamp(0, 1) |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| |
| if return_type == "pil": |
| images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images] |
| |
| print("✅ 图像生成完成!") |
| return images |
| |
| def save_adapter(self, save_path: str): |
| """ |
| Save adapter weights |
| 保存适配器权重 |
| """ |
| try: |
| if save_path.endswith(".safetensors"): |
| import safetensors.torch |
| safetensors.torch.save_file(self.adapter.state_dict(), save_path) |
| else: |
| torch.save(self.adapter.state_dict(), save_path) |
| print(f"✅ 适配器权重已保存到: {save_path}") |
| except Exception as e: |
| print(f"❌ 保存适配器权重失败: {e}") |
| |
| def load_adapter(self, load_path: str): |
| """ |
| Load adapter weights |
| 加载适配器权重 |
| """ |
| try: |
| if load_path.endswith(".safetensors"): |
| import safetensors.torch |
| state_dict = safetensors.torch.load_file(load_path) |
| else: |
| state_dict = torch.load(load_path, map_location=self.device) |
| |
| self.adapter.load_state_dict(state_dict) |
| print(f"✅ 适配器权重已从 {load_path} 加载") |
| except Exception as e: |
| print(f"❌ 加载适配器权重失败: {e}") |
|
|