Spaces:
Paused
Paused
| from __future__ import annotations | |
| import base64 | |
| import os | |
| import platform | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from .config import AppConfig, load_config, save_config | |
| from .models import ConfigUpdate, I2IRequest, InpaintRequest, T2IRequest | |
| from .services.novelai import generate_i2i, generate_inpaint, generate_t2i | |
| def _ensure_x64(n: int) -> int: | |
| if n <= 64: | |
| return 64 | |
| if n % 64 == 0: | |
| return n | |
| return ((n // 64) + 1) * 64 if (n / 64) % 1 >= 0.5 else (n // 64) * 64 | |
| def _as_data_uri(b64: str) -> str: | |
| return f"data:image/png;base64,{b64}" | |
| app = FastAPI(title="New NAI", version="1.0.0") | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health() -> Dict[str, Any]: | |
| return {"status": "ok"} | |
| def get_config() -> Dict[str, Any]: | |
| cfg: AppConfig = load_config() | |
| return cfg.model_dump() | |
| def update_config(update: ConfigUpdate) -> Dict[str, Any]: | |
| cfg = load_config() | |
| data = update.model_dump(exclude_none=True) | |
| for k, v in data.items(): | |
| setattr(cfg, k, v) | |
| save_config(cfg) | |
| return cfg.model_dump() | |
| def api_select_output_dir() -> Dict[str, Any]: | |
| """ | |
| 在本机弹出目录选择器并返回选择的目录路径。 | |
| 多重回退方案: | |
| 1) tkinter(如可用) | |
| 2) Windows: PowerShell + Shell.Application.BrowseForFolder | |
| 3) macOS: osascript choose folder | |
| 4) Linux: zenity --file-selection --directory | |
| 全部失败则返回 500。 | |
| """ | |
| # 1) tkinter | |
| try: | |
| import tkinter as tk # type: ignore | |
| from tkinter import filedialog # type: ignore | |
| root = tk.Tk() | |
| root.withdraw() | |
| path = filedialog.askdirectory(title="选择保存目录") | |
| try: | |
| root.destroy() | |
| except Exception: | |
| pass | |
| if path: | |
| return {"path": path} | |
| except Exception: | |
| pass | |
| system = platform.system() | |
| # 2) Windows: PowerShell COM Shell.Application | |
| if system == "Windows": | |
| # 2a) PowerShell COM Shell.Application | |
| try: | |
| import subprocess | |
| ps_cmd = r'$f=(New-Object -ComObject Shell.Application).BrowseForFolder(0,"选择保存目录",0); if($f){$f.Self.Path}' | |
| res = subprocess.run( | |
| ["powershell", "-NoProfile", "-Command", ps_cmd], | |
| capture_output=True, text=True, timeout=60 | |
| ) | |
| out = (res.stdout or "").strip() | |
| if out: | |
| return {"path": out} | |
| except Exception: | |
| pass | |
| # 2b) VBScript + cscript(兼容禁用 PowerShell 的环境) | |
| try: | |
| import subprocess | |
| vbs = ( | |
| 'Set sh = CreateObject("Shell.Application")\n' | |
| 'Set f = sh.BrowseForFolder(0, "选择保存目录", 0)\n' | |
| 'If (Not f Is Nothing) Then\n' | |
| ' WScript.Echo f.Self.Path\n' | |
| 'End If\n' | |
| ) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".vbs") as tf: | |
| tf.write(vbs.encode("utf-8")) | |
| tf_path = tf.name | |
| try: | |
| res = subprocess.run( | |
| ["cscript", "//nologo", tf_path], | |
| capture_output=True, text=True, timeout=60 | |
| ) | |
| finally: | |
| try: | |
| os.remove(tf_path) | |
| except Exception: | |
| pass | |
| out = (res.stdout or "").strip() | |
| if out: | |
| return {"path": out} | |
| except Exception: | |
| pass | |
| # 3) macOS: AppleScript choose folder | |
| if system == "Darwin": | |
| try: | |
| import subprocess | |
| script = 'tell application "System Events" to POSIX path of (choose folder with prompt "选择保存目录")' | |
| res = subprocess.run( | |
| ["osascript", "-e", script], | |
| capture_output=True, text=True, timeout=60 | |
| ) | |
| out = (res.stdout or "").strip() | |
| if out: | |
| return {"path": out} | |
| except Exception: | |
| pass | |
| # 4) Linux: zenity | |
| if system == "Linux": | |
| try: | |
| import subprocess | |
| res = subprocess.run( | |
| ["zenity", "--file-selection", "--directory", "--title=选择保存目录"], | |
| capture_output=True, text=True, timeout=60 | |
| ) | |
| out = (res.stdout or "").strip() | |
| if out: | |
| return {"path": out} | |
| except Exception: | |
| pass | |
| raise HTTPException(status_code=500, detail="无法打开文件夹选择器:所有策略均失败") | |
| def api_open_dir(payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| 打开指定目录;若未传 path,则打开配置中的 output_dir。 | |
| Windows 使用 os.startfile,macOS 用 open,Linux 用 xdg-open。 | |
| """ | |
| path = (payload or {}).get("path") or "" | |
| if not path: | |
| cfg = load_config() | |
| path = cfg.output_dir | |
| if not path: | |
| raise HTTPException(status_code=400, detail="未提供路径且配置中未设置 output_dir") | |
| p = Path(path) | |
| try: | |
| p.mkdir(parents=True, exist_ok=True) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"创建目录失败: {e}") from e | |
| try: | |
| if hasattr(os, "startfile"): | |
| os.startfile(str(p)) # Windows | |
| elif platform.system() == "Darwin": | |
| import subprocess | |
| subprocess.run(["open", str(p)], check=False) | |
| else: | |
| import subprocess | |
| subprocess.run(["xdg-open", str(p)], check=False) | |
| return {"ok": True, "path": str(p)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"无法打开目录: {e}") from e | |
| def api_t2i(req: T2IRequest): | |
| cfg = load_config() | |
| if not cfg.key: | |
| raise HTTPException(status_code=400, detail="尚未配置 key,请先在配置中设置 key。") | |
| width = _ensure_x64(req.width or 768) | |
| height = _ensure_x64(req.height or 768) | |
| try: | |
| b64, saved = generate_t2i( | |
| cfg, | |
| prompt=req.prompt, | |
| negative=req.negative or "", | |
| width=width, | |
| height=height, | |
| scale=req.scale, | |
| steps=req.steps, | |
| sampler=req.sampler, | |
| noise_schedule=req.noise_schedule, | |
| seed=req.seed, | |
| variety=req.variety, | |
| decrisp=req.decrisp, | |
| cfg_rescale=req.cfg_rescale, | |
| ) | |
| return {"image_base64": _as_data_uri(b64), "saved_path": saved} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| def api_i2i(req: I2IRequest): | |
| cfg = load_config() | |
| if not cfg.key: | |
| raise HTTPException(status_code=400, detail="尚未配置 key,请先在配置中设置 key。") | |
| width = _ensure_x64(req.width or 768) | |
| height = _ensure_x64(req.height or 768) | |
| try: | |
| b64, saved = generate_i2i( | |
| cfg, | |
| positive=req.positive, | |
| negative=req.negative or "", | |
| image_base64=req.image_base64, | |
| width=width, | |
| height=height, | |
| scale=req.scale, | |
| steps=req.steps, | |
| sampler=req.sampler, | |
| noise_schedule=req.noise_schedule, | |
| strength=req.strength or 0.5, | |
| noise=req.noise or 0.0, | |
| seed=req.seed, | |
| variety=req.variety, | |
| decrisp=req.decrisp, | |
| cfg_rescale=req.cfg_rescale, | |
| ) | |
| return {"image_base64": _as_data_uri(b64), "saved_path": saved} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| def api_inpaint(req: InpaintRequest): | |
| cfg = load_config() | |
| if not cfg.key: | |
| raise HTTPException(status_code=400, detail="尚未配置 key,请先在配置中设置 key。") | |
| width = _ensure_x64(req.width or 768) | |
| height = _ensure_x64(req.height or 768) | |
| try: | |
| b64, saved = generate_inpaint( | |
| cfg, | |
| positive=req.positive, | |
| negative=req.negative or "", | |
| image_base64=req.image_base64, | |
| mask_base64=req.mask_base64, | |
| add_original_image=req.add_original_image, | |
| width=width, | |
| height=height, | |
| scale=req.scale, | |
| steps=req.steps, | |
| sampler=req.sampler, | |
| noise_schedule=req.noise_schedule, | |
| strength=req.strength or 0.5, | |
| noise=req.noise or 0.0, | |
| seed=req.seed, | |
| variety=req.variety, | |
| decrisp=req.decrisp, | |
| cfg_rescale=req.cfg_rescale, | |
| ) | |
| return {"image_base64": _as_data_uri(b64), "saved_path": saved} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| # 前端静态资源(仅保留必要 UI,无教程/仓库链接) | |
| _frontend_dir = Path(__file__).resolve().parent.parent / "frontend" | |
| # 提示音静态资源映射:/ring -> 项目根/ring 目录(例如 G:\NOVELAI\New NAI\ring) | |
| _ring_dir = Path(__file__).resolve().parent.parent / "ring" | |
| app.mount("/ring", StaticFiles(directory=_ring_dir, html=False), name="ring") | |
| app.mount("/", StaticFiles(directory=_frontend_dir, html=True), name="frontend") |