""" Matcha-TTS Standalone API Server ================================ Kiến trúc: Per-Core Worker + Cache 1x + Text Chunking - Mỗi request chỉ dùng 1 CPU/GPU core (torch.set_num_threads(1)) - Worker pool tự động scale theo số core (TTS_WORKERS env) - Cache 1x audio: đổi speed chỉ chạy FFmpeg, không chạy lại model - Text dài tự chia nhỏ theo câu để tránh OOM """ import os import sys import re import tempfile import subprocess import hashlib import shutil import time import torch import soundfile as sf import numpy as np import uvicorn import asyncio from concurrent.futures import ThreadPoolExecutor from pathlib import Path from fastapi import FastAPI, HTTPException, Body, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse # ─── Config ───────────────────────────────────────────────────── current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "Matcha-TTS")) CHECKPOINT_PATH = os.path.join(current_dir, "model", "checkpoint_016_fp16.ckpt") VOCODER_PATH = os.path.join(current_dir, "model", "generator_v1_fp16") SAMPLE_RATE = 22050 MAX_CHUNK_CHARS = 300 CLEANER = "basic_cleaners_vi_female" # Cache CACHE_DIR = Path(os.path.join(current_dir, "cache_1x")) CACHE_DIR.mkdir(exist_ok=True) CACHE_MAX_FILES = 500 # Worker pool — mỗi worker chiếm đúng 1 core NUM_WORKERS = int(os.environ.get("TTS_WORKERS", min(os.cpu_count() or 2, 4))) # Giới hạn PyTorch: mỗi inference call chỉ dùng 1 thread torch.set_num_threads(1) # ─── Imports from Matcha-TTS ──────────────────────────────────── from matcha.hifigan.config import v1 from matcha.hifigan.env import AttrDict from matcha.hifigan.models import Generator as HiFiGAN from matcha.models.matcha_tts import MatchaTTS from matcha.text import text_to_sequence from matcha.utils.utils import intersperse # ─── App ──────────────────────────────────────────────────────── app = FastAPI( title="Matcha-TTS Standalone API", description="Per-core worker TTS API with 1x cache and text chunking" ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") models = {} tts_executor = None # Initialized in startup # ─── Text Chunking ────────────────────────────────────────────── def split_text_into_chunks(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list: """Chia text thành các đoạn nhỏ: ưu tiên xuống dòng → dấu câu → force cut.""" # 1. Tách theo xuống dòng trước paragraphs = text.split("\n") chunks = [] for para in paragraphs: para = para.strip() if not para: continue if len(para) <= max_chars: chunks.append(para) continue # 2. Tách theo dấu câu sentences = re.split(r'(?<=[.!?。!?;;,,])\s*', para) current = "" for sent in sentences: sent = sent.strip() if not sent: continue if len(current) + len(sent) + 1 <= max_chars: current = (current + " " + sent).strip() else: if current: chunks.append(current) # 3. Force cut nếu câu đơn quá dài if len(sent) > max_chars: for i in range(0, len(sent), max_chars): chunks.append(sent[i:i + max_chars]) current = "" else: current = sent if current: chunks.append(current) return chunks if chunks else [text[:max_chars]] # ─── Synthesis (chạy trên worker thread) ──────────────────────── def synthesise_chunk(text_chunk: str) -> np.ndarray: """Synthesise 1 đoạn text ngắn. Chạy trên 1 core duy nhất.""" x = torch.tensor( intersperse(text_to_sequence(text_chunk, [CLEANER])[0], 0), dtype=torch.long, device=device, )[None] x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) with torch.inference_mode(): if device.type == "cuda": with torch.amp.autocast(device_type="cuda", dtype=torch.float16): output = models["matcha"].synthesise( x, x_lengths, n_timesteps=10, temperature=0.667, spks=None, length_scale=1.0 ) audio = models["vocoder"](output["mel"]).clamp(-1, 1).squeeze().cpu().numpy() else: output = models["matcha"].synthesise( x, x_lengths, n_timesteps=10, temperature=0.667, spks=None, length_scale=1.0 ) audio = models["vocoder"](output["mel"]).clamp(-1, 1).squeeze().cpu().numpy() if device.type == "cuda": torch.cuda.empty_cache() return audio.astype(np.float32) def synthesise_full_text(text: str) -> str: """Synthesise full text (chunked), lưu cache 1x, trả về path.""" text_hash = hashlib.sha256(text.encode("utf-8")).hexdigest() cached_path = CACHE_DIR / f"{text_hash}.wav" if cached_path.exists(): print(f"[✓] Cache HIT (hash: {text_hash[:8]})") return str(cached_path) # Cache miss → chạy model chunks = split_text_into_chunks(text, MAX_CHUNK_CHARS) print(f"[~] Cache MISS → Tổng hợp {len(chunks)} chunks") audio_parts = [] for i, chunk in enumerate(chunks): if not chunk.strip(): continue t0 = time.time() part = synthesise_chunk(chunk) dt = time.time() - t0 print(f" Chunk {i+1}/{len(chunks)}: {len(chunk)} chars → {len(part)/SAMPLE_RATE:.1f}s audio [{dt:.2f}s]") audio_parts.append(part) if not audio_parts: raise ValueError("Không tạo được âm thanh") audio = np.concatenate(audio_parts) max_val = np.max(np.abs(audio)) if max_val > 0: audio = (audio / max_val * 0.95).astype(np.float32) sf.write(str(cached_path), audio, SAMPLE_RATE) print(f"[✓] Đã lưu cache 1x: {cached_path.name}") # Cleanup cũ nếu quá nhiều cache_files = sorted(CACHE_DIR.glob("*.wav"), key=lambda f: f.stat().st_mtime) if len(cache_files) > CACHE_MAX_FILES: for old in cache_files[:len(cache_files) - CACHE_MAX_FILES]: old.unlink(missing_ok=True) return str(cached_path) # ─── FFmpeg Speed/Volume ──────────────────────────────────────── def apply_ffmpeg(input_path: str, output_path: str, speed: float, volume: float = 1.0): if abs(speed - 1.0) < 0.05 and abs(volume - 1.0) < 0.05: shutil.copy(input_path, output_path) return filters = [] if abs(volume - 1.0) >= 0.05: filters.append(f"volume={volume}") remaining = speed while remaining > 2.0: filters.append("atempo=2.0") remaining /= 2.0 while remaining < 0.5: filters.append("atempo=0.5") remaining /= 0.5 if abs(remaining - 1.0) > 0.01: filters.append(f"atempo={remaining}") filter_str = ",".join(filters) if filters else "anull" cmd = ["ffmpeg", "-y", "-i", input_path, "-filter:a", filter_str, output_path] subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) def cleanup_file(filepath: str): try: if os.path.exists(filepath): os.remove(filepath) except: pass # ─── Startup ──────────────────────────────────────────────────── @app.on_event("startup") def startup(): global tts_executor tts_executor = ThreadPoolExecutor(max_workers=NUM_WORKERS, thread_name_prefix="tts-worker") print(f"[+] Device: {device} | Workers: {NUM_WORKERS} | Cache: {CACHE_DIR}") print(f"[!] Loading Matcha-TTS checkpoint: {CHECKPOINT_PATH}") checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False) model = MatchaTTS(**checkpoint["hyper_parameters"]) model.load_state_dict(checkpoint["state_dict"]) model = model.to(device) if device.type == "cuda": model = model.half() else: model = model.float() model.eval() models["matcha"] = model print(f"[!] Loading HiFi-GAN Vocoder: {VOCODER_PATH}") h = AttrDict(v1) vocoder = HiFiGAN(h).to(device) vocoder.load_state_dict(torch.load(VOCODER_PATH, map_location=device)["generator"]) if device.type == "cuda": vocoder = vocoder.half() else: vocoder = vocoder.float() vocoder.eval() vocoder.remove_weight_norm() models["vocoder"] = vocoder print("[✓] All models loaded!") # Warmup print("[!] Warming up...") try: synthesise_chunk("khởi động") print("[✓] Warmup complete!") except Exception as e: print(f"[⚠] Warmup failed: {e}") # ─── API Endpoints ────────────────────────────────────────────── @app.post("/synthesize") @app.post("/v1/audio/speech") async def synthesize( background_tasks: BackgroundTasks, text: str = Body(None, embed=True), input: str = Body(None, embed=True), speed: float = Body(1.0, embed=True), volume: float = Body(1.0, embed=True), bypass_cache: bool = Body(False, embed=True) ): # Hỗ trợ cả "text" và "input" parameter actual_text = text or input or "" if not actual_text.strip(): raise HTTPException(status_code=400, detail="Văn bản không được để trống") t0 = time.time() # 1. Tạo hoặc lấy 1x cache (chạy trên worker thread) loop = asyncio.get_event_loop() if bypass_cache: # Xóa cache cũ nếu có text_hash = hashlib.sha256(actual_text.encode("utf-8")).hexdigest() old_cache = CACHE_DIR / f"{text_hash}.wav" old_cache.unlink(missing_ok=True) try: cached_1x = await loop.run_in_executor(tts_executor, synthesise_full_text, actual_text) except Exception as e: if device.type == "cuda": torch.cuda.empty_cache() print(f"[❌] Synthesis error: {e}") raise HTTPException(status_code=500, detail=str(e)) # 2. Apply speed/volume (chỉ FFmpeg, rất nhanh) temp_fd, temp_out = tempfile.mkstemp(suffix=".wav") os.close(temp_fd) try: apply_ffmpeg(cached_1x, temp_out, speed, volume) except Exception as e: cleanup_file(temp_out) raise HTTPException(status_code=500, detail=f"FFmpeg error: {e}") dt = time.time() - t0 is_cached = "CACHE" if os.path.getmtime(cached_1x) < t0 else "NEW" print(f"[✓] Response: {len(actual_text)} chars | {is_cached} | speed={speed}x | {dt:.2f}s") background_tasks.add_task(cleanup_file, temp_out) return FileResponse(temp_out, media_type="audio/wav") @app.post("/clear_cache") async def clear_cache(): try: count = len(list(CACHE_DIR.glob("*.wav"))) shutil.rmtree(CACHE_DIR) CACHE_DIR.mkdir(exist_ok=True) return {"status": "ok", "cleared": count} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/cache/stats") async def cache_stats(): files = list(CACHE_DIR.glob("*.wav")) total_bytes = sum(f.stat().st_size for f in files) return { "files": len(files), "total_mb": round(total_bytes / 1024 / 1024, 1), "max_files": CACHE_MAX_FILES, "cache_dir": str(CACHE_DIR) } @app.get("/health") async def health(): return { "status": "ok", "device": str(device), "workers": NUM_WORKERS, "models_loaded": list(models.keys()), } if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run("api_server:app", host="0.0.0.0", port=port, log_level="info")