Spaces:
Runtime error
Runtime error
File size: 1,296 Bytes
e6b0653 3739bc2 dba4863 3739bc2 e6b0653 3739bc2 0ec91bc 3739bc2 e6b0653 3739bc2 e6b0653 8bd2801 e6b0653 3739bc2 e6b0653 dba4863 3739bc2 dba4863 e6b0653 dba4863 e6b0653 3739bc2 dba4863 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import os
import subprocess
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from peft import PeftModel
LORA_DIR = "./lora_girls"
# 1. 如果没训练好的 lora-girls,就先训练
if not os.path.exists(LORA_DIR):
print("🚀 未检测到 LoRA 权重,开始训练 ...")
result = subprocess.run(["python", "train_lora.py"], capture_output=True, text=True)
print(result.stdout)
if result.returncode != 0:
print("❌ 训练失败:", result.stderr)
raise SystemExit("训练出错,停止运行")
print("✅ 训练完成,LoRA 已生成。")
# 2. 训练完成后,加载基座模型 + LoRA
print("🔄 正在加载模型和 LoRA ...")
pipe = StableDiffusionPipeline.from_pretrained("prompthero/openjourney", torch_dtype=torch.float16).to("cuda")
pipe.unet = PeftModel.from_pretrained(pipe.unet, LORA_DIR)
print("✅ 模型加载完成,可以开始生成图片!")
# 3. 定义 Gradio 界面
def infer(prompt):
image = pipe(prompt).images[0]
return image
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ LoRA 文生图 (Girls)")
prompt = gr.Textbox(label="输入你的提示词")
output = gr.Image()
btn = gr.Button("生成")
btn.click(fn=infer, inputs=prompt, outputs=output)
demo.launch()
|