Spaces:
Sleeping
Sleeping
File size: 4,847 Bytes
b7cb69e 826855a b7cb69e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """NeuTTS FastAPI backend β runs on HuggingFace Spaces."""
from __future__ import annotations
import io
import os
import sys
import tempfile
import traceback
from pathlib import Path
import numpy as np
import soundfile as sf
import uvicorn
from fastapi import FastAPI, File, Form, Header, HTTPException, UploadFile
from fastapi.responses import Response
from neutts import NeuTTS
# βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
API_KEY = os.environ.get("NEUTTS_API_KEY", "")
BACKBONE = os.environ.get("NEUTTS_BACKBONE", "neuphonic/neutts-nano-q8-gguf")
DEVICE = os.environ.get("NEUTTS_DEVICE", "cpu")
CODEC = os.environ.get("NEUTTS_CODEC", "neuphonic/neucodec-onnx-decoder")
SAMPLE_RATE = 24_000
# βββ Model loading (at startup) βββββββββββββββββββββββββββββββββββββββββββββββ
print(f"[backend] Loading NeuTTS: backbone={BACKBONE} device={DEVICE} codec={CODEC}", flush=True)
_tts: NeuTTS | None = None
try:
_tts = NeuTTS(
backbone_repo=BACKBONE,
backbone_device=DEVICE,
codec_repo=CODEC,
codec_device="cpu",
)
print("[backend] Model loaded OK", flush=True)
except Exception as exc:
print(f"[backend] WARNING: model load failed: {exc}", file=sys.stderr, flush=True)
_whisper_model = None
_whisper_model_name = ""
# βββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(title="NeuTTS backend", version="1.0")
def _check_key(key: str | None) -> None:
if API_KEY and key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
@app.get("/health")
def health(x_api_key: str | None = Header(default=None)):
_check_key(x_api_key)
return {
"status": "ok",
"model_loaded": _tts is not None,
"backbone": BACKBONE,
"device": DEVICE,
"codec": CODEC,
}
@app.post("/generate")
async def generate(
text: str = Form(...),
ref_text: str = Form(""),
temperature: float = Form(1.0),
top_k: int = Form(50),
ref_audio: UploadFile = File(...),
x_api_key: str | None = Header(default=None),
):
_check_key(x_api_key)
if _tts is None:
raise HTTPException(status_code=503, detail="Model not loaded on backend")
suffix = Path(ref_audio.filename or "audio.wav").suffix or ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(await ref_audio.read())
tmp_path = tmp.name
try:
ref_codes = _tts.encode_reference(tmp_path)
wav = _tts.infer(
text.strip(),
ref_codes,
ref_text.strip() or " ",
temperature=float(temperature),
top_k=int(top_k),
)
buf = io.BytesIO()
sf.write(buf, wav.astype(np.float32), SAMPLE_RATE, format="WAV")
buf.seek(0)
return Response(content=buf.read(), media_type="audio/wav")
except Exception as exc:
print(f"[backend] /generate error:\n{traceback.format_exc()}", file=sys.stderr, flush=True)
raise HTTPException(status_code=500, detail=str(exc))
finally:
Path(tmp_path).unlink(missing_ok=True)
@app.post("/transcribe")
async def transcribe(
audio: UploadFile = File(...),
model_id: str = Form("base"),
x_api_key: str | None = Header(default=None),
):
global _whisper_model, _whisper_model_name
_check_key(x_api_key)
try:
import whisper as _w
except ImportError:
raise HTTPException(status_code=503, detail="openai-whisper not installed on backend")
suffix = Path(audio.filename or "audio.wav").suffix or ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(await audio.read())
tmp_path = tmp.name
try:
if _whisper_model is None or _whisper_model_name != model_id:
print(f"[backend] loading Whisper '{model_id}'...", flush=True)
_whisper_model = _w.load_model(model_id)
_whisper_model_name = model_id
result = _whisper_model.transcribe(tmp_path)
return {"text": result["text"].strip()}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
Path(tmp_path).unlink(missing_ok=True)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|