Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| # 全域變數 | |
| pipe = None | |
| def load_model_fn(model_id, device_choice): | |
| global pipe | |
| try: | |
| # 強制邏輯:如果是 CPU,一定要用 float32 才能跑 | |
| if device_choice == "cpu": | |
| dtype = torch.float32 | |
| else: | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| yield f"⏳ 正在載入模型至 {device_choice.upper()}... (這可能需要幾分鐘)" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| pipe.to(device_choice) | |
| yield f"✅ 成功載入:{model_id} ({device_choice.upper()})" | |
| except Exception as e: | |
| yield f"❌ 載_入失敗:{str(e)}" | |
| def generate_fn(prompt, steps, guidance): | |
| global pipe | |
| if pipe is None: | |
| return None, "⚠️ 請先點擊『載入模型』!" | |
| try: | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance | |
| ).images[0] # 確保回傳單張圖片 | |
| return image, "✨ 生成成功!" | |
| except Exception as e: | |
| return None, f"❌ 錯誤:{str(e)}" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🛠️ CPU 優化版 Stable Diffusion 控制台") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. 模型設定") | |
| model_input = gr.Textbox( | |
| label="Hugging Face Model ID", | |
| value="runwayml/stable-diffusion-v1-5" | |
| ) | |
| # 這裡預設改為 "cpu" | |
| device_radio = gr.Radio(["cpu", "cuda"], value="cpu", label="執行設備 (預設 CPU)") | |
| load_btn = gr.Button("🔄 載入模型", variant="secondary") | |
| load_status = gr.Markdown("系統狀態:等待指令") | |
| gr.Markdown("### 2. 繪圖參數") | |
| prompt = gr.Textbox(label="提示詞 (Prompt)", lines=3, placeholder="An astronaut riding a horse") | |
| steps = gr.Slider(1, 50, value=20, step=1, label="步數 (CPU 建議 15-20)") | |
| guidance = gr.Slider(1, 20, value=7.5, step=0.5, label="提示詞強度") | |
| gen_btn = gr.Button("🚀 開始生成", variant="primary") | |
| with gr.Column(scale=2): | |
| output_img = gr.Image(label="生成結果") | |
| status_msg = gr.Textbox(label="執行訊息", interactive=False) | |
| load_btn.click(fn=load_model_fn, inputs=[model_input, device_radio], outputs=load_status) | |
| gen_btn.click(fn=generate_fn, inputs=[prompt, steps, guidance], outputs=[output_img, status_msg]) | |
| demo.launch() | |