File size: 7,464 Bytes
be85c0f
 
 
 
 
 
 
 
 
e84f64a
 
f52f228
be85c0f
 
 
 
36f8bee
 
 
 
 
b52eab0
be85c0f
 
 
 
 
b52eab0
31c0008
 
 
 
 
 
 
b52eab0
31c0008
b52eab0
 
 
31c0008
be85c0f
b52eab0
 
 
 
 
 
 
 
 
 
 
 
 
 
be85c0f
b52eab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be85c0f
 
 
 
9f83ce7
 
be85c0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f83ce7
be85c0f
 
 
 
 
 
 
 
 
 
 
 
e84f64a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be85c0f
 
 
b52eab0
be85c0f
 
f52f228
 
 
 
 
 
 
 
 
 
be85c0f
 
 
f52f228
be85c0f
 
 
 
 
 
 
 
 
e84f64a
b52eab0
be85c0f
b52eab0
 
 
 
 
be85c0f
b52eab0
 
 
be85c0f
 
e84f64a
 
 
f52f228
 
 
 
 
 
 
be85c0f
 
 
f52f228
 
 
be85c0f
9f83ce7
f52f228
be85c0f
 
 
 
f52f228
be85c0f
 
 
 
b52eab0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import base64
import os
import tempfile
import uuid
from pathlib import Path
from typing import Optional

import requests
import torch
import torchaudio
from torchaudio.transforms import Resample
from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, Field, HttpUrl

SPACE_API_KEY = os.getenv("SPACE_API_KEY")
HF_TOKEN = (
    os.getenv("HUGGING_FACE_HUB_TOKEN")
    or os.getenv("HUGGINGFACEHUB_API_TOKEN")
    or os.getenv("HF_TOKEN")
)
MODEL_REPO = "IndexTeam/IndexTTS-2"
MAX_TEXT_LENGTH = 1000
DEFAULT_LANGUAGE = "en"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Set token in environment before importing
if HF_TOKEN:
    os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
    os.environ["HF_TOKEN"] = HF_TOKEN
    try:
        from huggingface_hub import login
        login(token=HF_TOKEN, add_to_git_credential=False)
    except ImportError:
        pass

# Download model checkpoints from Hugging Face
MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
os.makedirs(MODEL_DIR, exist_ok=True)

try:
    from huggingface_hub import snapshot_download
    
    # Download model if not already present
    if not Path(MODEL_DIR, "config.yaml").exists():
        print(f"Downloading IndexTTS2 model from {MODEL_REPO}...")
        snapshot_download(
            repo_id=MODEL_REPO,
            local_dir=MODEL_DIR,
            token=HF_TOKEN,
        )
        print("Model download complete.")
except Exception as exc:
    print(f"Warning: Could not download model: {exc}")
    # Continue anyway - model might already be present

# Initialize IndexTTS2
try:
    from indextts.infer_v2 import IndexTTS2
    
    cfg_path = os.path.join(MODEL_DIR, "config.yaml")
    if not Path(cfg_path).exists():
        raise FileNotFoundError(f"Config file not found at {cfg_path}. Model may not be downloaded.")
    
    tts_model = IndexTTS2(
        cfg_path=cfg_path,
        model_dir=MODEL_DIR,
        use_fp16=False,  # CPU doesn't support FP16
        use_cuda_kernel=False,  # CPU mode
        use_deepspeed=False,  # CPU mode
    )
    print("IndexTTS2 model loaded successfully.")
except Exception as exc:
    raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc

app = FastAPI(title="indextts2-api", version="1.0.0")


class GenerateRequest(BaseModel):
    text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
    speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
    language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en")


def _require_api_key(x_api_key: Optional[str]):
    if not SPACE_API_KEY:
        return
    if x_api_key != SPACE_API_KEY:
        raise HTTPException(status_code=401, detail="Unauthorized")


