import base64 import io import os import numpy as np import soundfile as sf import torch from fastapi import FastAPI from pydantic import BaseModel from huggingface_hub import hf_hub_download from safetensors.torch import load_file from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig # -------------------------------------------------- # Torch >= 2.6 safety (ignored on older versions) # -------------------------------------------------- try: from torch.serialization import add_safe_globals add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig]) except Exception: pass #--------------------------------------- # CONFIG # -------------------------------------------------- REPO_ID = "softwarebusters/qiuhuaTTSv2" # HuggingFace model repo CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors" # only fp16 checkpoint CONFIG_FILE = "config.json" SPEAKER_REFERENCE = "speaker_ref.wav" # must exist in Space files SR_OUT = 24000 def pick_device() -> str: if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" return "cpu" device = pick_device() print(f"🚀 Using device: {device}") # -------------------------------------------------- # AUTH TOKEN (for private repo) # -------------------------------------------------- HF_TOKEN = os.environ.get("HF_TOKEN") # Optional if repo is public # -------------------------------------------------- # DOWNLOAD ALL REQUIRED FILES # -------------------------------------------------- print("📥 Downloading required files from HuggingFace Hub…") # 1) Fine-tuned checkpoint (fp16) ckpt_path = hf_hub_download( REPO_ID, CHECKPOINT_FILE, token=HF_TOKEN, ) # 2) Model config cfg_path = hf_hub_download( REPO_ID, CONFIG_FILE, token=HF_TOKEN, ) # 3) Base XTTS model files (minimum set) model_pth = hf_hub_download(REPO_ID, "model.pth", token=HF_TOKEN) dvae_pth = hf_hub_download(REPO_ID, "dvae.pth", token=HF_TOKEN) mel_path = hf_hub_download(REPO_ID, "mel_stats.json", token=HF_TOKEN) vocab_path = hf_hub_download(REPO_ID, "vocab.json", token=HF_TOKEN) base_dir = os.path.dirname(model_pth) # All files are downloaded into the same cache dir # -------------------------------------------------- # LOAD XTTS MODEL # -------------------------------------------------- print("📄 Loading XTTS config…") config = XttsConfig() config.load_json(cfg_path) print("🧠 Initializing XTTS model…") model = Xtts.init_from_config(config) print("📦 Loading base XTTS weights (model.pth, dvae.pth, mel_stats.json)…") model.load_checkpoint( config=config, checkpoint_dir=base_dir, vocab_path=vocab_path, use_deepspeed=False, ) print(f"📦 Applying fine-tuned checkpoint: {ckpt_path}") state_dict = load_file(ckpt_path) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(" missing keys:", len(missing), "| unexpected keys:", len(unexpected)) model.to(device) model.eval() print("✅ Model loaded and ready.") # -------------------------------------------------- # SPEAKER EMBEDDINGS # -------------------------------------------------- if not os.path.exists(SPEAKER_REFERENCE): raise FileNotFoundError( f"Reference speaker file not found: {SPEAKER_REFERENCE}. " "Upload a short WAV file named 'speaker_ref.wav'." ) print("🎙️ Computing speaker conditioning latents…") with torch.inference_mode(): gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( audio_path=[SPEAKER_REFERENCE] ) print("✅ Speaker latents ready.") # -------------------------------------------------- # FASTAPI APP # -------------------------------------------------- app = FastAPI(title="XTTS v2 TTS API (HuggingFace Space)") class TtsRequest(BaseModel): text: str language: str = "en" temperature: float = 0.7 speed: float = 1.0 class TtsResponse(BaseModel): audio_base64: str sample_rate: int @app.get("/health") def health(): return {"status": "ok"} @app.post("/tts", response_model=TtsResponse) def tts(req: TtsRequest): if not req.text.strip(): return TtsResponse(audio_base64="", sample_rate=SR_OUT) with torch.inference_mode(): out = model.inference( text=req.text, language=req.language, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=req.temperature, speed=req.speed, enable_text_splitting=True, ) wav = np.asarray(out["wav"], dtype=np.float32) # Write audio to an in-memory buffer buf = io.BytesIO() sf.write(buf, wav, SR_OUT, format="WAV") audio_bytes = buf.getvalue() # Encode to base64 for JSON response audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT)