Spaces:
Sleeping
Sleeping
File size: 4,987 Bytes
eb7b307 ede25cd fe8e8e7 ede25cd 24f2a39 eb7b307 24f2a39 ede25cd eb7b307 fe8e8e7 eb7b307 fe8e8e7 eb7b307 ede25cd fe8e8e7 ede25cd fe8e8e7 ede25cd fe8e8e7 ede25cd fe8e8e7 ede25cd fe8e8e7 ede25cd fe8e8e7 ede25cd fe8e8e7 ede25cd eb7b307 fe8e8e7 eb7b307 ede25cd eb7b307 ede25cd eb7b307 fe8e8e7 eb7b307 fe8e8e7 eb7b307 ede25cd fe8e8e7 ede25cd eb7b307 fe8e8e7 eb7b307 fe8e8e7 eb7b307 fe8e8e7 ede25cd eb7b307 ede25cd eb7b307 fe8e8e7 eb7b307 fe8e8e7 eb7b307 fe8e8e7 eb7b307 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import base64
import io
import os
import numpy as np
import soundfile as sf
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
# --------------------------------------------------
# Torch >= 2.6 safety (ignored on older versions)
# --------------------------------------------------
try:
from torch.serialization import add_safe_globals
add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig])
except Exception:
pass
#---------------------------------------
# CONFIG
# --------------------------------------------------
REPO_ID = "softwarebusters/qiuhuaTTSv2" # HuggingFace model repo
CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors" # only fp16 checkpoint
CONFIG_FILE = "config.json"
SPEAKER_REFERENCE = "speaker_ref.wav" # must exist in Space files
SR_OUT = 24000
def pick_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
device = pick_device()
print(f"🚀 Using device: {device}")
# --------------------------------------------------
# AUTH TOKEN (for private repo)
# --------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN") # Optional if repo is public
# --------------------------------------------------
# DOWNLOAD ALL REQUIRED FILES
# --------------------------------------------------
print("📥 Downloading required files from HuggingFace Hub…")
# 1) Fine-tuned checkpoint (fp16)
ckpt_path = hf_hub_download(
REPO_ID,
CHECKPOINT_FILE,
token=HF_TOKEN,
)
# 2) Model config
cfg_path = hf_hub_download(
REPO_ID,
CONFIG_FILE,
token=HF_TOKEN,
)
# 3) Base XTTS model files (minimum set)
model_pth = hf_hub_download(REPO_ID, "model.pth", token=HF_TOKEN)
dvae_pth = hf_hub_download(REPO_ID, "dvae.pth", token=HF_TOKEN)
mel_path = hf_hub_download(REPO_ID, "mel_stats.json", token=HF_TOKEN)
vocab_path = hf_hub_download(REPO_ID, "vocab.json", token=HF_TOKEN)
base_dir = os.path.dirname(model_pth) # All files are downloaded into the same cache dir
# --------------------------------------------------
# LOAD XTTS MODEL
# --------------------------------------------------
print("📄 Loading XTTS config…")
config = XttsConfig()
config.load_json(cfg_path)
print("🧠 Initializing XTTS model…")
model = Xtts.init_from_config(config)
print("📦 Loading base XTTS weights (model.pth, dvae.pth, mel_stats.json)…")
model.load_checkpoint(
config=config,
checkpoint_dir=base_dir,
vocab_path=vocab_path,
use_deepspeed=False,
)
print(f"📦 Applying fine-tuned checkpoint: {ckpt_path}")
state_dict = load_file(ckpt_path)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(" missing keys:", len(missing), "| unexpected keys:", len(unexpected))
model.to(device)
model.eval()
print("✅ Model loaded and ready.")
# --------------------------------------------------
# SPEAKER EMBEDDINGS
# --------------------------------------------------
if not os.path.exists(SPEAKER_REFERENCE):
raise FileNotFoundError(
f"Reference speaker file not found: {SPEAKER_REFERENCE}. "
"Upload a short WAV file named 'speaker_ref.wav'."
)
print("🎙️ Computing speaker conditioning latents…")
with torch.inference_mode():
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
audio_path=[SPEAKER_REFERENCE]
)
print("✅ Speaker latents ready.")
# --------------------------------------------------
# FASTAPI APP
# --------------------------------------------------
app = FastAPI(title="XTTS v2 TTS API (HuggingFace Space)")
class TtsRequest(BaseModel):
text: str
language: str = "en"
temperature: float = 0.7
speed: float = 1.0
class TtsResponse(BaseModel):
audio_base64: str
sample_rate: int
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/tts", response_model=TtsResponse)
def tts(req: TtsRequest):
if not req.text.strip():
return TtsResponse(audio_base64="", sample_rate=SR_OUT)
with torch.inference_mode():
out = model.inference(
text=req.text,
language=req.language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=req.temperature,
speed=req.speed,
enable_text_splitting=True,
)
wav = np.asarray(out["wav"], dtype=np.float32)
# Write audio to an in-memory buffer
buf = io.BytesIO()
sf.write(buf, wav, SR_OUT, format="WAV")
audio_bytes = buf.getvalue()
# Encode to base64 for JSON response
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT)
|