def _write_temp_audio_from_url(url: HttpUrl) -> str:
    response = requests.get(url, stream=True, timeout=30)
    if response.status_code >= 400:
        raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}")
    suffix = Path(url.path).suffix or ".wav"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                tmp.write(chunk)
        return tmp.name


def _write_temp_audio_from_base64(payload: str) -> str:
    try:
        raw = base64.b64decode(payload)
    except Exception as exc:  # pragma: no cover
        raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
        tmp.write(raw)
        return tmp.name


def _temp_speaker_file(speaker_wav: str) -> str:
    if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
        return _write_temp_audio_from_url(HttpUrl(speaker_wav))
    return _write_temp_audio_from_base64(speaker_wav)


def _preprocess_audio_wav(path: str, target_sr: int = 24000, target_peak: float = 0.98) -> str:
    """
    Light preprocessing to stabilize embeddings and output quality:
    - convert to mono
    - resample to target_sr
    - peak-normalize to target_peak (avoid clipping)
    """
    wav, sr = torchaudio.load(path)

    # Mono
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)

    # Resample if needed
    if sr != target_sr:
        resampler = Resample(orig_freq=sr, new_freq=target_sr)
        wav = resampler(wav)
        sr = target_sr

    # Peak normalize
    peak = wav.abs().max().item() if wav.numel() else 0.0
    if peak > 0:
        scale = min(target_peak / peak, 1.0)
        wav = wav * scale

    # Overwrite input file to avoid extra temp files
    torchaudio.save(path, wav, sr, bits_per_sample=16)
    return path


@app.post("/health")
def health(x_api_key: Optional[str] = Header(default=None)):
    _require_api_key(x_api_key)
    return {"status": "ok", "model": "indextts2", "device": DEVICE}


def _cleanup_files(*files: str):
    """Background task to clean up temporary files after response is sent."""
    for file_path in files:
        if file_path and Path(file_path).exists():
            try:
                Path(file_path).unlink(missing_ok=True)
            except Exception:
                pass  # Ignore cleanup errors


@app.post("/generate")
def generate(
    payload: GenerateRequest = Body(...),
    background_tasks: BackgroundTasks = BackgroundTasks(),
    x_api_key: Optional[str] = Header(default=None),
):
    _require_api_key(x_api_key)

    speaker_file = None
    output_file = None

    try:
        speaker_file = _temp_speaker_file(payload.speaker_wav)
        speaker_file = _preprocess_audio_wav(speaker_file)
        output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")

        # IndexTTS2 inference
        # Note: language parameter is kept for API compatibility but IndexTTS2
        # handles multilingual automatically (supports English, Turkish, Chinese, etc.)
        tts_model.infer(
            spk_audio_prompt=speaker_file,
            text=payload.text,
            output_path=output_file,
            use_random=False,  # Deterministic output
            verbose=False,
        )

        # Light post-process to avoid end-of-file artifacts
        output_file = _preprocess_audio_wav(output_file)

        # Verify the output file was created
        if not Path(output_file).exists():
            raise RuntimeError(f"TTS generation failed: output file was not created at {output_file}")

        # Schedule cleanup after response is sent
        background_tasks.add_task(_cleanup_files, speaker_file, output_file)

        return FileResponse(output_file, media_type="audio/wav", filename="output.wav")

    except HTTPException:
        # Clean up on HTTPException
        if speaker_file and Path(speaker_file).exists():
            Path(speaker_file).unlink(missing_ok=True)
        raise
    except Exception as exc:  # pragma: no cover
        # Clean up on error
        if speaker_file and Path(speaker_file).exists():
            Path(speaker_file).unlink(missing_ok=True)
        if output_file and Path(output_file).exists():
            Path(output_file).unlink(missing_ok=True)
        return JSONResponse(status_code=500, content={"error": str(exc)})


@app.get("/")
def root():
    return {"name": "indextts2-api", "endpoints": ["/health", "/generate"]}