study-pic / app.py
ZhongZhiYY's picture
Update app.py
8bd2801 verified
import os
import subprocess
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from peft import PeftModel
LORA_DIR = "./lora_girls"
# 1. 如果没训练好的 lora-girls,就先训练
if not os.path.exists(LORA_DIR):
print("🚀 未检测到 LoRA 权重,开始训练 ...")
result = subprocess.run(["python", "train_lora.py"], capture_output=True, text=True)
print(result.stdout)
if result.returncode != 0:
print("❌ 训练失败:", result.stderr)
raise SystemExit("训练出错,停止运行")
print("✅ 训练完成,LoRA 已生成。")
# 2. 训练完成后,加载基座模型 + LoRA
print("🔄 正在加载模型和 LoRA ...")
pipe = StableDiffusionPipeline.from_pretrained("prompthero/openjourney", torch_dtype=torch.float16).to("cuda")
pipe.unet = PeftModel.from_pretrained(pipe.unet, LORA_DIR)
print("✅ 模型加载完成,可以开始生成图片!")
# 3. 定义 Gradio 界面
def infer(prompt):
image = pipe(prompt).images[0]
return image
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ LoRA 文生图 (Girls)")
prompt = gr.Textbox(label="输入你的提示词")
output = gr.Image()
btn = gr.Button("生成")
btn.click(fn=infer, inputs=prompt, outputs=output)
demo.launch()