from __future__ import annotations import hashlib import json import os import shutil import time import uuid from pathlib import Path from typing import Any import requests from fastapi import FastAPI, HTTPException, Query, Request from fastapi.responses import FileResponse, JSONResponse ZERO_SPACE = os.getenv("ZERO_SPACE", "schroneko/irodori-tts-zerogpu") HF_TOKEN = os.getenv("HF_TOKEN", "") or os.getenv("HUGGING_FACE_HUB_TOKEN", "") TTS_API_KEY = os.getenv("TTS_API_KEY", "") PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").rstrip("/") MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "240")) DEFAULT_SECONDS_RAW = os.getenv("DEFAULT_SECONDS", "").strip() DEFAULT_SECONDS = "" if DEFAULT_SECONDS_RAW.lower() in {"", "auto", "none"} else DEFAULT_SECONDS_RAW DEFAULT_DURATION_SCALE = float(os.getenv("DEFAULT_DURATION_SCALE", "0.95")) DEFAULT_STEPS = int(os.getenv("DEFAULT_STEPS", "18")) DEFAULT_SEED = int(os.getenv("DEFAULT_SEED", "3407")) DEFAULT_CAPTION = os.getenv( "DEFAULT_CAPTION", "若く元気な女性の声。近い距離感で、明るくやわらかく自然に話している。", ) MAX_CACHE_ENTRIES = int(os.getenv("MAX_CACHE_ENTRIES", "256")) OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "/tmp/stackchan-audio")) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI(title="Irodori TTS StackChan API") def _base_url(request: Request) -> str: if PUBLIC_BASE_URL: return PUBLIC_BASE_URL proto = request.headers.get("x-forwarded-proto", request.url.scheme) host = request.headers.get("x-forwarded-host", request.headers.get("host", request.url.netloc)) return f"{proto}://{host}".rstrip("/") def _check_key(key: str | None) -> None: if TTS_API_KEY and key != TTS_API_KEY: raise HTTPException(status_code=401, detail="invalid key") def _seed_for_speaker(speaker: str) -> str: digest = hashlib.sha256(f"{DEFAULT_SEED}:{speaker}".encode("utf-8")).digest() value = int.from_bytes(digest[:8], "big") & ((1 << 63) - 1) return str(value) def _cache_key(payload: dict[str, Any]) -> str: data = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) return hashlib.sha256(data.encode("utf-8")).hexdigest() def _cache_path(cache_key: str) -> Path: return OUTPUT_DIR / f"cache_{cache_key}.mp3" def _prune_cache() -> None: paths = sorted(OUTPUT_DIR.glob("cache_*.mp3"), key=lambda path: path.stat().st_mtime, reverse=True) for path in paths[MAX_CACHE_ENTRIES:]: path.unlink(missing_ok=True) def _zero_space_base_url() -> str: space = ZERO_SPACE.strip().rstrip("/") if space.startswith("http://") or space.startswith("https://"): return space return f"https://{space.replace('/', '-')}.hf.space" def _zero_headers(x_ip_token: str = "") -> dict[str, str]: headers = {"Content-Type": "application/json"} if HF_TOKEN: headers["Authorization"] = f"Bearer {HF_TOKEN}" if x_ip_token: headers["x-ip-token"] = x_ip_token return headers def _raise_for_response(response: requests.Response) -> None: if response.ok: return detail = response.text.strip() raise RuntimeError(f"ZeroGPU request failed: {response.status_code} {detail}") def _predict_zero_space(payload: dict[str, Any], x_ip_token: str = "") -> Any: base_url = _zero_space_base_url() response = _post_zero_space(base_url, "v2/synthesize", payload, x_ip_token) if response.status_code == 405: response = _post_zero_space(base_url, "synthesize", payload, x_ip_token) _raise_for_response(response) event_id = response.json().get("event_id") if not event_id: raise RuntimeError(f"ZeroGPU did not return event_id: {response.text}") response = requests.get( f"{base_url}/gradio_api/call/synthesize/{event_id}", headers=_zero_headers(x_ip_token), stream=True, timeout=(10, 300), ) _raise_for_response(response) event_name = "" for line in response.iter_lines(decode_unicode=True): if not line: continue if line.startswith("event:"): event_name = line.split(":", 1)[1].strip() continue if not line.startswith("data:"): continue data = line.split(":", 1)[1].strip() if event_name == "error": raise RuntimeError(data) parsed = json.loads(data) if event_name == "complete": return parsed raise RuntimeError("ZeroGPU stream ended before completion") def _post_zero_space(base_url: str, endpoint: str, payload: dict[str, Any], x_ip_token: str = "") -> requests.Response: body: dict[str, Any] = payload if endpoint == "synthesize": body = { "data": [ payload["text"], payload["speaker"], payload["seconds"], payload["duration_scale"], payload["steps"], payload["seed"], payload["caption"], ] } return requests.post( f"{base_url}/gradio_api/call/{endpoint}", json=body, headers=_zero_headers(x_ip_token), timeout=30, ) def _copy_result_file(result: Any, x_ip_token: str = "") -> Path: source: str | None = None if isinstance(result, (list, tuple)) and result: source = result[0] elif isinstance(result, dict): source = result.get("path") or result.get("name") elif isinstance(result, str): source = result if isinstance(source, dict): source = source.get("url") or source.get("path") or source.get("name") if not source: raise RuntimeError(f"Could not find generated audio path in result: {result!r}") output_path = OUTPUT_DIR / f"{int(time.time() * 1000)}_{uuid.uuid4().hex}.mp3" source_text = str(source) if source_text.startswith("http://") or source_text.startswith("https://"): response = requests.get(source_text, headers=_zero_headers(x_ip_token), timeout=60) _raise_for_response(response) output_path.write_bytes(response.content) return output_path source_path = Path(str(source)) if not source_path.is_file(): raise RuntimeError(f"Generated audio file is missing: {source_path}") shutil.copyfile(source_path, output_path) return output_path def _result_metadata(result: Any) -> dict[str, Any]: if isinstance(result, (list, tuple)) and len(result) > 1 and isinstance(result[1], dict): return result[1] return {} @app.get("/") def root() -> dict[str, str]: return {"ok": "true", "service": "irodori-tts-stackchan-api"} @app.get("/health") def health() -> dict[str, str]: return { "ok": "true", "zero_space": ZERO_SPACE, "duration_scale": str(DEFAULT_DURATION_SCALE), "has_hf_token": str(bool(HF_TOKEN)).lower(), "cache_entries": str(len(list(OUTPUT_DIR.glob("cache_*.mp3")))), "max_cache_entries": str(MAX_CACHE_ENTRIES), } @app.get("/audio/{filename}") def audio(filename: str) -> FileResponse: path = OUTPUT_DIR / filename if not path.is_file(): raise HTTPException(status_code=404, detail="audio not found") return FileResponse(path, media_type="audio/mpeg", filename=filename) @app.get("/synthesis") def synthesis( request: Request, key: str | None = Query(default=None), text: str = Query(..., min_length=1), speaker: str = Query(default="3"), seconds: str = Query(default=DEFAULT_SECONDS), duration_scale: float = Query(default=DEFAULT_DURATION_SCALE, gt=0.0, le=2.0), steps: int = Query(default=DEFAULT_STEPS, ge=1, le=80), seed: str = Query(default=""), caption: str = Query(default=DEFAULT_CAPTION), ) -> JSONResponse: _check_key(key) text = text.strip() if not text: raise HTTPException(status_code=400, detail="text is required") if len(text) > MAX_TEXT_LENGTH: raise HTTPException(status_code=400, detail=f"text is too long: max {MAX_TEXT_LENGTH}") seed_value = str(seed).strip() or _seed_for_speaker(str(speaker)) payload = { "text": text, "speaker": str(speaker), "seconds": str(seconds), "duration_scale": float(duration_scale), "steps": int(steps), "seed": seed_value, "caption": str(caption).strip() or DEFAULT_CAPTION, } cache_key = _cache_key(payload) cached_path = _cache_path(cache_key) if cached_path.is_file(): url = f"{_base_url(request)}/audio/{cached_path.name}" return JSONResponse( { "success": True, "isApiKeyValid": True, "speaker": str(speaker), "seed": seed_value, "durationScale": float(duration_scale), "metadata": {"cacheHit": True, "cacheKey": cache_key}, "mp3StreamingUrl": url, "mp3DownloadUrl": url, "audioStatusUrl": f"{_base_url(request)}/health", } ) try: x_ip_token = request.headers.get("x-ip-token", "").strip() result = _predict_zero_space(payload, x_ip_token=x_ip_token) metadata = _result_metadata(result) output_path = _copy_result_file(result, x_ip_token=x_ip_token) shutil.copyfile(output_path, cached_path) _prune_cache() except Exception as exc: return JSONResponse( status_code=502, content={ "success": False, "isApiKeyValid": bool(not TTS_API_KEY or key == TTS_API_KEY), "error": str(exc), }, ) url = f"{_base_url(request)}/audio/{output_path.name}" return JSONResponse( { "success": True, "isApiKeyValid": True, "speaker": str(speaker), "seed": seed_value, "durationScale": float(duration_scale), "metadata": {**metadata, "cacheHit": False, "cacheKey": cache_key}, "mp3StreamingUrl": url, "mp3DownloadUrl": url, "audioStatusUrl": f"{_base_url(request)}/health", } )