qiuhuaTTSv2-api / app.py
neboximate's picture
Update app.py
24f2a39 verified
import base64
import io
import os
import numpy as np
import soundfile as sf
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
# --------------------------------------------------
# Torch >= 2.6 safety (ignored on older versions)
# --------------------------------------------------
try:
from torch.serialization import add_safe_globals
add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig])
except Exception:
pass
#---------------------------------------
# CONFIG
# --------------------------------------------------
REPO_ID = "softwarebusters/qiuhuaTTSv2" # HuggingFace model repo
CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors" # only fp16 checkpoint
CONFIG_FILE = "config.json"
SPEAKER_REFERENCE = "speaker_ref.wav" # must exist in Space files
SR_OUT = 24000
def pick_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
device = pick_device()
print(f"🚀 Using device: {device}")
# --------------------------------------------------
# AUTH TOKEN (for private repo)
# --------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN") # Optional if repo is public
# --------------------------------------------------
# DOWNLOAD ALL REQUIRED FILES
# --------------------------------------------------
print("📥 Downloading required files from HuggingFace Hub…")
# 1) Fine-tuned checkpoint (fp16)
ckpt_path = hf_hub_download(
REPO_ID,
CHECKPOINT_FILE,
token=HF_TOKEN,
)
# 2) Model config
cfg_path = hf_hub_download(
REPO_ID,
CONFIG_FILE,
token=HF_TOKEN,
)
# 3) Base XTTS model files (minimum set)
model_pth = hf_hub_download(REPO_ID, "model.pth", token=HF_TOKEN)
dvae_pth = hf_hub_download(REPO_ID, "dvae.pth", token=HF_TOKEN)
mel_path = hf_hub_download(REPO_ID, "mel_stats.json", token=HF_TOKEN)
vocab_path = hf_hub_download(REPO_ID, "vocab.json", token=HF_TOKEN)
base_dir = os.path.dirname(model_pth) # All files are downloaded into the same cache dir
# --------------------------------------------------
# LOAD XTTS MODEL
# --------------------------------------------------
print("📄 Loading XTTS config…")
config = XttsConfig()
config.load_json(cfg_path)
print("🧠 Initializing XTTS model…")
model = Xtts.init_from_config(config)
print("📦 Loading base XTTS weights (model.pth, dvae.pth, mel_stats.json)…")
model.load_checkpoint(
config=config,
checkpoint_dir=base_dir,
vocab_path=vocab_path,
use_deepspeed=False,
)
print(f"📦 Applying fine-tuned checkpoint: {ckpt_path}")
state_dict = load_file(ckpt_path)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(" missing keys:", len(missing), "| unexpected keys:", len(unexpected))
model.to(device)
model.eval()
print("✅ Model loaded and ready.")
# --------------------------------------------------
# SPEAKER EMBEDDINGS
# --------------------------------------------------
if not os.path.exists(SPEAKER_REFERENCE):
raise FileNotFoundError(
f"Reference speaker file not found: {SPEAKER_REFERENCE}. "
"Upload a short WAV file named 'speaker_ref.wav'."
)
print("🎙️ Computing speaker conditioning latents…")
with torch.inference_mode():
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
audio_path=[SPEAKER_REFERENCE]
)
print("✅ Speaker latents ready.")
# --------------------------------------------------
# FASTAPI APP
# --------------------------------------------------
app = FastAPI(title="XTTS v2 TTS API (HuggingFace Space)")
class TtsRequest(BaseModel):
text: str
language: str = "en"
temperature: float = 0.7
speed: float = 1.0
class TtsResponse(BaseModel):
audio_base64: str
sample_rate: int
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/tts", response_model=TtsResponse)
def tts(req: TtsRequest):
if not req.text.strip():
return TtsResponse(audio_base64="", sample_rate=SR_OUT)
with torch.inference_mode():
out = model.inference(
text=req.text,
language=req.language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=req.temperature,
speed=req.speed,
enable_text_splitting=True,
)
wav = np.asarray(out["wav"], dtype=np.float32)
# Write audio to an in-memory buffer
buf = io.BytesIO()
sf.write(buf, wav, SR_OUT, format="WAV")
audio_bytes = buf.getvalue()
# Encode to base64 for JSON response
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT)