Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| from peft import PeftModel | |
| LORA_DIR = "./lora_girls" | |
| # 1. 如果没训练好的 lora-girls,就先训练 | |
| if not os.path.exists(LORA_DIR): | |
| print("🚀 未检测到 LoRA 权重,开始训练 ...") | |
| result = subprocess.run(["python", "train_lora.py"], capture_output=True, text=True) | |
| print(result.stdout) | |
| if result.returncode != 0: | |
| print("❌ 训练失败:", result.stderr) | |
| raise SystemExit("训练出错,停止运行") | |
| print("✅ 训练完成,LoRA 已生成。") | |
| # 2. 训练完成后,加载基座模型 + LoRA | |
| print("🔄 正在加载模型和 LoRA ...") | |
| pipe = StableDiffusionPipeline.from_pretrained("prompthero/openjourney", torch_dtype=torch.float16).to("cuda") | |
| pipe.unet = PeftModel.from_pretrained(pipe.unet, LORA_DIR) | |
| print("✅ 模型加载完成,可以开始生成图片!") | |
| # 3. 定义 Gradio 界面 | |
| def infer(prompt): | |
| image = pipe(prompt).images[0] | |
| return image | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🖼️ LoRA 文生图 (Girls)") | |
| prompt = gr.Textbox(label="输入你的提示词") | |
| output = gr.Image() | |
| btn = gr.Button("生成") | |
| btn.click(fn=infer, inputs=prompt, outputs=output) | |
| demo.launch() | |