Spaces:
Build error
Build error
| import os | |
| import io | |
| import gc | |
| import json | |
| import time | |
| import base64 | |
| import shutil | |
| import requests | |
| from pathlib import Path | |
| from typing import Optional, Dict, List | |
| from urllib.parse import urlparse, parse_qs | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| # 延遲導入重量級庫 | |
| diffusers_loaded = False | |
| def load_diffusers(): | |
| global DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, diffusers_loaded | |
| if not diffusers_loaded: | |
| from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline | |
| diffusers_loaded = True | |
| return DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline | |
| return DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline | |
| huggingface_hub_loaded = False | |
| def load_huggingface_hub(): | |
| global snapshot_download, huggingface_hub_loaded | |
| if not huggingface_hub_loaded: | |
| from huggingface_hub import snapshot_download | |
| huggingface_hub_loaded = True | |
| return snapshot_download | |
| return snapshot_download | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| PUBLIC_API_KEY = os.getenv("PUBLIC_API_KEY", "demo-key-123456") | |
| SPACE_URL = os.getenv("SPACE_URL", "https://your-space.hf.space") | |
| CIVITAI_API_KEY = os.getenv("CIVITAI_API_KEY", "") | |
| def get_writable_dir() -> Path: | |
| candidates = [os.getenv("MODELS_DIR"), os.path.join(os.path.expanduser("~"), ".hf_models"), "/tmp/hf_models"] | |
| for path in candidates: | |
| if not path: | |
| continue | |
| try: | |
| p = Path(path) | |
| p.mkdir(parents=True, exist_ok=True) | |
| test_file = p / ".write_test" | |
| test_file.touch() | |
| test_file.unlink() | |
| return p | |
| except (PermissionError, OSError): | |
| continue | |
| fallback = Path.cwd() / "models" | |
| fallback.mkdir(parents=True, exist_ok=True) | |
| return fallback | |
| PERSIST_DIR = get_writable_dir() | |
| MODELS_DB_FILE = PERSIST_DIR / "model_library.json" | |
| HF_CACHE_DIR = Path.home() / ".cache" / "huggingface" / "hub" | |
| print(f"📁 模型目錄: {PERSIST_DIR}") | |
| class ModelLibrary: | |
| def __init__(self): | |
| self.models: Dict[str, dict] = {} | |
| self.load() | |
| def load(self): | |
| if MODELS_DB_FILE.exists(): | |
| try: | |
| with open(MODELS_DB_FILE, "r", encoding="utf-8") as f: | |
| self.models = json.load(f) | |
| print(f"✅ 已載入 {len(self.models)} 個模型記錄") | |
| except Exception: | |
| self.models = {} | |
| else: | |
| self.models = {} | |
| def save(self): | |
| try: | |
| with open(MODELS_DB_FILE, "w", encoding="utf-8") as f: | |
| json.dump(self.models, f, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| print(f"保存失敗: {e}") | |
| def _key(self, repo_id: str) -> str: | |
| return repo_id.replace("/", "--").replace(":", "--") | |
| def add(self, repo_id: str, name: Optional[str] = None, local_path: str = "", source: str = "huggingface"): | |
| key = self._key(repo_id) | |
| self.models[key] = { | |
| "repo_id": repo_id, | |
| "name": name or repo_id.split("/")[-1], | |
| "local_path": local_path, | |
| "source": source, | |
| "added_at": time.strftime("%Y-%m-%d %H:%M"), | |
| "status": "ready" | |
| } | |
| self.save() | |
| def remove(self, repo_id: str): | |
| key = self._key(repo_id) | |
| if key in self.models: | |
| del self.models[key] | |
| self.save() | |
| return True | |
| return False | |
| def list_display(self) -> List[str]: | |
| if not self.models: | |
| return ["暫無已下載模型"] | |
| out = [] | |
| for m in self.models.values(): | |
| emoji = "✅" if m.get("status") == "ready" else "⚠️" | |
| source_icon = "🤗" if m.get("source") == "huggingface" else "🎨" | |
| out.append(f"{emoji}{source_icon} {m['name']}") | |
| return out | |
| def resolve_repo(self, display: str) -> str: | |
| for m in self.models.values(): | |
| if m["name"] in display or m["repo_id"] in display: | |
| return m["repo_id"] | |
| return display.strip() | |
| model_library = ModelLibrary() | |
| class ModelManager: | |
| def __init__(self): | |
| self.loaded: Dict[str, any] = {} | |
| self.current: Optional[str] = None | |
| def local_path(self, repo_id: str) -> Path: | |
| return PERSIST_DIR / repo_id.replace("/", "--").replace(":", "--") | |
| def is_downloaded(self, repo_id: str) -> bool: | |
| lp = self.local_path(repo_id) | |
| return (lp / "model_index.json").exists() or any(lp.glob("*.safetensors")) | |
| def size_gb(self, repo_id: str) -> float: | |
| lp = self.local_path(repo_id) | |
| if not lp.exists(): | |
| return 0.0 | |
| total = sum(f.stat().st_size for f in lp.rglob("*") if f.is_file()) | |
| return round(total / (1024**3), 2) | |
| def storage_info(self) -> dict: | |
| total = sum(self.size_gb(m["repo_id"]) for m in model_library.models.values()) | |
| try: | |
| free = shutil.disk_usage(PERSIST_DIR).free / (1024**3) | |
| except: | |
| free = 0.0 | |
| return { | |
| "model_count": len(model_library.models), | |
| "total_size_gb": round(total, 2), | |
| "persist_dir": str(PERSIST_DIR), | |
| "free_space_gb": round(free, 2) | |
| } | |
| def parse_civitai_url(self, url: str) -> Optional[Dict]: | |
| try: | |
| if "civitai.com" not in url: | |
| return None | |
| parsed = urlparse(url) | |
| if "/models/" in url: | |
| parts = url.split("/models/")[1].split("/")[0].split("?")[0] | |
| model_id = parts | |
| query_params = parse_qs(parsed.query) | |
| version_id = query_params.get("modelVersionId", [None])[0] | |
| return {"model_id": model_id, "version_id": version_id, "url": url} | |
| return None | |
| except Exception as e: | |
| print(f"解析 Civitai URL 失敗: {e}") | |
| return None | |
| def download_civitai_model(self, url: str) -> tuple: | |
| parsed = self.parse_civitai_url(url) | |
| if not parsed: | |
| return False, "❌ 無效的 Civitai URL" | |
| model_id = parsed["model_id"] | |
| version_id = parsed["version_id"] | |
| try: | |
| api_url = f"https://civitai.com/api/v1/models/{model_id}" | |
| headers = {} | |
| if CIVITAI_API_KEY: | |
| headers["Authorization"] = f"Bearer {CIVITAI_API_KEY}" | |
| print(f"📡 獲取模型信息: {api_url}") | |
| response = requests.get(api_url, headers=headers, timeout=30) | |
| response.raise_for_status() | |
| model_info = response.json() | |
| model_name = model_info.get("name", f"civitai-{model_id}") | |
| versions = model_info.get("modelVersions", []) | |
| if not versions: | |
| return False, "❌ 找不到模型版本" | |
| selected_version = None | |
| if version_id: | |
| selected_version = next((v for v in versions if str(v["id"]) == version_id), None) | |
| if not selected_version: | |
| selected_version = versions[0] | |
| files = selected_version.get("files", []) | |
| if not files: | |
| return False, "❌ 找不到下載文件" | |
| download_file = None | |
| for f in files: | |
| if f.get("type") == "Model" or f["name"].endswith(".safetensors"): | |
| download_file = f | |
| break | |
| if not download_file: | |
| download_file = files[0] | |
| download_url = download_file["downloadUrl"] | |
| file_name = download_file["name"] | |
| file_size = download_file.get("sizeKB", 0) / 1024 / 1024 | |
| repo_id = f"civitai:{model_id}" | |
| if version_id: | |
| repo_id += f":{version_id}" | |
| local_path = self.local_path(repo_id) | |
| local_path.mkdir(parents=True, exist_ok=True) | |
| output_file = local_path / file_name | |
| if output_file.exists(): | |
| model_library.add(repo_id, model_name, str(local_path), source="civitai") | |
| return True, f"✅ 模型已存在: {model_name}" | |
| print(f"⬇️ 下載 Civitai 模型: {model_name} ({file_size:.2f} GB)") | |
| start = time.time() | |
| if CIVITAI_API_KEY and "?" in download_url: | |
| download_url += f"&token={CIVITAI_API_KEY}" | |
| elif CIVITAI_API_KEY: | |
| download_url += f"?token={CIVITAI_API_KEY}" | |
| response = requests.get(download_url, stream=True, timeout=300) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| downloaded = 0 | |
| with open(output_file, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| elapsed = time.time() - start | |
| actual_size = self.size_gb(repo_id) | |
| model_library.add(repo_id, model_name, str(local_path), source="civitai") | |
| return True, f"✅ Civitai 模型下載完成\n名稱: {model_name}\n大小: {actual_size} GB\n耗時: {elapsed:.1f} 秒" | |
| except requests.RequestException as e: | |
| return False, f"❌ 網絡錯誤: {str(e)[:150]}" | |
| except Exception as e: | |
| return False, f"❌ 下載失敗: {str(e)[:150]}" | |
| def download(self, repo_id: str): | |
| if "civitai.com" in repo_id: | |
| return self.download_civitai_model(repo_id) | |
| if self.is_downloaded(repo_id): | |
| lp = self.local_path(repo_id) | |
| model_library.add(repo_id, repo_id.split("/")[-1], str(lp), source="huggingface") | |
| return True, f"✅ 模型已存在: {lp}" | |
| lp = self.local_path(repo_id) | |
| print(f"⬇️ 下載 HF 模型: {repo_id}") | |
| try: | |
| snapshot_download_func = load_huggingface_hub() | |
| start = time.time() | |
| snapshot_download_func( | |
| repo_id=repo_id, | |
| local_dir=str(lp), | |
| local_dir_use_symlinks=False, | |
| cache_dir=str(HF_CACHE_DIR), | |
| token=HF_TOKEN or None | |
| ) | |
| elapsed = time.time() - start | |
| size = self.size_gb(repo_id) | |
| model_library.add(repo_id, repo_id.split("/")[-1], str(lp), source="huggingface") | |
| return True, f"✅ HF 模型下載完成\n大小: {size} GB\n耗時: {elapsed:.1f} 秒" | |
| except Exception as e: | |
| err = str(e) | |
| if "401" in err or "403" in err: | |
| return False, "❌ 需要授權: 請設置 HF_TOKEN" | |
| if "not found" in err.lower(): | |
| return False, "❌ 找不到模型" | |
| return False, f"❌ 下載失敗: {err[:150]}" | |
| def load(self, repo_id: str): | |
| if not self.is_downloaded(repo_id): | |
| ok, msg = self.download(repo_id) | |
| if not ok: | |
| return False, msg, None | |
| if repo_id in self.loaded: | |
| self.current = repo_id | |
| return True, "使用已加載模型", self.loaded[repo_id] | |
| lp = self.local_path(repo_id) | |
| try: | |
| print(f"⚙️ 加載模型: {lp}") | |
| DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline = load_diffusers() | |
| safetensors_files = list(lp.glob("*.safetensors")) | |
| if safetensors_files and not (lp / "model_index.json").exists(): | |
| pipe = StableDiffusionPipeline.from_single_file( | |
| str(safetensors_files[0]), | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| safety_checker=None, | |
| low_cpu_mem_usage=True | |
| ) | |
| else: | |
| pipe = DiffusionPipeline.from_pretrained( | |
| str(lp), | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| safety_checker=None, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True | |
| ) | |
| if hasattr(pipe, 'scheduler') and pipe.scheduler is not None: | |
| try: | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| except: | |
| pass | |
| if not hasattr(pipe, 'text_encoder') or pipe.text_encoder is None: | |
| return False, "❌ 缺少 text_encoder", None | |
| self.loaded[repo_id] = pipe | |
| self.current = repo_id | |
| return True, "✅ 加載成功", pipe | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"❌ 錯誤: {error_msg}") | |
| return False, f"❌ 加載失敗: {error_msg[:200]}", None | |
| def delete(self, repo_id: str): | |
| if repo_id in self.loaded: | |
| del self.loaded[repo_id] | |
| gc.collect() | |
| model_library.remove(repo_id) | |
| lp = self.local_path(repo_id) | |
| if lp.exists(): | |
| try: | |
| shutil.rmtree(lp) | |
| return True, f"✅ 已刪除: {repo_id}" | |
| except Exception as e: | |
| return False, f"⚠️ 刪除失敗: {str(e)[:200]}" | |
| return True, f"✅ 已移除記錄: {repo_id}" | |
| model_manager = ModelManager() | |
| app = FastAPI() | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| model: str | |
| negative_prompt: str = "" | |
| guidance_scale: float = 7.5 | |
| num_steps: int = 25 | |
| api_key: str | |
| class GenerateResponse(BaseModel): | |
| status: int | |
| image_base64: Optional[str] = None | |
| model: Optional[str] = None | |
| generation_time: Optional[float] = None | |
| error: Optional[str] = None | |
| async def api_generate(req: GenerateRequest): | |
| if req.api_key != PUBLIC_API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| if not req.prompt or len(req.prompt) < 3: | |
| raise HTTPException(status_code=400, detail="Prompt too short") | |
| ok, msg, pipe = model_manager.load(req.model) | |
| if not ok or pipe is None: | |
| return GenerateResponse(status=500, error=msg) | |
| try: | |
| t0 = time.time() | |
| img = pipe(prompt=req.prompt, negative_prompt=req.negative_prompt or None, num_inference_steps=max(4, min(int(req.num_steps), 50)), guidance_scale=req.guidance_scale, height=512, width=512).images[0] | |
| dt = time.time() - t0 | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return GenerateResponse(status=200, image_base64=f"data:image/png;base64,{b64}", model=req.model, generation_time=round(dt, 2)) | |
| except Exception as e: | |
| return GenerateResponse(status=500, error=str(e)[:200]) | |
| async def api_models(): | |
| return {"models": model_library.models, "storage": model_manager.storage_info()} | |
| async def health(): | |
| return {"status": "ok", "storage": model_manager.storage_info(), "loaded_models": list(model_manager.loaded.keys())} | |
| def build_api_examples(prompt, model_repo, negative, guidance, steps): | |
| if not model_repo: | |
| model_repo = "prompthero/openjourney-v4" | |
| payload = {"prompt": prompt or "A beautiful sunset", "model": model_repo.strip(), "negative_prompt": negative or "", "guidance_scale": guidance, "num_steps": int(steps), "api_key": PUBLIC_API_KEY} | |
| json_body = json.dumps(payload, indent=2, ensure_ascii=False) | |
| curl = f"curl -X POST {SPACE_URL}/api/generate -H 'Content-Type: application/json' -d '{json_body}'" | |
| return json_body, curl | |
| def ui_generate(prompt, model_display, repo_input, negative, guidance, steps, api_key): | |
| if api_key != PUBLIC_API_KEY: | |
| return None, "❌ API Key 無效" | |
| if not prompt or len(prompt) < 3: | |
| return None, "❌ 提示詞太短" | |
| repo_id = repo_input.strip() if repo_input and len(repo_input.strip()) > 3 else model_library.resolve_repo(model_display) | |
| if not repo_id or "暫無" in repo_id: | |
| return None, "❌ 請選擇或輸入模型" | |
| ok, msg, pipe = model_manager.load(repo_id) | |
| if not ok or pipe is None: | |
| return None, msg | |
| try: | |
| t0 = time.time() | |
| img = pipe(prompt=prompt, negative_prompt=negative or None, num_inference_steps=max(4, min(int(steps), 50)), guidance_scale=guidance, height=512, width=512).images[0] | |
| dt = time.time() - t0 | |
| size = model_manager.size_gb(repo_id) | |
| txt = f"✅ 成功\n模型: {repo_id}\n大小: {size} GB\n耗時: {dt:.1f} 秒" | |
| return img, txt | |
| except Exception as e: | |
| return None, f"❌ 失敗: {str(e)[:200]}" | |
| def ui_download(repo_id, api_key): | |
| if api_key != PUBLIC_API_KEY: | |
| return "❌ API Key 無效", gr.update(choices=model_library.list_display()), model_manager.storage_info() | |
| if not repo_id or len(repo_id.strip()) < 3: | |
| return "❌ 請輸入模型地址或 Civitai URL", gr.update(choices=model_library.list_display()), model_manager.storage_info() | |
| ok, msg = model_manager.download(repo_id.strip()) | |
| return msg, gr.update(choices=model_library.list_display()), model_manager.storage_info() | |
| def ui_delete(model_display, api_key): | |
| if api_key != PUBLIC_API_KEY: | |
| return "❌ API Key 無效", gr.update(choices=model_library.list_display()), model_manager.storage_info() | |
| if not model_display or "暫無" in model_display: | |
| return "❌ 請先選擇模型", gr.update(choices=model_library.list_display()), model_manager.storage_info() | |
| repo_id = model_library.resolve_repo(model_display) | |
| ok, msg = model_manager.delete(repo_id) | |
| return msg, gr.update(choices=model_library.list_display()), model_manager.storage_info() | |
| def ui_update_api_examples(prompt, model_display, repo_input, negative, guidance, steps): | |
| repo_id = repo_input.strip() if repo_input and len(repo_input.strip()) > 3 else (model_library.resolve_repo(model_display) if model_display and "暫無" not in model_display else "prompthero/openjourney-v4") | |
| return build_api_examples(prompt, repo_id, negative, guidance, steps) | |
| with gr.Blocks(title="AI 圖片生成", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f""" | |
| # 🖼️ AI 圖片生成(HF + Civitai) | |
| - 📁 模型目錄: `{PERSIST_DIR}` | |
| - 🔑 API Key: `{PUBLIC_API_KEY}` | |
| - 🤗 **HF**: `prompthero/openjourney-v4` · 🎨 **Civitai**: `https://civitai.com/models/xxx` | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("🎨 生圖"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown(label="已下載模型", choices=model_library.list_display(), value=model_library.list_display()[0] if model_library.list_display() else None) | |
| repo_box = gr.Textbox(label="或輸入模型地址 / Civitai URL", placeholder="prompthero/openjourney-v4") | |
| prompt_box = gr.Textbox(label="提示詞", lines=3, placeholder="a beautiful sunset...") | |
| negative_box = gr.Textbox(label="負面提示詞", lines=2, placeholder="blurry, low quality...") | |
| with gr.Row(): | |
| guidance_slider = gr.Slider(1, 20, 7.5, step=0.5, label="Guidance") | |
| steps_slider = gr.Slider(4, 50, 25, step=1, label="Steps") | |
| api_key_box = gr.Textbox(label="API Key", value=PUBLIC_API_KEY, type="password") | |
| gen_btn = gr.Button("🚀 生成", variant="primary") | |
| with gr.Column(): | |
| out_img = gr.Image(label="結果") | |
| out_msg = gr.Textbox(label="狀態", lines=5) | |
| with gr.Accordion("📡 API 示例", open=False): | |
| api_json = gr.Code(label="JSON", language="json") | |
| api_curl = gr.Code(label="cURL", language="shell") | |
| gen_btn.click(fn=ui_generate, inputs=[prompt_box, model_dropdown, repo_box, negative_box, guidance_slider, steps_slider, api_key_box], outputs=[out_img, out_msg]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown]) | |
| for comp in [prompt_box, model_dropdown, repo_box, negative_box, guidance_slider, steps_slider]: | |
| comp.change(fn=ui_update_api_examples, inputs=[prompt_box, model_dropdown, repo_box, negative_box, guidance_slider, steps_slider], outputs=[api_json, api_curl]) | |
| with gr.Tab("⬇️ 模型管理"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 下載新模型") | |
| dl_repo = gr.Textbox(label="模型地址 / URL", placeholder="prompthero/openjourney-v4") | |
| dl_key = gr.Textbox(label="API Key", value=PUBLIC_API_KEY, type="password") | |
| dl_btn = gr.Button("⬇️ 下載") | |
| dl_msg = gr.Textbox(label="結果", lines=4) | |
| with gr.Column(): | |
| gr.Markdown("### 管理模型") | |
| lib_dropdown = gr.Dropdown(label="已下載模型", choices=model_library.list_display()) | |
| del_key = gr.Textbox(label="API Key", value=PUBLIC_API_KEY, type="password") | |
| del_btn = gr.Button("🗑️ 刪除") | |
| del_msg = gr.Textbox(label="結果", lines=2) | |
| refresh_btn = gr.Button("🔄 刷新") | |
| storage_box = gr.JSON(label="存儲信息") | |
| dl_btn.click(fn=ui_download, inputs=[dl_repo, dl_key], outputs=[dl_msg, lib_dropdown, storage_box]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown]) | |
| del_btn.click(fn=ui_delete, inputs=[lib_dropdown, del_key], outputs=[del_msg, lib_dropdown, storage_box]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown]) | |
| refresh_btn.click(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[lib_dropdown]).then(fn=model_manager.storage_info, outputs=[storage_box]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown]) | |
| print("🚀 啟動中...") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("✅ 準備完成,正在啟動服務器...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |