koai / app.py
kines9661's picture
Update app.py
c10f021 verified
import os
import io
import gc
import json
import time
import base64
import shutil
import requests
from pathlib import Path
from typing import Optional, Dict, List
from urllib.parse import urlparse, parse_qs
import torch
from PIL import Image
import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# 延遲導入重量級庫
diffusers_loaded = False
def load_diffusers():
global DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, diffusers_loaded
if not diffusers_loaded:
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline
diffusers_loaded = True
return DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline
return DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline
huggingface_hub_loaded = False
def load_huggingface_hub():
global snapshot_download, huggingface_hub_loaded
if not huggingface_hub_loaded:
from huggingface_hub import snapshot_download
huggingface_hub_loaded = True
return snapshot_download
return snapshot_download
HF_TOKEN = os.getenv("HF_TOKEN", "")
PUBLIC_API_KEY = os.getenv("PUBLIC_API_KEY", "demo-key-123456")
SPACE_URL = os.getenv("SPACE_URL", "https://your-space.hf.space")
CIVITAI_API_KEY = os.getenv("CIVITAI_API_KEY", "")
def get_writable_dir() -> Path:
candidates = [os.getenv("MODELS_DIR"), os.path.join(os.path.expanduser("~"), ".hf_models"), "/tmp/hf_models"]
for path in candidates:
if not path:
continue
try:
p = Path(path)
p.mkdir(parents=True, exist_ok=True)
test_file = p / ".write_test"
test_file.touch()
test_file.unlink()
return p
except (PermissionError, OSError):
continue
fallback = Path.cwd() / "models"
fallback.mkdir(parents=True, exist_ok=True)
return fallback
PERSIST_DIR = get_writable_dir()
MODELS_DB_FILE = PERSIST_DIR / "model_library.json"
HF_CACHE_DIR = Path.home() / ".cache" / "huggingface" / "hub"
print(f"📁 模型目錄: {PERSIST_DIR}")
class ModelLibrary:
def __init__(self):
self.models: Dict[str, dict] = {}
self.load()
def load(self):
if MODELS_DB_FILE.exists():
try:
with open(MODELS_DB_FILE, "r", encoding="utf-8") as f:
self.models = json.load(f)
print(f"✅ 已載入 {len(self.models)} 個模型記錄")
except Exception:
self.models = {}
else:
self.models = {}
def save(self):
try:
with open(MODELS_DB_FILE, "w", encoding="utf-8") as f:
json.dump(self.models, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"保存失敗: {e}")
def _key(self, repo_id: str) -> str:
return repo_id.replace("/", "--").replace(":", "--")
def add(self, repo_id: str, name: Optional[str] = None, local_path: str = "", source: str = "huggingface"):
key = self._key(repo_id)
self.models[key] = {
"repo_id": repo_id,
"name": name or repo_id.split("/")[-1],
"local_path": local_path,
"source": source,
"added_at": time.strftime("%Y-%m-%d %H:%M"),
"status": "ready"
}
self.save()
def remove(self, repo_id: str):
key = self._key(repo_id)
if key in self.models:
del self.models[key]
self.save()
return True
return False
def list_display(self) -> List[str]:
if not self.models:
return ["暫無已下載模型"]
out = []
for m in self.models.values():
emoji = "✅" if m.get("status") == "ready" else "⚠️"
source_icon = "🤗" if m.get("source") == "huggingface" else "🎨"
out.append(f"{emoji}{source_icon} {m['name']}")
return out
def resolve_repo(self, display: str) -> str:
for m in self.models.values():
if m["name"] in display or m["repo_id"] in display:
return m["repo_id"]
return display.strip()
model_library = ModelLibrary()
class ModelManager:
def __init__(self):
self.loaded: Dict[str, any] = {}
self.current: Optional[str] = None
def local_path(self, repo_id: str) -> Path:
return PERSIST_DIR / repo_id.replace("/", "--").replace(":", "--")
def is_downloaded(self, repo_id: str) -> bool:
lp = self.local_path(repo_id)
return (lp / "model_index.json").exists() or any(lp.glob("*.safetensors"))
def size_gb(self, repo_id: str) -> float:
lp = self.local_path(repo_id)
if not lp.exists():
return 0.0
total = sum(f.stat().st_size for f in lp.rglob("*") if f.is_file())
return round(total / (1024**3), 2)
def storage_info(self) -> dict:
total = sum(self.size_gb(m["repo_id"]) for m in model_library.models.values())
try:
free = shutil.disk_usage(PERSIST_DIR).free / (1024**3)
except:
free = 0.0
return {
"model_count": len(model_library.models),
"total_size_gb": round(total, 2),
"persist_dir": str(PERSIST_DIR),
"free_space_gb": round(free, 2)
}
def parse_civitai_url(self, url: str) -> Optional[Dict]:
try:
if "civitai.com" not in url:
return None
parsed = urlparse(url)
if "/models/" in url:
parts = url.split("/models/")[1].split("/")[0].split("?")[0]
model_id = parts
query_params = parse_qs(parsed.query)
version_id = query_params.get("modelVersionId", [None])[0]
return {"model_id": model_id, "version_id": version_id, "url": url}
return None
except Exception as e:
print(f"解析 Civitai URL 失敗: {e}")
return None
def download_civitai_model(self, url: str) -> tuple:
parsed = self.parse_civitai_url(url)
if not parsed:
return False, "❌ 無效的 Civitai URL"
model_id = parsed["model_id"]
version_id = parsed["version_id"]
try:
api_url = f"https://civitai.com/api/v1/models/{model_id}"
headers = {}
if CIVITAI_API_KEY:
headers["Authorization"] = f"Bearer {CIVITAI_API_KEY}"
print(f"📡 獲取模型信息: {api_url}")
response = requests.get(api_url, headers=headers, timeout=30)
response.raise_for_status()
model_info = response.json()
model_name = model_info.get("name", f"civitai-{model_id}")
versions = model_info.get("modelVersions", [])
if not versions:
return False, "❌ 找不到模型版本"
selected_version = None
if version_id:
selected_version = next((v for v in versions if str(v["id"]) == version_id), None)
if not selected_version:
selected_version = versions[0]
files = selected_version.get("files", [])
if not files:
return False, "❌ 找不到下載文件"
download_file = None
for f in files:
if f.get("type") == "Model" or f["name"].endswith(".safetensors"):
download_file = f
break
if not download_file:
download_file = files[0]
download_url = download_file["downloadUrl"]
file_name = download_file["name"]
file_size = download_file.get("sizeKB", 0) / 1024 / 1024
repo_id = f"civitai:{model_id}"
if version_id:
repo_id += f":{version_id}"
local_path = self.local_path(repo_id)
local_path.mkdir(parents=True, exist_ok=True)
output_file = local_path / file_name
if output_file.exists():
model_library.add(repo_id, model_name, str(local_path), source="civitai")
return True, f"✅ 模型已存在: {model_name}"
print(f"⬇️ 下載 Civitai 模型: {model_name} ({file_size:.2f} GB)")
start = time.time()
if CIVITAI_API_KEY and "?" in download_url:
download_url += f"&token={CIVITAI_API_KEY}"
elif CIVITAI_API_KEY:
download_url += f"?token={CIVITAI_API_KEY}"
response = requests.get(download_url, stream=True, timeout=300)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(output_file, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
elapsed = time.time() - start
actual_size = self.size_gb(repo_id)
model_library.add(repo_id, model_name, str(local_path), source="civitai")
return True, f"✅ Civitai 模型下載完成\n名稱: {model_name}\n大小: {actual_size} GB\n耗時: {elapsed:.1f} 秒"
except requests.RequestException as e:
return False, f"❌ 網絡錯誤: {str(e)[:150]}"
except Exception as e:
return False, f"❌ 下載失敗: {str(e)[:150]}"
def download(self, repo_id: str):
if "civitai.com" in repo_id:
return self.download_civitai_model(repo_id)
if self.is_downloaded(repo_id):
lp = self.local_path(repo_id)
model_library.add(repo_id, repo_id.split("/")[-1], str(lp), source="huggingface")
return True, f"✅ 模型已存在: {lp}"
lp = self.local_path(repo_id)
print(f"⬇️ 下載 HF 模型: {repo_id}")
try:
snapshot_download_func = load_huggingface_hub()
start = time.time()
snapshot_download_func(
repo_id=repo_id,
local_dir=str(lp),
local_dir_use_symlinks=False,
cache_dir=str(HF_CACHE_DIR),
token=HF_TOKEN or None
)
elapsed = time.time() - start
size = self.size_gb(repo_id)
model_library.add(repo_id, repo_id.split("/")[-1], str(lp), source="huggingface")
return True, f"✅ HF 模型下載完成\n大小: {size} GB\n耗時: {elapsed:.1f} 秒"
except Exception as e:
err = str(e)
if "401" in err or "403" in err:
return False, "❌ 需要授權: 請設置 HF_TOKEN"
if "not found" in err.lower():
return False, "❌ 找不到模型"
return False, f"❌ 下載失敗: {err[:150]}"
def load(self, repo_id: str):
if not self.is_downloaded(repo_id):
ok, msg = self.download(repo_id)
if not ok:
return False, msg, None
if repo_id in self.loaded:
self.current = repo_id
return True, "使用已加載模型", self.loaded[repo_id]
lp = self.local_path(repo_id)
try:
print(f"⚙️ 加載模型: {lp}")
DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline = load_diffusers()
safetensors_files = list(lp.glob("*.safetensors"))
if safetensors_files and not (lp / "model_index.json").exists():
pipe = StableDiffusionPipeline.from_single_file(
str(safetensors_files[0]),
torch_dtype=torch.float32,
device_map="cpu",
safety_checker=None,
low_cpu_mem_usage=True
)
else:
pipe = DiffusionPipeline.from_pretrained(
str(lp),
torch_dtype=torch.float32,
device_map="cpu",
safety_checker=None,
low_cpu_mem_usage=True,
use_safetensors=True
)
if hasattr(pipe, 'scheduler') and pipe.scheduler is not None:
try:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
except:
pass
if not hasattr(pipe, 'text_encoder') or pipe.text_encoder is None:
return False, "❌ 缺少 text_encoder", None
self.loaded[repo_id] = pipe
self.current = repo_id
return True, "✅ 加載成功", pipe
except Exception as e:
error_msg = str(e)
print(f"❌ 錯誤: {error_msg}")
return False, f"❌ 加載失敗: {error_msg[:200]}", None
def delete(self, repo_id: str):
if repo_id in self.loaded:
del self.loaded[repo_id]
gc.collect()
model_library.remove(repo_id)
lp = self.local_path(repo_id)
if lp.exists():
try:
shutil.rmtree(lp)
return True, f"✅ 已刪除: {repo_id}"
except Exception as e:
return False, f"⚠️ 刪除失敗: {str(e)[:200]}"
return True, f"✅ 已移除記錄: {repo_id}"
model_manager = ModelManager()
app = FastAPI()
class GenerateRequest(BaseModel):
prompt: str
model: str
negative_prompt: str = ""
guidance_scale: float = 7.5
num_steps: int = 25
api_key: str
class GenerateResponse(BaseModel):
status: int
image_base64: Optional[str] = None
model: Optional[str] = None
generation_time: Optional[float] = None
error: Optional[str] = None
@app.post("/api/generate")
async def api_generate(req: GenerateRequest):
if req.api_key != PUBLIC_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API Key")
if not req.prompt or len(req.prompt) < 3:
raise HTTPException(status_code=400, detail="Prompt too short")
ok, msg, pipe = model_manager.load(req.model)
if not ok or pipe is None:
return GenerateResponse(status=500, error=msg)
try:
t0 = time.time()
img = pipe(prompt=req.prompt, negative_prompt=req.negative_prompt or None, num_inference_steps=max(4, min(int(req.num_steps), 50)), guidance_scale=req.guidance_scale, height=512, width=512).images[0]
dt = time.time() - t0
buf = io.BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return GenerateResponse(status=200, image_base64=f"data:image/png;base64,{b64}", model=req.model, generation_time=round(dt, 2))
except Exception as e:
return GenerateResponse(status=500, error=str(e)[:200])
@app.get("/api/models")
async def api_models():
return {"models": model_library.models, "storage": model_manager.storage_info()}
@app.get("/health")
async def health():
return {"status": "ok", "storage": model_manager.storage_info(), "loaded_models": list(model_manager.loaded.keys())}
def build_api_examples(prompt, model_repo, negative, guidance, steps):
if not model_repo:
model_repo = "prompthero/openjourney-v4"
payload = {"prompt": prompt or "A beautiful sunset", "model": model_repo.strip(), "negative_prompt": negative or "", "guidance_scale": guidance, "num_steps": int(steps), "api_key": PUBLIC_API_KEY}
json_body = json.dumps(payload, indent=2, ensure_ascii=False)
curl = f"curl -X POST {SPACE_URL}/api/generate -H 'Content-Type: application/json' -d '{json_body}'"
return json_body, curl
def ui_generate(prompt, model_display, repo_input, negative, guidance, steps, api_key):
if api_key != PUBLIC_API_KEY:
return None, "❌ API Key 無效"
if not prompt or len(prompt) < 3:
return None, "❌ 提示詞太短"
repo_id = repo_input.strip() if repo_input and len(repo_input.strip()) > 3 else model_library.resolve_repo(model_display)
if not repo_id or "暫無" in repo_id:
return None, "❌ 請選擇或輸入模型"
ok, msg, pipe = model_manager.load(repo_id)
if not ok or pipe is None:
return None, msg
try:
t0 = time.time()
img = pipe(prompt=prompt, negative_prompt=negative or None, num_inference_steps=max(4, min(int(steps), 50)), guidance_scale=guidance, height=512, width=512).images[0]
dt = time.time() - t0
size = model_manager.size_gb(repo_id)
txt = f"✅ 成功\n模型: {repo_id}\n大小: {size} GB\n耗時: {dt:.1f} 秒"
return img, txt
except Exception as e:
return None, f"❌ 失敗: {str(e)[:200]}"
def ui_download(repo_id, api_key):
if api_key != PUBLIC_API_KEY:
return "❌ API Key 無效", gr.update(choices=model_library.list_display()), model_manager.storage_info()
if not repo_id or len(repo_id.strip()) < 3:
return "❌ 請輸入模型地址或 Civitai URL", gr.update(choices=model_library.list_display()), model_manager.storage_info()
ok, msg = model_manager.download(repo_id.strip())
return msg, gr.update(choices=model_library.list_display()), model_manager.storage_info()
def ui_delete(model_display, api_key):
if api_key != PUBLIC_API_KEY:
return "❌ API Key 無效", gr.update(choices=model_library.list_display()), model_manager.storage_info()
if not model_display or "暫無" in model_display:
return "❌ 請先選擇模型", gr.update(choices=model_library.list_display()), model_manager.storage_info()
repo_id = model_library.resolve_repo(model_display)
ok, msg = model_manager.delete(repo_id)
return msg, gr.update(choices=model_library.list_display()), model_manager.storage_info()
def ui_update_api_examples(prompt, model_display, repo_input, negative, guidance, steps):
repo_id = repo_input.strip() if repo_input and len(repo_input.strip()) > 3 else (model_library.resolve_repo(model_display) if model_display and "暫無" not in model_display else "prompthero/openjourney-v4")
return build_api_examples(prompt, repo_id, negative, guidance, steps)
with gr.Blocks(title="AI 圖片生成", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# 🖼️ AI 圖片生成(HF + Civitai)
- 📁 模型目錄: `{PERSIST_DIR}`
- 🔑 API Key: `{PUBLIC_API_KEY}`
- 🤗 **HF**: `prompthero/openjourney-v4` · 🎨 **Civitai**: `https://civitai.com/models/xxx`
""")
with gr.Tabs():
with gr.Tab("🎨 生圖"):
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(label="已下載模型", choices=model_library.list_display(), value=model_library.list_display()[0] if model_library.list_display() else None)
repo_box = gr.Textbox(label="或輸入模型地址 / Civitai URL", placeholder="prompthero/openjourney-v4")
prompt_box = gr.Textbox(label="提示詞", lines=3, placeholder="a beautiful sunset...")
negative_box = gr.Textbox(label="負面提示詞", lines=2, placeholder="blurry, low quality...")
with gr.Row():
guidance_slider = gr.Slider(1, 20, 7.5, step=0.5, label="Guidance")
steps_slider = gr.Slider(4, 50, 25, step=1, label="Steps")
api_key_box = gr.Textbox(label="API Key", value=PUBLIC_API_KEY, type="password")
gen_btn = gr.Button("🚀 生成", variant="primary")
with gr.Column():
out_img = gr.Image(label="結果")
out_msg = gr.Textbox(label="狀態", lines=5)
with gr.Accordion("📡 API 示例", open=False):
api_json = gr.Code(label="JSON", language="json")
api_curl = gr.Code(label="cURL", language="shell")
gen_btn.click(fn=ui_generate, inputs=[prompt_box, model_dropdown, repo_box, negative_box, guidance_slider, steps_slider, api_key_box], outputs=[out_img, out_msg]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown])
for comp in [prompt_box, model_dropdown, repo_box, negative_box, guidance_slider, steps_slider]:
comp.change(fn=ui_update_api_examples, inputs=[prompt_box, model_dropdown, repo_box, negative_box, guidance_slider, steps_slider], outputs=[api_json, api_curl])
with gr.Tab("⬇️ 模型管理"):
with gr.Row():
with gr.Column():
gr.Markdown("### 下載新模型")
dl_repo = gr.Textbox(label="模型地址 / URL", placeholder="prompthero/openjourney-v4")
dl_key = gr.Textbox(label="API Key", value=PUBLIC_API_KEY, type="password")
dl_btn = gr.Button("⬇️ 下載")
dl_msg = gr.Textbox(label="結果", lines=4)
with gr.Column():
gr.Markdown("### 管理模型")
lib_dropdown = gr.Dropdown(label="已下載模型", choices=model_library.list_display())
del_key = gr.Textbox(label="API Key", value=PUBLIC_API_KEY, type="password")
del_btn = gr.Button("🗑️ 刪除")
del_msg = gr.Textbox(label="結果", lines=2)
refresh_btn = gr.Button("🔄 刷新")
storage_box = gr.JSON(label="存儲信息")
dl_btn.click(fn=ui_download, inputs=[dl_repo, dl_key], outputs=[dl_msg, lib_dropdown, storage_box]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown])
del_btn.click(fn=ui_delete, inputs=[lib_dropdown, del_key], outputs=[del_msg, lib_dropdown, storage_box]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown])
refresh_btn.click(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[lib_dropdown]).then(fn=model_manager.storage_info, outputs=[storage_box]).then(fn=lambda: gr.update(choices=model_library.list_display()), outputs=[model_dropdown])
print("🚀 啟動中...")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
print("✅ 準備完成,正在啟動服務器...")
uvicorn.run(app, host="0.0.0.0", port=7860)