|
|
""" |
|
|
Qwen-SDXL Inference Script (Updated to use arch components) |
|
|
基于 Qwen3 Embedding 的 SDXL 推理管道 - 使用架构组件 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from typing import List, Optional, Union |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
from arch import QwenIllustriousInference |
|
|
|
|
|
|
|
|
def test_qwen_sdxl_inference(): |
|
|
""" |
|
|
Test the Qwen-SDXL inference pipeline using arch components |
|
|
使用架构组件测试 Qwen-SDXL 推理管道 |
|
|
""" |
|
|
print("🧪 测试 Qwen-SDXL 推理管道 (使用架构组件)") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
pipeline = QwenIllustriousInference( |
|
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
|
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
if not pipeline.is_ready: |
|
|
print("⚠️ 管道未准备就绪,使用模拟模式进行测试") |
|
|
|
|
|
|
|
|
test_prompts = [ |
|
|
"A beautiful landscape with mountains and rivers, oil painting style", |
|
|
"A cute cat wearing a red hat, anime style, high quality", |
|
|
] |
|
|
|
|
|
negative_prompt = "low quality, blurry, distorted, watermark" |
|
|
|
|
|
|
|
|
for i, prompt in enumerate(test_prompts): |
|
|
print(f"\n🎨 生成测试图像 {i+1}") |
|
|
print(f"📝 提示词: {prompt}") |
|
|
|
|
|
try: |
|
|
images = pipeline.generate( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=512, |
|
|
width=512, |
|
|
num_inference_steps=50, |
|
|
guidance_scale=7.5, |
|
|
) |
|
|
|
|
|
if images: |
|
|
output_path = f"test_qwen_sdxl_{i+1}.png" |
|
|
images[0].save(output_path) |
|
|
print(f"💾 已保存: {output_path}") |
|
|
else: |
|
|
print("❌ 图像生成失败") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 生成过程中发生错误: {e}") |
|
|
|
|
|
print("\n🎉 测试完成!") |
|
|
|
|
|
|
|
|
def generate_single_image( |
|
|
prompt: str, |
|
|
negative_prompt: str = "low quality, blurry, distorted", |
|
|
height: int = 1024, |
|
|
width: int = 1024, |
|
|
num_inference_steps: int = 50, |
|
|
guidance_scale: float = 7.5, |
|
|
output_path: str = "output.png", |
|
|
adapter_path: Optional[str] = None, |
|
|
lora_weights_path: Optional[str] = None, |
|
|
lora_config_path: Optional[str] = None, |
|
|
use_fused_unet: bool = False, |
|
|
fused_unet_path: Optional[str] = None |
|
|
) -> bool: |
|
|
""" |
|
|
Generate a single image using Qwen-SDXL pipeline |
|
|
使用 Qwen-SDXL 管道生成单张图像 |
|
|
|
|
|
Args: |
|
|
prompt: Text prompt for image generation |
|
|
negative_prompt: Negative prompt |
|
|
height: Image height |
|
|
width: Image width |
|
|
num_inference_steps: Number of denoising steps |
|
|
guidance_scale: Guidance scale for CFG |
|
|
output_path: Path to save the generated image |
|
|
adapter_path: Path to trained adapter weights (safetensors) |
|
|
lora_weights_path: Path to LoRA weights (safetensors) |
|
|
lora_config_path: Path to LoRA config directory |
|
|
use_fused_unet: Whether to use fused UNet with merged LoRA |
|
|
fused_unet_path: Path to fused UNet model directory |
|
|
|
|
|
Returns: |
|
|
True if generation successful, False otherwise |
|
|
""" |
|
|
print(f"🎨 使用 Qwen-SDXL 生成图像") |
|
|
print(f"📝 提示词: {prompt}") |
|
|
print(f"📏 尺寸: {width}x{height}") |
|
|
|
|
|
try: |
|
|
|
|
|
pipeline = QwenIllustriousInference( |
|
|
adapter_path=adapter_path, |
|
|
lora_weights_path=lora_weights_path, |
|
|
lora_config_path=lora_config_path, |
|
|
use_fused_unet=use_fused_unet, |
|
|
fused_unet_path=fused_unet_path, |
|
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
|
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
if not pipeline.is_ready: |
|
|
print("❌ 管道未准备就绪") |
|
|
return False |
|
|
|
|
|
|
|
|
images = pipeline.generate( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=height, |
|
|
width=width, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
) |
|
|
|
|
|
if images: |
|
|
images[0].save(output_path) |
|
|
print(f"✅ 图像已保存到: {output_path}") |
|
|
return True |
|
|
else: |
|
|
print("❌ 图像生成失败") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 生成过程中发生错误: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Qwen-SDXL Inference") |
|
|
parser.add_argument("--prompt", type=str, help="Text prompt for generation", default=None) |
|
|
parser.add_argument("--negative_prompt", type=str, default="low quality, blurry", help="Negative prompt") |
|
|
parser.add_argument("--height", type=int, default=1024, help="Image height") |
|
|
parser.add_argument("--width", type=int, default=1024, help="Image width") |
|
|
parser.add_argument("--steps", type=int, default=35, help="Number of inference steps") |
|
|
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale") |
|
|
parser.add_argument("--output", type=str, default="output.png", help="Output image path") |
|
|
parser.add_argument("--adapter_path", type=str, help="Path to trained adapter weights (safetensors)") |
|
|
parser.add_argument("--lora_weights_path", type=str, help="Path to LoRA weights (safetensors)") |
|
|
parser.add_argument("--lora_config_path", type=str, help="Path to LoRA config directory") |
|
|
parser.add_argument("--use_fused_unet", action="store_true", help="Use fused UNet with merged LoRA weights") |
|
|
parser.add_argument("--fused_unet_path", type=str, help="Path to fused UNet model directory") |
|
|
parser.add_argument("--test", action="store_true", help="Run test mode") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.test or args.prompt is None: |
|
|
|
|
|
test_qwen_sdxl_inference() |
|
|
else: |
|
|
|
|
|
generate_single_image( |
|
|
prompt=args.prompt, |
|
|
negative_prompt=args.negative_prompt, |
|
|
height=args.height, |
|
|
width=args.width, |
|
|
num_inference_steps=args.steps, |
|
|
guidance_scale=args.guidance_scale, |
|
|
output_path=args.output, |
|
|
adapter_path=args.adapter_path, |
|
|
lora_weights_path=args.lora_weights_path, |
|
|
lora_config_path=args.lora_config_path, |
|
|
use_fused_unet=args.use_fused_unet, |
|
|
fused_unet_path=args.fused_unet_path |
|
|
) |
|
|
|