Spaces:
Sleeping
Sleeping
| import base64 | |
| import io | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig | |
| # -------------------------------------------------- | |
| # Torch >= 2.6 safety (ignored on older versions) | |
| # -------------------------------------------------- | |
| try: | |
| from torch.serialization import add_safe_globals | |
| add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig]) | |
| except Exception: | |
| pass | |
| #--------------------------------------- | |
| # CONFIG | |
| # -------------------------------------------------- | |
| REPO_ID = "softwarebusters/qiuhuaTTSv2" # HuggingFace model repo | |
| CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors" # only fp16 checkpoint | |
| CONFIG_FILE = "config.json" | |
| SPEAKER_REFERENCE = "speaker_ref.wav" # must exist in Space files | |
| SR_OUT = 24000 | |
| def pick_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| device = pick_device() | |
| print(f"🚀 Using device: {device}") | |
| # -------------------------------------------------- | |
| # AUTH TOKEN (for private repo) | |
| # -------------------------------------------------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # Optional if repo is public | |
| # -------------------------------------------------- | |
| # DOWNLOAD ALL REQUIRED FILES | |
| # -------------------------------------------------- | |
| print("📥 Downloading required files from HuggingFace Hub…") | |
| # 1) Fine-tuned checkpoint (fp16) | |
| ckpt_path = hf_hub_download( | |
| REPO_ID, | |
| CHECKPOINT_FILE, | |
| token=HF_TOKEN, | |
| ) | |
| # 2) Model config | |
| cfg_path = hf_hub_download( | |
| REPO_ID, | |
| CONFIG_FILE, | |
| token=HF_TOKEN, | |
| ) | |
| # 3) Base XTTS model files (minimum set) | |
| model_pth = hf_hub_download(REPO_ID, "model.pth", token=HF_TOKEN) | |
| dvae_pth = hf_hub_download(REPO_ID, "dvae.pth", token=HF_TOKEN) | |
| mel_path = hf_hub_download(REPO_ID, "mel_stats.json", token=HF_TOKEN) | |
| vocab_path = hf_hub_download(REPO_ID, "vocab.json", token=HF_TOKEN) | |
| base_dir = os.path.dirname(model_pth) # All files are downloaded into the same cache dir | |
| # -------------------------------------------------- | |
| # LOAD XTTS MODEL | |
| # -------------------------------------------------- | |
| print("📄 Loading XTTS config…") | |
| config = XttsConfig() | |
| config.load_json(cfg_path) | |
| print("🧠 Initializing XTTS model…") | |
| model = Xtts.init_from_config(config) | |
| print("📦 Loading base XTTS weights (model.pth, dvae.pth, mel_stats.json)…") | |
| model.load_checkpoint( | |
| config=config, | |
| checkpoint_dir=base_dir, | |
| vocab_path=vocab_path, | |
| use_deepspeed=False, | |
| ) | |
| print(f"📦 Applying fine-tuned checkpoint: {ckpt_path}") | |
| state_dict = load_file(ckpt_path) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| print(" missing keys:", len(missing), "| unexpected keys:", len(unexpected)) | |
| model.to(device) | |
| model.eval() | |
| print("✅ Model loaded and ready.") | |
| # -------------------------------------------------- | |
| # SPEAKER EMBEDDINGS | |
| # -------------------------------------------------- | |
| if not os.path.exists(SPEAKER_REFERENCE): | |
| raise FileNotFoundError( | |
| f"Reference speaker file not found: {SPEAKER_REFERENCE}. " | |
| "Upload a short WAV file named 'speaker_ref.wav'." | |
| ) | |
| print("🎙️ Computing speaker conditioning latents…") | |
| with torch.inference_mode(): | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( | |
| audio_path=[SPEAKER_REFERENCE] | |
| ) | |
| print("✅ Speaker latents ready.") | |
| # -------------------------------------------------- | |
| # FASTAPI APP | |
| # -------------------------------------------------- | |
| app = FastAPI(title="XTTS v2 TTS API (HuggingFace Space)") | |
| class TtsRequest(BaseModel): | |
| text: str | |
| language: str = "en" | |
| temperature: float = 0.7 | |
| speed: float = 1.0 | |
| class TtsResponse(BaseModel): | |
| audio_base64: str | |
| sample_rate: int | |
| def health(): | |
| return {"status": "ok"} | |
| def tts(req: TtsRequest): | |
| if not req.text.strip(): | |
| return TtsResponse(audio_base64="", sample_rate=SR_OUT) | |
| with torch.inference_mode(): | |
| out = model.inference( | |
| text=req.text, | |
| language=req.language, | |
| gpt_cond_latent=gpt_cond_latent, | |
| speaker_embedding=speaker_embedding, | |
| temperature=req.temperature, | |
| speed=req.speed, | |
| enable_text_splitting=True, | |
| ) | |
| wav = np.asarray(out["wav"], dtype=np.float32) | |
| # Write audio to an in-memory buffer | |
| buf = io.BytesIO() | |
| sf.write(buf, wav, SR_OUT, format="WAV") | |
| audio_bytes = buf.getvalue() | |
| # Encode to base64 for JSON response | |
| audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
| return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT) | |