neboximate commited on
Commit
eb7b307
·
verified ·
1 Parent(s): 7ada6ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import torch
9
+ from fastapi import FastAPI
10
+ from pydantic import BaseModel
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import load_file
13
+ from TTS.tts.configs.xtts_config import XttsConfig
14
+ from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
15
+
16
+ # Torch >= 2.6 safety (older versions just ignore this)
17
+ try:
18
+ from torch.serialization import add_safe_globals
19
+ add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig])
20
+ except Exception:
21
+ pass
22
+
23
+ # ---------- CONFIG ----------
24
+
25
+ REPO_ID = "softwarebusters/qiuhuaTTSv2" # HF model repo id
26
+ CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors"
27
+ CONFIG_FILE = "config.json"
28
+
29
+ SPEAKER_REFERENCE = "speaker_ref.wav" # short wav you will upload
30
+ SR_OUT = 24000
31
+
32
+
33
+ def pick_device() -> str:
34
+ if torch.cuda.is_available():
35
+ return "cuda"
36
+ if torch.backends.mps.is_available():
37
+ return "mps"
38
+ return "cpu"
39
+
40
+
41
+ device = pick_device()
42
+ print(f"🚀 Using device: {device}")
43
+
44
+ # ---------- LOAD MODEL AT STARTUP ----------
45
+
46
+ print("📥 Downloading model files from Hugging Face…")
47
+ ckpt_path = hf_hub_download(REPO_ID, CHECKPOINT_FILE)
48
+ cfg_path = hf_hub_download(REPO_ID, CONFIG_FILE)
49
+
50
+ print("📄 Loading XTTS config…")
51
+ config = XttsConfig()
52
+ config.load_json(cfg_path)
53
+
54
+ print("🧠 Initializing XTTS model…")
55
+ model = Xtts.init_from_config(config)
56
+
57
+ # base XTTS files (model.pth, dvae.pth, mel_stats.json, vocab.json)
58
+ base_dir = os.path.dirname(ckpt_path)
59
+
60
+ print("📦 Loading base XTTS weights…")
61
+ model.load_checkpoint(
62
+ config=config,
63
+ checkpoint_dir=base_dir,
64
+ vocab_path=os.path.join(base_dir, "vocab.json"),
65
+ use_deepspeed=False,
66
+ )
67
+
68
+ print(f"📦 Applying fine-tuned checkpoint: {ckpt_path}")
69
+ state_dict = load_file(ckpt_path)
70
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
71
+ print(" missing keys:", len(missing), "| unexpected:", len(unexpected))
72
+
73
+ model.to(device)
74
+ model.eval()
75
+ print("✅ Model ready.")
76
+
77
+
78
+ # ---------- SPEAKER LATENTS ----------
79
+
80
+ if not os.path.exists(SPEAKER_REFERENCE):
81
+ raise FileNotFoundError(
82
+ f"Speaker reference file not found: {SPEAKER_REFERENCE}. "
83
+ "Upload a short WAV file named 'speaker_ref.wav' to the Space."
84
+ )
85
+
86
+ print("🎙️ Computing speaker latents…")
87
+ with torch.inference_mode():
88
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
89
+ audio_path=[SPEAKER_REFERENCE]
90
+ )
91
+ print("✅ Speaker latents ready.")
92
+
93
+
94
+ # ---------- FASTAPI APP ----------
95
+
96
+ app = FastAPI(title="XTTS v2 TTS API (Space)")
97
+
98
+
99
+ class TtsRequest(BaseModel):
100
+ text: str
101
+ language: str = "en"
102
+ temperature: float = 0.7
103
+ speed: float = 1.0
104
+
105
+
106
+ class TtsResponse(BaseModel):
107
+ audio_base64: str
108
+ sample_rate: int
109
+
110
+
111
+ @app.get("/health")
112
+ def health():
113
+ return {"status": "ok"}
114
+
115
+
116
+ @app.post("/tts", response_model=TtsResponse)
117
+ def tts(req: TtsRequest):
118
+ if not req.text.strip():
119
+ return TtsResponse(audio_base64="", sample_rate=SR_OUT)
120
+
121
+ with torch.inference_mode():
122
+ out = model.inference(
123
+ text=req.text,
124
+ language=req.language,
125
+ gpt_cond_latent=gpt_cond_latent,
126
+ speaker_embedding=speaker_embedding,
127
+ temperature=req.temperature,
128
+ speed=req.speed,
129
+ enable_text_splitting=True,
130
+ )
131
+
132
+ wav = np.asarray(out["wav"], dtype=np.float32)
133
+
134
+ buf = io.BytesIO()
135
+ sf.write(buf, wav, SR_OUT, format="WAV")
136
+ audio_bytes = buf.getvalue()
137
+
138
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
139
+
140
+ return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT)