Spaces:
Paused
Paused
| """ | |
| Matcha-TTS Standalone API Server | |
| ================================ | |
| KiαΊΏn trΓΊc: Per-Core Worker + Cache 1x + Text Chunking | |
| - Mα»i request chα» dΓΉng 1 CPU/GPU core (torch.set_num_threads(1)) | |
| - Worker pool tα»± Δα»ng scale theo sα» core (TTS_WORKERS env) | |
| - Cache 1x audio: Δα»i speed chα» chαΊ‘y FFmpeg, khΓ΄ng chαΊ‘y lαΊ‘i model | |
| - Text dΓ i tα»± chia nhα» theo cΓ’u Δα» trΓ‘nh OOM | |
| """ | |
| import os | |
| import sys | |
| import re | |
| import tempfile | |
| import subprocess | |
| import hashlib | |
| import shutil | |
| import time | |
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| import uvicorn | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException, Body, BackgroundTasks | |
| from fastapi.responses import FileResponse, JSONResponse | |
| # βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.join(current_dir, "Matcha-TTS")) | |
| CHECKPOINT_PATH = os.path.join(current_dir, "model", "checkpoint_016_fp16.ckpt") | |
| VOCODER_PATH = os.path.join(current_dir, "model", "generator_v1_fp16") | |
| SAMPLE_RATE = 22050 | |
| MAX_CHUNK_CHARS = 300 | |
| CLEANER = "basic_cleaners_vi_female" | |
| # Cache | |
| CACHE_DIR = Path(os.path.join(current_dir, "cache_1x")) | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| CACHE_MAX_FILES = 500 | |
| # Worker pool β mα»i worker chiαΊΏm ΔΓΊng 1 core | |
| NUM_WORKERS = int(os.environ.get("TTS_WORKERS", min(os.cpu_count() or 2, 4))) | |
| # Giα»i hαΊ‘n PyTorch: mα»i inference call chα» dΓΉng 1 thread | |
| torch.set_num_threads(1) | |
| # βββ Imports from Matcha-TTS ββββββββββββββββββββββββββββββββββββ | |
| from matcha.hifigan.config import v1 | |
| from matcha.hifigan.env import AttrDict | |
| from matcha.hifigan.models import Generator as HiFiGAN | |
| from matcha.models.matcha_tts import MatchaTTS | |
| from matcha.text import text_to_sequence | |
| from matcha.utils.utils import intersperse | |
| # βββ App ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Matcha-TTS Standalone API", | |
| description="Per-core worker TTS API with 1x cache and text chunking" | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| models = {} | |
| tts_executor = None # Initialized in startup | |
| # βββ Text Chunking ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def split_text_into_chunks(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list: | |
| """Chia text thΓ nh cΓ‘c ΔoαΊ‘n nhα»: Ζ°u tiΓͺn xuα»ng dΓ²ng β dαΊ₯u cΓ’u β force cut.""" | |
| # 1. TΓ‘ch theo xuα»ng dΓ²ng trΖ°α»c | |
| paragraphs = text.split("\n") | |
| chunks = [] | |
| for para in paragraphs: | |
| para = para.strip() | |
| if not para: | |
| continue | |
| if len(para) <= max_chars: | |
| chunks.append(para) | |
| continue | |
| # 2. TΓ‘ch theo dαΊ₯u cΓ’u | |
| sentences = re.split(r'(?<=[.!?γοΌοΌ;οΌ,οΌ])\s*', para) | |
| current = "" | |
| for sent in sentences: | |
| sent = sent.strip() | |
| if not sent: | |
| continue | |
| if len(current) + len(sent) + 1 <= max_chars: | |
| current = (current + " " + sent).strip() | |
| else: | |
| if current: | |
| chunks.append(current) | |
| # 3. Force cut nαΊΏu cΓ’u ΔΖ‘n quΓ‘ dΓ i | |
| if len(sent) > max_chars: | |
| for i in range(0, len(sent), max_chars): | |
| chunks.append(sent[i:i + max_chars]) | |
| current = "" | |
| else: | |
| current = sent | |
| if current: | |
| chunks.append(current) | |
| return chunks if chunks else [text[:max_chars]] | |
| # βββ Synthesis (chαΊ‘y trΓͺn worker thread) ββββββββββββββββββββββββ | |
| def synthesise_chunk(text_chunk: str) -> np.ndarray: | |
| """Synthesise 1 ΔoαΊ‘n text ngαΊ―n. ChαΊ‘y trΓͺn 1 core duy nhαΊ₯t.""" | |
| x = torch.tensor( | |
| intersperse(text_to_sequence(text_chunk, [CLEANER])[0], 0), | |
| dtype=torch.long, | |
| device=device, | |
| )[None] | |
| x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| if device.type == "cuda": | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.float16): | |
| output = models["matcha"].synthesise( | |
| x, x_lengths, n_timesteps=10, temperature=0.667, spks=None, length_scale=1.0 | |
| ) | |
| audio = models["vocoder"](output["mel"]).clamp(-1, 1).squeeze().cpu().numpy() | |
| else: | |
| output = models["matcha"].synthesise( | |
| x, x_lengths, n_timesteps=10, temperature=0.667, spks=None, length_scale=1.0 | |
| ) | |
| audio = models["vocoder"](output["mel"]).clamp(-1, 1).squeeze().cpu().numpy() | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return audio.astype(np.float32) | |
| def synthesise_full_text(text: str) -> str: | |
| """Synthesise full text (chunked), lΖ°u cache 1x, trαΊ£ vα» path.""" | |
| text_hash = hashlib.sha256(text.encode("utf-8")).hexdigest() | |
| cached_path = CACHE_DIR / f"{text_hash}.wav" | |
| if cached_path.exists(): | |
| print(f"[β] Cache HIT (hash: {text_hash[:8]})") | |
| return str(cached_path) | |
| # Cache miss β chαΊ‘y model | |
| chunks = split_text_into_chunks(text, MAX_CHUNK_CHARS) | |
| print(f"[~] Cache MISS β Tα»ng hợp {len(chunks)} chunks") | |
| audio_parts = [] | |
| for i, chunk in enumerate(chunks): | |
| if not chunk.strip(): | |
| continue | |
| t0 = time.time() | |
| part = synthesise_chunk(chunk) | |
| dt = time.time() - t0 | |
| print(f" Chunk {i+1}/{len(chunks)}: {len(chunk)} chars β {len(part)/SAMPLE_RATE:.1f}s audio [{dt:.2f}s]") | |
| audio_parts.append(part) | |
| if not audio_parts: | |
| raise ValueError("KhΓ΄ng tαΊ‘o Δược Γ’m thanh") | |
| audio = np.concatenate(audio_parts) | |
| max_val = np.max(np.abs(audio)) | |
| if max_val > 0: | |
| audio = (audio / max_val * 0.95).astype(np.float32) | |
| sf.write(str(cached_path), audio, SAMPLE_RATE) | |
| print(f"[β] ΔΓ£ lΖ°u cache 1x: {cached_path.name}") | |
| # Cleanup cΕ© nαΊΏu quΓ‘ nhiα»u | |
| cache_files = sorted(CACHE_DIR.glob("*.wav"), key=lambda f: f.stat().st_mtime) | |
| if len(cache_files) > CACHE_MAX_FILES: | |
| for old in cache_files[:len(cache_files) - CACHE_MAX_FILES]: | |
| old.unlink(missing_ok=True) | |
| return str(cached_path) | |
| # βββ FFmpeg Speed/Volume ββββββββββββββββββββββββββββββββββββββββ | |
| def apply_ffmpeg(input_path: str, output_path: str, speed: float, volume: float = 1.0): | |
| if abs(speed - 1.0) < 0.05 and abs(volume - 1.0) < 0.05: | |
| shutil.copy(input_path, output_path) | |
| return | |
| filters = [] | |
| if abs(volume - 1.0) >= 0.05: | |
| filters.append(f"volume={volume}") | |
| remaining = speed | |
| while remaining > 2.0: | |
| filters.append("atempo=2.0") | |
| remaining /= 2.0 | |
| while remaining < 0.5: | |
| filters.append("atempo=0.5") | |
| remaining /= 0.5 | |
| if abs(remaining - 1.0) > 0.01: | |
| filters.append(f"atempo={remaining}") | |
| filter_str = ",".join(filters) if filters else "anull" | |
| cmd = ["ffmpeg", "-y", "-i", input_path, "-filter:a", filter_str, output_path] | |
| subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) | |
| def cleanup_file(filepath: str): | |
| try: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| except: | |
| pass | |
| # βββ Startup ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def startup(): | |
| global tts_executor | |
| tts_executor = ThreadPoolExecutor(max_workers=NUM_WORKERS, thread_name_prefix="tts-worker") | |
| print(f"[+] Device: {device} | Workers: {NUM_WORKERS} | Cache: {CACHE_DIR}") | |
| print(f"[!] Loading Matcha-TTS checkpoint: {CHECKPOINT_PATH}") | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False) | |
| model = MatchaTTS(**checkpoint["hyper_parameters"]) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model = model.to(device) | |
| if device.type == "cuda": | |
| model = model.half() | |
| else: | |
| model = model.float() | |
| model.eval() | |
| models["matcha"] = model | |
| print(f"[!] Loading HiFi-GAN Vocoder: {VOCODER_PATH}") | |
| h = AttrDict(v1) | |
| vocoder = HiFiGAN(h).to(device) | |
| vocoder.load_state_dict(torch.load(VOCODER_PATH, map_location=device)["generator"]) | |
| if device.type == "cuda": | |
| vocoder = vocoder.half() | |
| else: | |
| vocoder = vocoder.float() | |
| vocoder.eval() | |
| vocoder.remove_weight_norm() | |
| models["vocoder"] = vocoder | |
| print("[β] All models loaded!") | |
| # Warmup | |
| print("[!] Warming up...") | |
| try: | |
| synthesise_chunk("khα»i Δα»ng") | |
| print("[β] Warmup complete!") | |
| except Exception as e: | |
| print(f"[β ] Warmup failed: {e}") | |
| # βββ API Endpoints ββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def synthesize( | |
| background_tasks: BackgroundTasks, | |
| text: str = Body(None, embed=True), | |
| input: str = Body(None, embed=True), | |
| speed: float = Body(1.0, embed=True), | |
| volume: float = Body(1.0, embed=True), | |
| bypass_cache: bool = Body(False, embed=True) | |
| ): | |
| # HỠtrợ cả "text" và "input" parameter | |
| actual_text = text or input or "" | |
| if not actual_text.strip(): | |
| raise HTTPException(status_code=400, detail="VΔn bαΊ£n khΓ΄ng Δược Δα» trα»ng") | |
| t0 = time.time() | |
| # 1. TαΊ‘o hoαΊ·c lαΊ₯y 1x cache (chαΊ‘y trΓͺn worker thread) | |
| loop = asyncio.get_event_loop() | |
| if bypass_cache: | |
| # XΓ³a cache cΕ© nαΊΏu cΓ³ | |
| text_hash = hashlib.sha256(actual_text.encode("utf-8")).hexdigest() | |
| old_cache = CACHE_DIR / f"{text_hash}.wav" | |
| old_cache.unlink(missing_ok=True) | |
| try: | |
| cached_1x = await loop.run_in_executor(tts_executor, synthesise_full_text, actual_text) | |
| except Exception as e: | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| print(f"[β] Synthesis error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 2. Apply speed/volume (chα» FFmpeg, rαΊ₯t nhanh) | |
| temp_fd, temp_out = tempfile.mkstemp(suffix=".wav") | |
| os.close(temp_fd) | |
| try: | |
| apply_ffmpeg(cached_1x, temp_out, speed, volume) | |
| except Exception as e: | |
| cleanup_file(temp_out) | |
| raise HTTPException(status_code=500, detail=f"FFmpeg error: {e}") | |
| dt = time.time() - t0 | |
| is_cached = "CACHE" if os.path.getmtime(cached_1x) < t0 else "NEW" | |
| print(f"[β] Response: {len(actual_text)} chars | {is_cached} | speed={speed}x | {dt:.2f}s") | |
| background_tasks.add_task(cleanup_file, temp_out) | |
| return FileResponse(temp_out, media_type="audio/wav") | |
| async def clear_cache(): | |
| try: | |
| count = len(list(CACHE_DIR.glob("*.wav"))) | |
| shutil.rmtree(CACHE_DIR) | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| return {"status": "ok", "cleared": count} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def cache_stats(): | |
| files = list(CACHE_DIR.glob("*.wav")) | |
| total_bytes = sum(f.stat().st_size for f in files) | |
| return { | |
| "files": len(files), | |
| "total_mb": round(total_bytes / 1024 / 1024, 1), | |
| "max_files": CACHE_MAX_FILES, | |
| "cache_dir": str(CACHE_DIR) | |
| } | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "device": str(device), | |
| "workers": NUM_WORKERS, | |
| "models_loaded": list(models.keys()), | |
| } | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("api_server:app", host="0.0.0.0", port=port, log_level="info") | |