Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import os | |
| import shutil | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| from fastapi import FastAPI, HTTPException, Query, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| ZERO_SPACE = os.getenv("ZERO_SPACE", "schroneko/irodori-tts-zerogpu") | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") or os.getenv("HUGGING_FACE_HUB_TOKEN", "") | |
| TTS_API_KEY = os.getenv("TTS_API_KEY", "") | |
| PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").rstrip("/") | |
| MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "240")) | |
| DEFAULT_SECONDS_RAW = os.getenv("DEFAULT_SECONDS", "").strip() | |
| DEFAULT_SECONDS = "" if DEFAULT_SECONDS_RAW.lower() in {"", "auto", "none"} else DEFAULT_SECONDS_RAW | |
| DEFAULT_DURATION_SCALE = float(os.getenv("DEFAULT_DURATION_SCALE", "0.95")) | |
| DEFAULT_STEPS = int(os.getenv("DEFAULT_STEPS", "18")) | |
| DEFAULT_SEED = int(os.getenv("DEFAULT_SEED", "3407")) | |
| DEFAULT_CAPTION = os.getenv( | |
| "DEFAULT_CAPTION", | |
| "若く元気な女性の声。近い距離感で、明るくやわらかく自然に話している。", | |
| ) | |
| MAX_CACHE_ENTRIES = int(os.getenv("MAX_CACHE_ENTRIES", "256")) | |
| OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "/tmp/stackchan-audio")) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| app = FastAPI(title="Irodori TTS StackChan API") | |
| def _base_url(request: Request) -> str: | |
| if PUBLIC_BASE_URL: | |
| return PUBLIC_BASE_URL | |
| proto = request.headers.get("x-forwarded-proto", request.url.scheme) | |
| host = request.headers.get("x-forwarded-host", request.headers.get("host", request.url.netloc)) | |
| return f"{proto}://{host}".rstrip("/") | |
| def _check_key(key: str | None) -> None: | |
| if TTS_API_KEY and key != TTS_API_KEY: | |
| raise HTTPException(status_code=401, detail="invalid key") | |
| def _seed_for_speaker(speaker: str) -> str: | |
| digest = hashlib.sha256(f"{DEFAULT_SEED}:{speaker}".encode("utf-8")).digest() | |
| value = int.from_bytes(digest[:8], "big") & ((1 << 63) - 1) | |
| return str(value) | |
| def _cache_key(payload: dict[str, Any]) -> str: | |
| data = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) | |
| return hashlib.sha256(data.encode("utf-8")).hexdigest() | |
| def _cache_path(cache_key: str) -> Path: | |
| return OUTPUT_DIR / f"cache_{cache_key}.mp3" | |
| def _prune_cache() -> None: | |
| paths = sorted(OUTPUT_DIR.glob("cache_*.mp3"), key=lambda path: path.stat().st_mtime, reverse=True) | |
| for path in paths[MAX_CACHE_ENTRIES:]: | |
| path.unlink(missing_ok=True) | |
| def _zero_space_base_url() -> str: | |
| space = ZERO_SPACE.strip().rstrip("/") | |
| if space.startswith("http://") or space.startswith("https://"): | |
| return space | |
| return f"https://{space.replace('/', '-')}.hf.space" | |
| def _zero_headers(x_ip_token: str = "") -> dict[str, str]: | |
| headers = {"Content-Type": "application/json"} | |
| if HF_TOKEN: | |
| headers["Authorization"] = f"Bearer {HF_TOKEN}" | |
| if x_ip_token: | |
| headers["x-ip-token"] = x_ip_token | |
| return headers | |
| def _raise_for_response(response: requests.Response) -> None: | |
| if response.ok: | |
| return | |
| detail = response.text.strip() | |
| raise RuntimeError(f"ZeroGPU request failed: {response.status_code} {detail}") | |
| def _predict_zero_space(payload: dict[str, Any], x_ip_token: str = "") -> Any: | |
| base_url = _zero_space_base_url() | |
| response = _post_zero_space(base_url, "v2/synthesize", payload, x_ip_token) | |
| if response.status_code == 405: | |
| response = _post_zero_space(base_url, "synthesize", payload, x_ip_token) | |
| _raise_for_response(response) | |
| event_id = response.json().get("event_id") | |
| if not event_id: | |
| raise RuntimeError(f"ZeroGPU did not return event_id: {response.text}") | |
| response = requests.get( | |
| f"{base_url}/gradio_api/call/synthesize/{event_id}", | |
| headers=_zero_headers(x_ip_token), | |
| stream=True, | |
| timeout=(10, 300), | |
| ) | |
| _raise_for_response(response) | |
| event_name = "" | |
| for line in response.iter_lines(decode_unicode=True): | |
| if not line: | |
| continue | |
| if line.startswith("event:"): | |
| event_name = line.split(":", 1)[1].strip() | |
| continue | |
| if not line.startswith("data:"): | |
| continue | |
| data = line.split(":", 1)[1].strip() | |
| if event_name == "error": | |
| raise RuntimeError(data) | |
| parsed = json.loads(data) | |
| if event_name == "complete": | |
| return parsed | |
| raise RuntimeError("ZeroGPU stream ended before completion") | |
| def _post_zero_space(base_url: str, endpoint: str, payload: dict[str, Any], x_ip_token: str = "") -> requests.Response: | |
| body: dict[str, Any] = payload | |
| if endpoint == "synthesize": | |
| body = { | |
| "data": [ | |
| payload["text"], | |
| payload["speaker"], | |
| payload["seconds"], | |
| payload["duration_scale"], | |
| payload["steps"], | |
| payload["seed"], | |
| payload["caption"], | |
| ] | |
| } | |
| return requests.post( | |
| f"{base_url}/gradio_api/call/{endpoint}", | |
| json=body, | |
| headers=_zero_headers(x_ip_token), | |
| timeout=30, | |
| ) | |
| def _copy_result_file(result: Any, x_ip_token: str = "") -> Path: | |
| source: str | None = None | |
| if isinstance(result, (list, tuple)) and result: | |
| source = result[0] | |
| elif isinstance(result, dict): | |
| source = result.get("path") or result.get("name") | |
| elif isinstance(result, str): | |
| source = result | |
| if isinstance(source, dict): | |
| source = source.get("url") or source.get("path") or source.get("name") | |
| if not source: | |
| raise RuntimeError(f"Could not find generated audio path in result: {result!r}") | |
| output_path = OUTPUT_DIR / f"{int(time.time() * 1000)}_{uuid.uuid4().hex}.mp3" | |
| source_text = str(source) | |
| if source_text.startswith("http://") or source_text.startswith("https://"): | |
| response = requests.get(source_text, headers=_zero_headers(x_ip_token), timeout=60) | |
| _raise_for_response(response) | |
| output_path.write_bytes(response.content) | |
| return output_path | |
| source_path = Path(str(source)) | |
| if not source_path.is_file(): | |
| raise RuntimeError(f"Generated audio file is missing: {source_path}") | |
| shutil.copyfile(source_path, output_path) | |
| return output_path | |
| def _result_metadata(result: Any) -> dict[str, Any]: | |
| if isinstance(result, (list, tuple)) and len(result) > 1 and isinstance(result[1], dict): | |
| return result[1] | |
| return {} | |
| def root() -> dict[str, str]: | |
| return {"ok": "true", "service": "irodori-tts-stackchan-api"} | |
| def health() -> dict[str, str]: | |
| return { | |
| "ok": "true", | |
| "zero_space": ZERO_SPACE, | |
| "duration_scale": str(DEFAULT_DURATION_SCALE), | |
| "has_hf_token": str(bool(HF_TOKEN)).lower(), | |
| "cache_entries": str(len(list(OUTPUT_DIR.glob("cache_*.mp3")))), | |
| "max_cache_entries": str(MAX_CACHE_ENTRIES), | |
| } | |
| def audio(filename: str) -> FileResponse: | |
| path = OUTPUT_DIR / filename | |
| if not path.is_file(): | |
| raise HTTPException(status_code=404, detail="audio not found") | |
| return FileResponse(path, media_type="audio/mpeg", filename=filename) | |
| def synthesis( | |
| request: Request, | |
| key: str | None = Query(default=None), | |
| text: str = Query(..., min_length=1), | |
| speaker: str = Query(default="3"), | |
| seconds: str = Query(default=DEFAULT_SECONDS), | |
| duration_scale: float = Query(default=DEFAULT_DURATION_SCALE, gt=0.0, le=2.0), | |
| steps: int = Query(default=DEFAULT_STEPS, ge=1, le=80), | |
| seed: str = Query(default=""), | |
| caption: str = Query(default=DEFAULT_CAPTION), | |
| ) -> JSONResponse: | |
| _check_key(key) | |
| text = text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="text is required") | |
| if len(text) > MAX_TEXT_LENGTH: | |
| raise HTTPException(status_code=400, detail=f"text is too long: max {MAX_TEXT_LENGTH}") | |
| seed_value = str(seed).strip() or _seed_for_speaker(str(speaker)) | |
| payload = { | |
| "text": text, | |
| "speaker": str(speaker), | |
| "seconds": str(seconds), | |
| "duration_scale": float(duration_scale), | |
| "steps": int(steps), | |
| "seed": seed_value, | |
| "caption": str(caption).strip() or DEFAULT_CAPTION, | |
| } | |
| cache_key = _cache_key(payload) | |
| cached_path = _cache_path(cache_key) | |
| if cached_path.is_file(): | |
| url = f"{_base_url(request)}/audio/{cached_path.name}" | |
| return JSONResponse( | |
| { | |
| "success": True, | |
| "isApiKeyValid": True, | |
| "speaker": str(speaker), | |
| "seed": seed_value, | |
| "durationScale": float(duration_scale), | |
| "metadata": {"cacheHit": True, "cacheKey": cache_key}, | |
| "mp3StreamingUrl": url, | |
| "mp3DownloadUrl": url, | |
| "audioStatusUrl": f"{_base_url(request)}/health", | |
| } | |
| ) | |
| try: | |
| x_ip_token = request.headers.get("x-ip-token", "").strip() | |
| result = _predict_zero_space(payload, x_ip_token=x_ip_token) | |
| metadata = _result_metadata(result) | |
| output_path = _copy_result_file(result, x_ip_token=x_ip_token) | |
| shutil.copyfile(output_path, cached_path) | |
| _prune_cache() | |
| except Exception as exc: | |
| return JSONResponse( | |
| status_code=502, | |
| content={ | |
| "success": False, | |
| "isApiKeyValid": bool(not TTS_API_KEY or key == TTS_API_KEY), | |
| "error": str(exc), | |
| }, | |
| ) | |
| url = f"{_base_url(request)}/audio/{output_path.name}" | |
| return JSONResponse( | |
| { | |
| "success": True, | |
| "isApiKeyValid": True, | |
| "speaker": str(speaker), | |
| "seed": seed_value, | |
| "durationScale": float(duration_scale), | |
| "metadata": {**metadata, "cacheHit": False, "cacheKey": cache_key}, | |
| "mp3StreamingUrl": url, | |
| "mp3DownloadUrl": url, | |
| "audioStatusUrl": f"{_base_url(request)}/health", | |
| } | |
| ) | |