import os import subprocess import gradio as gr from diffusers import StableDiffusionPipeline import torch from peft import PeftModel LORA_DIR = "./lora_girls" # 1. 如果没训练好的 lora-girls,就先训练 if not os.path.exists(LORA_DIR): print("🚀 未检测到 LoRA 权重,开始训练 ...") result = subprocess.run(["python", "train_lora.py"], capture_output=True, text=True) print(result.stdout) if result.returncode != 0: print("❌ 训练失败:", result.stderr) raise SystemExit("训练出错,停止运行") print("✅ 训练完成,LoRA 已生成。") # 2. 训练完成后,加载基座模型 + LoRA print("🔄 正在加载模型和 LoRA ...") pipe = StableDiffusionPipeline.from_pretrained("prompthero/openjourney", torch_dtype=torch.float16).to("cuda") pipe.unet = PeftModel.from_pretrained(pipe.unet, LORA_DIR) print("✅ 模型加载完成,可以开始生成图片!") # 3. 定义 Gradio 界面 def infer(prompt): image = pipe(prompt).images[0] return image with gr.Blocks() as demo: gr.Markdown("# 🖼️ LoRA 文生图 (Girls)") prompt = gr.Textbox(label="输入你的提示词") output = gr.Image() btn = gr.Button("生成") btn.click(fn=infer, inputs=prompt, outputs=output) demo.launch()