qwenillustrious / inference_updated.py
lsmpp's picture
Add files using upload-large-folder tool
d926b4c verified
"""
Qwen-SDXL Inference Script (Updated to use arch components)
基于 Qwen3 Embedding 的 SDXL 推理管道 - 使用架构组件
"""
import torch
from typing import List, Optional, Union
from PIL import Image
# Import from arch components
from arch import QwenIllustriousInference
def test_qwen_sdxl_inference():
"""
Test the Qwen-SDXL inference pipeline using arch components
使用架构组件测试 Qwen-SDXL 推理管道
"""
print("🧪 测试 Qwen-SDXL 推理管道 (使用架构组件)")
print("=" * 50)
# Initialize pipeline using arch components
pipeline = QwenIllustriousInference(
device="cuda" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
if not pipeline.is_ready:
print("⚠️ 管道未准备就绪,使用模拟模式进行测试")
# Test prompts
test_prompts = [
"A beautiful landscape with mountains and rivers, oil painting style",
"A cute cat wearing a red hat, anime style, high quality",
]
negative_prompt = "low quality, blurry, distorted, watermark"
# Generate images
for i, prompt in enumerate(test_prompts):
print(f"\n🎨 生成测试图像 {i+1}")
print(f"📝 提示词: {prompt}")
try:
images = pipeline.generate(
prompt=prompt,
negative_prompt=negative_prompt,
height=512, # 使用较小尺寸进行测试
width=512,
num_inference_steps=50, # 较少步数用于快速测试
guidance_scale=7.5,
)
if images:
output_path = f"test_qwen_sdxl_{i+1}.png"
images[0].save(output_path)
print(f"💾 已保存: {output_path}")
else:
print("❌ 图像生成失败")
except Exception as e:
print(f"❌ 生成过程中发生错误: {e}")
print("\n🎉 测试完成!")
def generate_single_image(
prompt: str,
negative_prompt: str = "low quality, blurry, distorted",
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
output_path: str = "output.png",
adapter_path: Optional[str] = None,
lora_weights_path: Optional[str] = None,
lora_config_path: Optional[str] = None,
use_fused_unet: bool = False,
fused_unet_path: Optional[str] = None
) -> bool:
"""
Generate a single image using Qwen-SDXL pipeline
使用 Qwen-SDXL 管道生成单张图像
Args:
prompt: Text prompt for image generation
negative_prompt: Negative prompt
height: Image height
width: Image width
num_inference_steps: Number of denoising steps
guidance_scale: Guidance scale for CFG
output_path: Path to save the generated image
adapter_path: Path to trained adapter weights (safetensors)
lora_weights_path: Path to LoRA weights (safetensors)
lora_config_path: Path to LoRA config directory
use_fused_unet: Whether to use fused UNet with merged LoRA
fused_unet_path: Path to fused UNet model directory
Returns:
True if generation successful, False otherwise
"""
print(f"🎨 使用 Qwen-SDXL 生成图像")
print(f"📝 提示词: {prompt}")
print(f"📏 尺寸: {width}x{height}")
try:
# Initialize pipeline
pipeline = QwenIllustriousInference(
adapter_path=adapter_path,
lora_weights_path=lora_weights_path,
lora_config_path=lora_config_path,
use_fused_unet=use_fused_unet,
fused_unet_path=fused_unet_path,
device="cuda" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
if not pipeline.is_ready:
print("❌ 管道未准备就绪")
return False
# Generate image
images = pipeline.generate(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
if images:
images[0].save(output_path)
print(f"✅ 图像已保存到: {output_path}")
return True
else:
print("❌ 图像生成失败")
return False
except Exception as e:
print(f"❌ 生成过程中发生错误: {e}")
return False
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Qwen-SDXL Inference")
parser.add_argument("--prompt", type=str, help="Text prompt for generation", default=None)
parser.add_argument("--negative_prompt", type=str, default="low quality, blurry", help="Negative prompt")
parser.add_argument("--height", type=int, default=1024, help="Image height")
parser.add_argument("--width", type=int, default=1024, help="Image width")
parser.add_argument("--steps", type=int, default=35, help="Number of inference steps")
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale")
parser.add_argument("--output", type=str, default="output.png", help="Output image path")
parser.add_argument("--adapter_path", type=str, help="Path to trained adapter weights (safetensors)")
parser.add_argument("--lora_weights_path", type=str, help="Path to LoRA weights (safetensors)")
parser.add_argument("--lora_config_path", type=str, help="Path to LoRA config directory")
parser.add_argument("--use_fused_unet", action="store_true", help="Use fused UNet with merged LoRA weights")
parser.add_argument("--fused_unet_path", type=str, help="Path to fused UNet model directory")
parser.add_argument("--test", action="store_true", help="Run test mode")
args = parser.parse_args()
if args.test or args.prompt is None:
# Run test mode
test_qwen_sdxl_inference()
else:
# Generate single image
generate_single_image(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
output_path=args.output,
adapter_path=args.adapter_path,
lora_weights_path=args.lora_weights_path,
lora_config_path=args.lora_config_path,
use_fused_unet=args.use_fused_unet,
fused_unet_path=args.fused_unet_path
)