Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import base64
|
| 2 |
import os
|
| 3 |
-
import sys
|
| 4 |
import tempfile
|
| 5 |
import uuid
|
| 6 |
-
from io import StringIO
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Optional
|
| 9 |
|
|
@@ -21,42 +19,62 @@ HF_TOKEN = (
|
|
| 21 |
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 22 |
or os.getenv("HF_TOKEN")
|
| 23 |
)
|
| 24 |
-
|
| 25 |
MAX_TEXT_LENGTH = 1000
|
| 26 |
DEFAULT_LANGUAGE = "en"
|
| 27 |
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
|
| 30 |
-
# Set token in environment before importing
|
| 31 |
if HF_TOKEN:
|
| 32 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
| 33 |
os.environ["HF_TOKEN"] = HF_TOKEN
|
| 34 |
-
# Also login explicitly via huggingface_hub
|
| 35 |
try:
|
| 36 |
from huggingface_hub import login
|
| 37 |
login(token=HF_TOKEN, add_to_git_credential=False)
|
| 38 |
except ImportError:
|
| 39 |
-
pass
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
sys.stdin = StringIO("y\n") # Auto-accept TOS
|
| 45 |
-
|
| 46 |
-
from TTS.api import TTS
|
| 47 |
|
| 48 |
try:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
if
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
class GenerateRequest(BaseModel):
|
|
@@ -133,7 +151,7 @@ def _preprocess_audio_wav(path: str, target_sr: int = 24000, target_peak: float
|
|
| 133 |
@app.post("/health")
|
| 134 |
def health(x_api_key: Optional[str] = Header(default=None)):
|
| 135 |
_require_api_key(x_api_key)
|
| 136 |
-
return {"status": "ok", "model": "
|
| 137 |
|
| 138 |
|
| 139 |
def _cleanup_files(*files: str):
|
|
@@ -160,14 +178,17 @@ def generate(
|
|
| 160 |
try:
|
| 161 |
speaker_file = _temp_speaker_file(payload.speaker_wav)
|
| 162 |
speaker_file = _preprocess_audio_wav(speaker_file)
|
| 163 |
-
output_file = os.path.join(tempfile.gettempdir(), f"
|
| 164 |
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
text=payload.text,
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
split_sentences=True,
|
| 171 |
)
|
| 172 |
|
| 173 |
# Light post-process to avoid end-of-file artifacts
|
|
@@ -198,4 +219,4 @@ def generate(
|
|
| 198 |
|
| 199 |
@app.get("/")
|
| 200 |
def root():
|
| 201 |
-
return {"name": "
|
|
|
|
| 1 |
import base64
|
| 2 |
import os
|
|
|
|
| 3 |
import tempfile
|
| 4 |
import uuid
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Optional
|
| 7 |
|
|
|
|
| 19 |
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 20 |
or os.getenv("HF_TOKEN")
|
| 21 |
)
|
| 22 |
+
MODEL_REPO = "IndexTeam/IndexTTS-2"
|
| 23 |
MAX_TEXT_LENGTH = 1000
|
| 24 |
DEFAULT_LANGUAGE = "en"
|
| 25 |
|
| 26 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
|
| 28 |
+
# Set token in environment before importing
|
| 29 |
if HF_TOKEN:
|
| 30 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
| 31 |
os.environ["HF_TOKEN"] = HF_TOKEN
|
|
|
|
| 32 |
try:
|
| 33 |
from huggingface_hub import login
|
| 34 |
login(token=HF_TOKEN, add_to_git_credential=False)
|
| 35 |
except ImportError:
|
| 36 |
+
pass
|
| 37 |
|
| 38 |
+
# Download model checkpoints from Hugging Face
|
| 39 |
+
MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
|
| 40 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
try:
|
| 43 |
+
from huggingface_hub import snapshot_download
|
| 44 |
+
|
| 45 |
+
# Download model if not already present
|
| 46 |
+
if not Path(MODEL_DIR, "config.yaml").exists():
|
| 47 |
+
print(f"Downloading IndexTTS2 model from {MODEL_REPO}...")
|
| 48 |
+
snapshot_download(
|
| 49 |
+
repo_id=MODEL_REPO,
|
| 50 |
+
local_dir=MODEL_DIR,
|
| 51 |
+
token=HF_TOKEN,
|
| 52 |
+
)
|
| 53 |
+
print("Model download complete.")
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
print(f"Warning: Could not download model: {exc}")
|
| 56 |
+
# Continue anyway - model might already be present
|
| 57 |
|
| 58 |
+
# Initialize IndexTTS2
|
| 59 |
+
try:
|
| 60 |
+
from indextts.infer_v2 import IndexTTS2
|
| 61 |
+
|
| 62 |
+
cfg_path = os.path.join(MODEL_DIR, "config.yaml")
|
| 63 |
+
if not Path(cfg_path).exists():
|
| 64 |
+
raise FileNotFoundError(f"Config file not found at {cfg_path}. Model may not be downloaded.")
|
| 65 |
+
|
| 66 |
+
tts_model = IndexTTS2(
|
| 67 |
+
cfg_path=cfg_path,
|
| 68 |
+
model_dir=MODEL_DIR,
|
| 69 |
+
use_fp16=False, # CPU doesn't support FP16
|
| 70 |
+
use_cuda_kernel=False, # CPU mode
|
| 71 |
+
use_deepspeed=False, # CPU mode
|
| 72 |
+
)
|
| 73 |
+
print("IndexTTS2 model loaded successfully.")
|
| 74 |
+
except Exception as exc:
|
| 75 |
+
raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
|
| 76 |
+
|
| 77 |
+
app = FastAPI(title="indextts2-api", version="1.0.0")
|
| 78 |
|
| 79 |
|
| 80 |
class GenerateRequest(BaseModel):
|
|
|
|
| 151 |
@app.post("/health")
|
| 152 |
def health(x_api_key: Optional[str] = Header(default=None)):
|
| 153 |
_require_api_key(x_api_key)
|
| 154 |
+
return {"status": "ok", "model": "indextts2", "device": DEVICE}
|
| 155 |
|
| 156 |
|
| 157 |
def _cleanup_files(*files: str):
|
|
|
|
| 178 |
try:
|
| 179 |
speaker_file = _temp_speaker_file(payload.speaker_wav)
|
| 180 |
speaker_file = _preprocess_audio_wav(speaker_file)
|
| 181 |
+
output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
|
| 182 |
|
| 183 |
+
# IndexTTS2 inference
|
| 184 |
+
# Note: language parameter is kept for API compatibility but IndexTTS2
|
| 185 |
+
# handles multilingual automatically (supports English, Turkish, Chinese, etc.)
|
| 186 |
+
tts_model.infer(
|
| 187 |
+
spk_audio_prompt=speaker_file,
|
| 188 |
text=payload.text,
|
| 189 |
+
output_path=output_file,
|
| 190 |
+
use_random=False, # Deterministic output
|
| 191 |
+
verbose=False,
|
|
|
|
| 192 |
)
|
| 193 |
|
| 194 |
# Light post-process to avoid end-of-file artifacts
|
|
|
|
| 219 |
|
| 220 |
@app.get("/")
|
| 221 |
def root():
|
| 222 |
+
return {"name": "indextts2-api", "endpoints": ["/health", "/generate"]}
|