schroneko
Expose Hugging Face token status
fec5b89
Raw
History Blame Contribute Delete
10.2 kB
from __future__ import annotations
import hashlib
import json
import os
import shutil
import time
import uuid
from pathlib import Path
from typing import Any
import requests
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse, JSONResponse
ZERO_SPACE = os.getenv("ZERO_SPACE", "schroneko/irodori-tts-zerogpu")
HF_TOKEN = os.getenv("HF_TOKEN", "") or os.getenv("HUGGING_FACE_HUB_TOKEN", "")
TTS_API_KEY = os.getenv("TTS_API_KEY", "")
PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").rstrip("/")
MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "240"))
DEFAULT_SECONDS_RAW = os.getenv("DEFAULT_SECONDS", "").strip()
DEFAULT_SECONDS = "" if DEFAULT_SECONDS_RAW.lower() in {"", "auto", "none"} else DEFAULT_SECONDS_RAW
DEFAULT_DURATION_SCALE = float(os.getenv("DEFAULT_DURATION_SCALE", "0.95"))
DEFAULT_STEPS = int(os.getenv("DEFAULT_STEPS", "18"))
DEFAULT_SEED = int(os.getenv("DEFAULT_SEED", "3407"))
DEFAULT_CAPTION = os.getenv(
"DEFAULT_CAPTION",
"若く元気な女性の声。近い距離感で、明るくやわらかく自然に話している。",
)
MAX_CACHE_ENTRIES = int(os.getenv("MAX_CACHE_ENTRIES", "256"))
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "/tmp/stackchan-audio"))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title="Irodori TTS StackChan API")
def _base_url(request: Request) -> str:
if PUBLIC_BASE_URL:
return PUBLIC_BASE_URL
proto = request.headers.get("x-forwarded-proto", request.url.scheme)
host = request.headers.get("x-forwarded-host", request.headers.get("host", request.url.netloc))
return f"{proto}://{host}".rstrip("/")
def _check_key(key: str | None) -> None:
if TTS_API_KEY and key != TTS_API_KEY:
raise HTTPException(status_code=401, detail="invalid key")
def _seed_for_speaker(speaker: str) -> str:
digest = hashlib.sha256(f"{DEFAULT_SEED}:{speaker}".encode("utf-8")).digest()
value = int.from_bytes(digest[:8], "big") & ((1 << 63) - 1)
return str(value)
def _cache_key(payload: dict[str, Any]) -> str:
data = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(data.encode("utf-8")).hexdigest()
def _cache_path(cache_key: str) -> Path:
return OUTPUT_DIR / f"cache_{cache_key}.mp3"
def _prune_cache() -> None:
paths = sorted(OUTPUT_DIR.glob("cache_*.mp3"), key=lambda path: path.stat().st_mtime, reverse=True)
for path in paths[MAX_CACHE_ENTRIES:]:
path.unlink(missing_ok=True)
def _zero_space_base_url() -> str:
space = ZERO_SPACE.strip().rstrip("/")
if space.startswith("http://") or space.startswith("https://"):
return space
return f"https://{space.replace('/', '-')}.hf.space"
def _zero_headers(x_ip_token: str = "") -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if HF_TOKEN:
headers["Authorization"] = f"Bearer {HF_TOKEN}"
if x_ip_token:
headers["x-ip-token"] = x_ip_token
return headers
def _raise_for_response(response: requests.Response) -> None:
if response.ok:
return
detail = response.text.strip()
raise RuntimeError(f"ZeroGPU request failed: {response.status_code} {detail}")
def _predict_zero_space(payload: dict[str, Any], x_ip_token: str = "") -> Any:
base_url = _zero_space_base_url()
response = _post_zero_space(base_url, "v2/synthesize", payload, x_ip_token)
if response.status_code == 405:
response = _post_zero_space(base_url, "synthesize", payload, x_ip_token)
_raise_for_response(response)
event_id = response.json().get("event_id")
if not event_id:
raise RuntimeError(f"ZeroGPU did not return event_id: {response.text}")
response = requests.get(
f"{base_url}/gradio_api/call/synthesize/{event_id}",
headers=_zero_headers(x_ip_token),
stream=True,
timeout=(10, 300),
)
_raise_for_response(response)
event_name = ""
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
continue
if not line.startswith("data:"):
continue
data = line.split(":", 1)[1].strip()
if event_name == "error":
raise RuntimeError(data)
parsed = json.loads(data)
if event_name == "complete":
return parsed
raise RuntimeError("ZeroGPU stream ended before completion")
def _post_zero_space(base_url: str, endpoint: str, payload: dict[str, Any], x_ip_token: str = "") -> requests.Response:
body: dict[str, Any] = payload
if endpoint == "synthesize":
body = {
"data": [
payload["text"],
payload["speaker"],
payload["seconds"],
payload["duration_scale"],
payload["steps"],
payload["seed"],
payload["caption"],
]
}
return requests.post(
f"{base_url}/gradio_api/call/{endpoint}",
json=body,
headers=_zero_headers(x_ip_token),
timeout=30,
)
def _copy_result_file(result: Any, x_ip_token: str = "") -> Path:
source: str | None = None
if isinstance(result, (list, tuple)) and result:
source = result[0]
elif isinstance(result, dict):
source = result.get("path") or result.get("name")
elif isinstance(result, str):
source = result
if isinstance(source, dict):
source = source.get("url") or source.get("path") or source.get("name")
if not source:
raise RuntimeError(f"Could not find generated audio path in result: {result!r}")
output_path = OUTPUT_DIR / f"{int(time.time() * 1000)}_{uuid.uuid4().hex}.mp3"
source_text = str(source)
if source_text.startswith("http://") or source_text.startswith("https://"):
response = requests.get(source_text, headers=_zero_headers(x_ip_token), timeout=60)
_raise_for_response(response)
output_path.write_bytes(response.content)
return output_path
source_path = Path(str(source))
if not source_path.is_file():
raise RuntimeError(f"Generated audio file is missing: {source_path}")
shutil.copyfile(source_path, output_path)
return output_path
def _result_metadata(result: Any) -> dict[str, Any]:
if isinstance(result, (list, tuple)) and len(result) > 1 and isinstance(result[1], dict):
return result[1]
return {}
@app.get("/")
def root() -> dict[str, str]:
return {"ok": "true", "service": "irodori-tts-stackchan-api"}
@app.get("/health")
def health() -> dict[str, str]:
return {
"ok": "true",
"zero_space": ZERO_SPACE,
"duration_scale": str(DEFAULT_DURATION_SCALE),
"has_hf_token": str(bool(HF_TOKEN)).lower(),
"cache_entries": str(len(list(OUTPUT_DIR.glob("cache_*.mp3")))),
"max_cache_entries": str(MAX_CACHE_ENTRIES),
}
@app.get("/audio/{filename}")
def audio(filename: str) -> FileResponse:
path = OUTPUT_DIR / filename
if not path.is_file():
raise HTTPException(status_code=404, detail="audio not found")
return FileResponse(path, media_type="audio/mpeg", filename=filename)
@app.get("/synthesis")
def synthesis(
request: Request,
key: str | None = Query(default=None),
text: str = Query(..., min_length=1),
speaker: str = Query(default="3"),
seconds: str = Query(default=DEFAULT_SECONDS),
duration_scale: float = Query(default=DEFAULT_DURATION_SCALE, gt=0.0, le=2.0),
steps: int = Query(default=DEFAULT_STEPS, ge=1, le=80),
seed: str = Query(default=""),
caption: str = Query(default=DEFAULT_CAPTION),
) -> JSONResponse:
_check_key(key)
text = text.strip()
if not text:
raise HTTPException(status_code=400, detail="text is required")
if len(text) > MAX_TEXT_LENGTH:
raise HTTPException(status_code=400, detail=f"text is too long: max {MAX_TEXT_LENGTH}")
seed_value = str(seed).strip() or _seed_for_speaker(str(speaker))
payload = {
"text": text,
"speaker": str(speaker),
"seconds": str(seconds),
"duration_scale": float(duration_scale),
"steps": int(steps),
"seed": seed_value,
"caption": str(caption).strip() or DEFAULT_CAPTION,
}
cache_key = _cache_key(payload)
cached_path = _cache_path(cache_key)
if cached_path.is_file():
url = f"{_base_url(request)}/audio/{cached_path.name}"
return JSONResponse(
{
"success": True,
"isApiKeyValid": True,
"speaker": str(speaker),
"seed": seed_value,
"durationScale": float(duration_scale),
"metadata": {"cacheHit": True, "cacheKey": cache_key},
"mp3StreamingUrl": url,
"mp3DownloadUrl": url,
"audioStatusUrl": f"{_base_url(request)}/health",
}
)
try:
x_ip_token = request.headers.get("x-ip-token", "").strip()
result = _predict_zero_space(payload, x_ip_token=x_ip_token)
metadata = _result_metadata(result)
output_path = _copy_result_file(result, x_ip_token=x_ip_token)
shutil.copyfile(output_path, cached_path)
_prune_cache()
except Exception as exc:
return JSONResponse(
status_code=502,
content={
"success": False,
"isApiKeyValid": bool(not TTS_API_KEY or key == TTS_API_KEY),
"error": str(exc),
},
)
url = f"{_base_url(request)}/audio/{output_path.name}"
return JSONResponse(
{
"success": True,
"isApiKeyValid": True,
"speaker": str(speaker),
"seed": seed_value,
"durationScale": float(duration_scale),
"metadata": {**metadata, "cacheHit": False, "cacheKey": cache_key},
"mp3StreamingUrl": url,
"mp3DownloadUrl": url,
"audioStatusUrl": f"{_base_url(request)}/health",
}
)