File size: 1,296 Bytes
e6b0653
 
3739bc2
dba4863
3739bc2
e6b0653
3739bc2
0ec91bc
3739bc2
e6b0653
 
 
 
 
 
 
 
 
3739bc2
e6b0653
 
8bd2801
e6b0653
 
3739bc2
e6b0653
 
 
dba4863
3739bc2
dba4863
e6b0653
 
dba4863
e6b0653
 
3739bc2
dba4863
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import subprocess
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from peft import PeftModel

LORA_DIR = "./lora_girls"

# 1. 如果没训练好的 lora-girls,就先训练
if not os.path.exists(LORA_DIR):
    print("🚀 未检测到 LoRA 权重,开始训练 ...")
    result = subprocess.run(["python", "train_lora.py"], capture_output=True, text=True)
    print(result.stdout)
    if result.returncode != 0:
        print("❌ 训练失败:", result.stderr)
        raise SystemExit("训练出错,停止运行")
    print("✅ 训练完成,LoRA 已生成。")

# 2. 训练完成后,加载基座模型 + LoRA
print("🔄 正在加载模型和 LoRA ...")
pipe = StableDiffusionPipeline.from_pretrained("prompthero/openjourney", torch_dtype=torch.float16).to("cuda")
pipe.unet = PeftModel.from_pretrained(pipe.unet, LORA_DIR)
print("✅ 模型加载完成,可以开始生成图片!")

# 3. 定义 Gradio 界面
def infer(prompt):
    image = pipe(prompt).images[0]
    return image

with gr.Blocks() as demo:
    gr.Markdown("# 🖼️ LoRA 文生图 (Girls)")
    prompt = gr.Textbox(label="输入你的提示词")
    output = gr.Image()
    btn = gr.Button("生成")
    btn.click(fn=infer, inputs=prompt, outputs=output)

demo.launch()