Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| FastAPI TTS Server với Speed Control + Async Processing | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from typing import List | |
| import tempfile | |
| import os | |
| import time | |
| from datetime import datetime | |
| import soundfile as sf | |
| import numpy as np | |
| from pydub import AudioSegment | |
| import torch | |
| from vieneu_tts import VieNeuTTS | |
| import asyncio | |
| import concurrent.futures | |
| import threading | |
| # ========================================== | |
| # SETUP | |
| # ========================================== | |
| app = FastAPI(title="VieNeu-TTS API", version="1.0.0") | |
| # Global variables | |
| tts = None | |
| device = None | |
| # Async control - Semaphores để kiểm soát tài nguyên | |
| gpu_semaphore = None # Chỉ 1 GPU task cùng lúc | |
| cpu_semaphore = None # 3 CPU tasks cùng lúc | |
| io_semaphore = None # 5 I/O tasks cùng lúc | |
| # Thread pool cho blocking operations | |
| thread_pool = None | |
| # Voice samples | |
| VOICE_SAMPLES = { | |
| "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"}, | |
| "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"}, | |
| "Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"}, | |
| "Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"}, | |
| "Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"}, | |
| "Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"}, | |
| "Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"}, | |
| "Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"}, | |
| "Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"}, | |
| "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"}, | |
| } | |
| # Cache for reference codes | |
| reference_cache = {} | |
| # ========================================== | |
| # MODELS | |
| # ========================================== | |
| class TTSRequest(BaseModel): | |
| text: str | |
| voice_choice: str = "Tuyên (nam miền Bắc)" | |
| speed_factor: float = 1.0 | |
| class TTSResponse(BaseModel): | |
| audio_url: str | |
| status: str | |
| processing_time: float | |
| voice_used: str | |
| speed_applied: float | |
| # ========================================== | |
| # HELPER FUNCTIONS | |
| # ========================================== | |
| def apply_speed_control(audio, speed_factor): | |
| """Áp dụng speed control với Pydub""" | |
| if speed_factor == 1.0: | |
| return audio | |
| sr = 24000 | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| sf.write(tmp.name, audio, sr) | |
| tmp_path = tmp.name | |
| sound = AudioSegment.from_wav(tmp_path) | |
| new_frame_rate = int(sound.frame_rate * speed_factor) | |
| sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate}) | |
| sound_stretched = sound_stretched.set_frame_rate(sr) | |
| audio_stretched = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0 | |
| if sound_stretched.channels == 2: | |
| audio_stretched = audio_stretched.reshape((-1, 2)).mean(axis=1) | |
| os.unlink(tmp_path) | |
| return audio_stretched | |
| # ========================================== | |
| # API ENDPOINTS | |
| # ========================================== | |
| async def root(): | |
| return { | |
| "message": "VieNeu-TTS API Server with Async Processing", | |
| "version": "1.0.0", | |
| "available_voices": list(VOICE_SAMPLES.keys()), | |
| "async_features": { | |
| "gpu_semaphore": "2 concurrent GPU tasks", | |
| "cpu_semaphore": "4 concurrent CPU tasks", | |
| "io_semaphore": "6 concurrent I/O tasks", | |
| "thread_pool": "6 worker threads" | |
| }, | |
| "endpoints": { | |
| "POST /tts": "Synthesize speech (standard)", | |
| "POST /fast-tts": "Fast TTS for external apps", | |
| "POST /bulk-tts": "Bulk processing (up to 50 requests)", | |
| "GET /voices": "List available voices", | |
| "GET /health": "Health check with async status", | |
| "GET /status": "Detailed async resource status", | |
| "POST /admin/update_settings": "Update async settings real-time", | |
| "POST /admin/clear_cache": "Clear reference cache", | |
| "GET /admin/settings": "Get current settings" | |
| }, | |
| "external_app_recommendations": { | |
| "single_requests": "Use POST /fast-tts with return_base64=true", | |
| "batch_requests": "Use POST /bulk-tts for up to 50 requests", | |
| "performance_tips": [ | |
| "Use same voice for consecutive requests (cache benefit)", | |
| "Keep text under 200 characters for best speed", | |
| "Use return_base64=true to avoid file I/O", | |
| "Consider bulk-tts for batches of 10-50 requests" | |
| ] | |
| } | |
| } | |
| async def get_voices(): | |
| return { | |
| "voices": list(VOICE_SAMPLES.keys()), | |
| "total": len(VOICE_SAMPLES) | |
| } | |
| async def health_check(): | |
| model_status = "loaded" if tts is not None else "not_loaded" | |
| # Check CUDA memory if using GPU | |
| cuda_info = {} | |
| if device == "cuda" and torch.cuda.is_available(): | |
| cuda_info = { | |
| "cuda_memory_allocated": f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB", | |
| "cuda_memory_reserved": f"{torch.cuda.memory_reserved(0) / 1024**3:.2f} GB" | |
| } | |
| # Async resource status | |
| async_status = {} | |
| if gpu_semaphore and cpu_semaphore and io_semaphore: | |
| async_status = { | |
| "gpu_available": gpu_semaphore._value, | |
| "cpu_available": cpu_semaphore._value, | |
| "io_available": io_semaphore._value, | |
| "thread_pool_active": thread_pool._threads if thread_pool else 0 | |
| } | |
| return { | |
| "status": "healthy", | |
| "model_status": model_status, | |
| "device": device, | |
| "cache_size": len(reference_cache), | |
| "async_resources": async_status, | |
| "timestamp": datetime.now().isoformat(), | |
| **cuda_info | |
| } | |
| async def get_status(): | |
| """Detailed server status including async resource usage""" | |
| # Semaphore status | |
| semaphore_status = {} | |
| if gpu_semaphore and cpu_semaphore and io_semaphore: | |
| semaphore_status = { | |
| "gpu_semaphore": { | |
| "available": gpu_semaphore._value, | |
| "max_capacity": 2, | |
| "in_use": 2 - gpu_semaphore._value | |
| }, | |
| "cpu_semaphore": { | |
| "available": cpu_semaphore._value, | |
| "max_capacity": 4, | |
| "in_use": 4 - cpu_semaphore._value | |
| }, | |
| "io_semaphore": { | |
| "available": io_semaphore._value, | |
| "max_capacity": 6, | |
| "in_use": 6 - io_semaphore._value | |
| } | |
| } | |
| # Thread pool status | |
| thread_status = {} | |
| if thread_pool: | |
| thread_status = { | |
| "max_workers": thread_pool._max_workers, | |
| "active_threads": len(thread_pool._threads) if hasattr(thread_pool, '_threads') else 0 | |
| } | |
| # Model and cache info | |
| model_info = { | |
| "model_loaded": tts is not None, | |
| "device": device, | |
| "reference_cache_size": len(reference_cache), | |
| "cached_voices": list(reference_cache.keys()) | |
| } | |
| return { | |
| "server_status": "running", | |
| "async_processing": semaphore_status, | |
| "thread_pool": thread_status, | |
| "model_info": model_info, | |
| "available_voices": list(VOICE_SAMPLES.keys()), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def _load_reference_text(ref_text_path: str) -> str: | |
| """Load reference text with I/O semaphore""" | |
| async with io_semaphore: | |
| loop = asyncio.get_event_loop() | |
| with open(ref_text_path, "r", encoding="utf-8") as f: | |
| return await loop.run_in_executor(thread_pool, f.read) | |
| async def _encode_reference_async(ref_audio_path: str, cache_key: str, request_id: str = "") -> torch.Tensor: | |
| """Encode reference audio with GPU semaphore""" | |
| async with gpu_semaphore: | |
| print(f" 🔄 [{request_id}] Encoding reference for {cache_key}...") | |
| loop = asyncio.get_event_loop() | |
| def encode_sync(): | |
| try: | |
| # Clear CUDA cache before encoding | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| ref_codes = tts.encode_reference(ref_audio_path) | |
| # Ensure ref_codes is on CPU for caching | |
| if hasattr(ref_codes, 'cpu'): | |
| ref_codes = ref_codes.cpu() | |
| return ref_codes | |
| except Exception as e: | |
| print(f" ❌ [{request_id}] Failed to encode reference: {e}") | |
| raise e | |
| ref_codes = await loop.run_in_executor(thread_pool, encode_sync) | |
| reference_cache[cache_key] = ref_codes | |
| print(f" ✅ [{request_id}] Reference encoded and cached") | |
| return ref_codes | |
| async def _generate_speech_async(text: str, ref_codes: torch.Tensor, ref_text_raw: str, request_id: str = "") -> np.ndarray: | |
| """Generate speech with GPU semaphore""" | |
| async with gpu_semaphore: | |
| print(f" 🎵 [{request_id}] Generating speech...") | |
| loop = asyncio.get_event_loop() | |
| def infer_sync(): | |
| try: | |
| # Clear CUDA cache before inference | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| wav = tts.infer(text, ref_codes, ref_text_raw) | |
| return wav | |
| except Exception as e: | |
| print(f" ❌ [{request_id}] Failed to generate speech: {e}") | |
| raise e | |
| wav = await loop.run_in_executor(thread_pool, infer_sync) | |
| print(f" ✅ [{request_id}] Speech generated") | |
| return wav | |
| async def _apply_speed_control_async(audio: np.ndarray, speed_factor: float, request_id: str = "") -> np.ndarray: | |
| """Apply speed control with CPU semaphore""" | |
| if speed_factor == 1.0: | |
| return audio | |
| async with cpu_semaphore: | |
| print(f" 🎚️ [{request_id}] Applying speed control: {speed_factor}x") | |
| loop = asyncio.get_event_loop() | |
| def speed_control_sync(): | |
| return apply_speed_control(audio, speed_factor) | |
| return await loop.run_in_executor(thread_pool, speed_control_sync) | |
| async def _save_audio_async(wav: np.ndarray, output_path: str) -> None: | |
| """Save audio file with I/O semaphore""" | |
| async with io_semaphore: | |
| loop = asyncio.get_event_loop() | |
| def save_sync(): | |
| sf.write(output_path, wav, 24000) | |
| await loop.run_in_executor(thread_pool, save_sync) | |
| async def synthesize_speech(request: TTSRequest): | |
| """ | |
| Tổng hợp giọng nói với speed control - Async Processing | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Validate input | |
| if not request.text or len(request.text.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| if len(request.text) > 500: | |
| raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
| if request.voice_choice not in VOICE_SAMPLES: | |
| raise HTTPException(status_code=400, detail=f"Voice not found. Available: {list(VOICE_SAMPLES.keys())}") | |
| if not (0.5 <= request.speed_factor <= 2.0): | |
| raise HTTPException(status_code=400, detail="Speed factor must be between 0.5 and 2.0") | |
| request_id = f"REQ-{int(time.time() * 1000) % 100000}" | |
| print(f"🎤 [{request_id}] Processing: {request.text[:50]}...") | |
| print(f"🎚️ [{request_id}] Voice: {request.voice_choice}, Speed: {request.speed_factor}x") | |
| # Get reference audio and text paths | |
| voice_info = VOICE_SAMPLES[request.voice_choice] | |
| ref_audio_path = voice_info["audio"] | |
| ref_text_path = voice_info["text"] | |
| if not os.path.exists(ref_audio_path): | |
| raise HTTPException(status_code=500, detail=f"Reference audio not found: {ref_audio_path}") | |
| # Load reference text (async I/O) | |
| ref_text_raw = await _load_reference_text(ref_text_path) | |
| # Encode reference (with cache) - async GPU | |
| cache_key = request.voice_choice | |
| if cache_key in reference_cache: | |
| print(f" ✨ [{request_id}] Using cached reference for {cache_key}") | |
| ref_codes = reference_cache[cache_key] | |
| else: | |
| ref_codes = await _encode_reference_async(ref_audio_path, cache_key, request_id) | |
| # Generate speech - async GPU | |
| wav = await _generate_speech_async(request.text, ref_codes, ref_text_raw, request_id) | |
| # Apply speed control - async CPU | |
| wav = await _apply_speed_control_async(wav, request.speed_factor, request_id) | |
| # Prepare output path | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # milliseconds | |
| filename = f"tts_{request.speed_factor}x_{timestamp}.wav" | |
| # Create temp directory if not exists | |
| temp_dir = "./temp_audio" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| output_path = os.path.join(temp_dir, filename) | |
| # Save output - async I/O | |
| await _save_audio_async(wav, output_path) | |
| processing_time = time.time() - start_time | |
| print(f" ✅ [{request_id}] Success! Processing time: {processing_time:.2f}s") | |
| print(f" 📁 [{request_id}] Saved: {output_path}") | |
| # Return response | |
| return TTSResponse( | |
| audio_url=f"/audio/{filename}", | |
| status="success", | |
| processing_time=processing_time, | |
| voice_used=request.voice_choice, | |
| speed_applied=request.speed_factor | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| processing_time = time.time() - start_time | |
| print(f" ❌ [{request_id}] Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def get_audio(filename: str): | |
| """ | |
| Tải file audio đã tạo | |
| """ | |
| file_path = os.path.join("./temp_audio", filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return FileResponse( | |
| path=file_path, | |
| media_type="audio/wav", | |
| filename=filename | |
| ) | |
| # ========================================== | |
| # ADMIN ENDPOINTS | |
| # ========================================== | |
| class SettingsUpdate(BaseModel): | |
| gpu_semaphore: int = 2 | |
| cpu_semaphore: int = 4 | |
| io_semaphore: int = 6 | |
| thread_pool: int = 6 | |
| async def update_settings(settings: SettingsUpdate): | |
| """ | |
| Cập nhật async settings real-time | |
| """ | |
| global gpu_semaphore, cpu_semaphore, io_semaphore, thread_pool | |
| try: | |
| # Validate settings | |
| if not (1 <= settings.gpu_semaphore <= 4): | |
| raise HTTPException(status_code=400, detail="GPU semaphore must be between 1-4") | |
| if not (2 <= settings.cpu_semaphore <= 16): | |
| raise HTTPException(status_code=400, detail="CPU semaphore must be between 2-16") | |
| if not (3 <= settings.io_semaphore <= 16): | |
| raise HTTPException(status_code=400, detail="I/O semaphore must be between 3-16") | |
| if not (2 <= settings.thread_pool <= 20): | |
| raise HTTPException(status_code=400, detail="Thread pool must be between 2-20") | |
| # Update semaphores | |
| gpu_semaphore = asyncio.Semaphore(settings.gpu_semaphore) | |
| cpu_semaphore = asyncio.Semaphore(settings.cpu_semaphore) | |
| io_semaphore = asyncio.Semaphore(settings.io_semaphore) | |
| # Update thread pool (need to shutdown old one) | |
| if thread_pool: | |
| old_pool = thread_pool | |
| thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=settings.thread_pool) | |
| # Shutdown old pool gracefully | |
| threading.Thread(target=lambda: old_pool.shutdown(wait=True), daemon=True).start() | |
| else: | |
| thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=settings.thread_pool) | |
| print(f"🔄 Settings updated: GPU({settings.gpu_semaphore}) CPU({settings.cpu_semaphore}) I/O({settings.io_semaphore}) Threads({settings.thread_pool})") | |
| return { | |
| "status": "success", | |
| "message": "Settings updated successfully", | |
| "new_settings": { | |
| "gpu_semaphore": settings.gpu_semaphore, | |
| "cpu_semaphore": settings.cpu_semaphore, | |
| "io_semaphore": settings.io_semaphore, | |
| "thread_pool": settings.thread_pool | |
| }, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"❌ Failed to update settings: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to update settings: {str(e)}") | |
| async def clear_cache(): | |
| """ | |
| Xóa reference cache | |
| """ | |
| global reference_cache | |
| try: | |
| cache_size = len(reference_cache) | |
| reference_cache.clear() | |
| # Clear CUDA cache if available | |
| if device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Cache cleared: {cache_size} references removed") | |
| return { | |
| "status": "success", | |
| "message": f"Cache cleared successfully. {cache_size} references removed.", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| print(f"❌ Failed to clear cache: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}") | |
| async def get_current_settings(): | |
| """ | |
| Lấy settings hiện tại | |
| """ | |
| current_settings = { | |
| "gpu_semaphore": { | |
| "current": gpu_semaphore._value if gpu_semaphore else 0, | |
| "max_capacity": 2 # Default, will be updated based on actual semaphore | |
| }, | |
| "cpu_semaphore": { | |
| "current": cpu_semaphore._value if cpu_semaphore else 0, | |
| "max_capacity": 4 | |
| }, | |
| "io_semaphore": { | |
| "current": io_semaphore._value if io_semaphore else 0, | |
| "max_capacity": 6 | |
| }, | |
| "thread_pool": { | |
| "max_workers": thread_pool._max_workers if thread_pool else 0, | |
| "active_threads": len(thread_pool._threads) if thread_pool and hasattr(thread_pool, '_threads') else 0 | |
| } | |
| } | |
| return { | |
| "status": "success", | |
| "settings": current_settings, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # ========================================== | |
| # FAST TTS ENDPOINT FOR EXTERNAL APPS | |
| # ========================================== | |
| class FastTTSRequest(BaseModel): | |
| text: str | |
| voice_choice: str = "Tuyên (nam miền Bắc)" | |
| speed_factor: float = 1.0 | |
| return_base64: bool = False # Option to return audio as base64 | |
| skip_file_save: bool = False # Option to skip saving file | |
| async def fast_tts(request: FastTTSRequest): | |
| """ | |
| Fast TTS endpoint tối ưu cho external apps gửi nhiều requests | |
| - Ít validation hơn | |
| - Có thể return base64 thay vì file | |
| - Có thể skip file saving | |
| """ | |
| start_time = time.time() | |
| request_id = f"FAST-{int(time.time() * 1000) % 100000}" | |
| try: | |
| # Minimal validation | |
| if not request.text or len(request.text.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| if len(request.text) > 1000: # Increased limit for external apps | |
| raise HTTPException(status_code=400, detail="Text too long (max 1000 characters)") | |
| if request.voice_choice not in VOICE_SAMPLES: | |
| # Auto fallback to default voice instead of error | |
| request.voice_choice = "Tuyên (nam miền Bắc)" | |
| if not (0.5 <= request.speed_factor <= 2.0): | |
| request.speed_factor = 1.0 # Auto fallback instead of error | |
| print(f"⚡ [{request_id}] Fast processing: {request.text[:30]}... | {request.voice_choice} | {request.speed_factor}x") | |
| # Get reference (with cache) | |
| voice_info = VOICE_SAMPLES[request.voice_choice] | |
| ref_audio_path = voice_info["audio"] | |
| ref_text_path = voice_info["text"] | |
| # Load reference text (async I/O) | |
| ref_text_raw = await _load_reference_text(ref_text_path) | |
| # Encode reference (with cache) - async GPU | |
| cache_key = request.voice_choice | |
| if cache_key in reference_cache: | |
| ref_codes = reference_cache[cache_key] | |
| else: | |
| ref_codes = await _encode_reference_async(ref_audio_path, cache_key, request_id) | |
| # Generate speech - async GPU | |
| wav = await _generate_speech_async(request.text, ref_codes, ref_text_raw, request_id) | |
| # Apply speed control - async CPU | |
| wav = await _apply_speed_control_async(wav, request.speed_factor, request_id) | |
| processing_time = time.time() - start_time | |
| # Return options | |
| if request.return_base64: | |
| # Return audio as base64 (no file saving) | |
| import base64 | |
| import io | |
| # Convert to bytes | |
| audio_buffer = io.BytesIO() | |
| sf.write(audio_buffer, wav, 24000, format='WAV') | |
| audio_bytes = audio_buffer.getvalue() | |
| audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
| print(f" ✅ [{request_id}] Fast success (base64): {processing_time:.2f}s") | |
| return { | |
| "audio_base64": audio_base64, | |
| "status": "success", | |
| "processing_time": processing_time, | |
| "voice_used": request.voice_choice, | |
| "speed_applied": request.speed_factor, | |
| "format": "wav", | |
| "sample_rate": 24000 | |
| } | |
| elif request.skip_file_save: | |
| # Return raw audio data info (for streaming) | |
| print(f" ✅ [{request_id}] Fast success (no save): {processing_time:.2f}s") | |
| return { | |
| "status": "success", | |
| "processing_time": processing_time, | |
| "voice_used": request.voice_choice, | |
| "speed_applied": request.speed_factor, | |
| "audio_length": len(wav), | |
| "sample_rate": 24000, | |
| "message": "Audio generated but not saved" | |
| } | |
| else: | |
| # Standard file saving | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] | |
| filename = f"fast_{request.speed_factor}x_{timestamp}.wav" | |
| temp_dir = "./temp_audio" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| output_path = os.path.join(temp_dir, filename) | |
| await _save_audio_async(wav, output_path) | |
| print(f" ✅ [{request_id}] Fast success: {processing_time:.2f}s | {filename}") | |
| return { | |
| "audio_url": f"/audio/{filename}", | |
| "status": "success", | |
| "processing_time": processing_time, | |
| "voice_used": request.voice_choice, | |
| "speed_applied": request.speed_factor, | |
| "filename": filename | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| processing_time = time.time() - start_time | |
| print(f" ❌ [{request_id}] Fast error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Fast TTS error: {str(e)}") | |
| async def bulk_tts(requests: List[FastTTSRequest]): | |
| """ | |
| Bulk TTS endpoint - xử lý nhiều requests cùng lúc | |
| Tối ưu cho external apps gửi batch | |
| """ | |
| if len(requests) > 50: # Limit batch size | |
| raise HTTPException(status_code=400, detail="Too many requests in batch (max 50)") | |
| start_time = time.time() | |
| batch_id = f"BULK-{int(time.time() * 1000) % 100000}" | |
| print(f"📦 [{batch_id}] Processing bulk: {len(requests)} requests") | |
| # Process all requests concurrently | |
| async def process_single(req, index): | |
| try: | |
| # Add index to request for tracking | |
| req_copy = req.copy() | |
| result = await fast_tts(req_copy) | |
| return {"index": index, "status": "success", "result": result} | |
| except Exception as e: | |
| return {"index": index, "status": "error", "error": str(e)} | |
| # Create tasks for all requests | |
| tasks = [process_single(req, i) for i, req in enumerate(requests)] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Process results | |
| processed_results = [] | |
| for result in results: | |
| if isinstance(result, Exception): | |
| processed_results.append({"status": "exception", "error": str(result)}) | |
| else: | |
| processed_results.append(result) | |
| total_time = time.time() - start_time | |
| success_count = len([r for r in processed_results if r.get("status") == "success"]) | |
| print(f" ✅ [{batch_id}] Bulk completed: {success_count}/{len(requests)} success in {total_time:.2f}s") | |
| return { | |
| "batch_id": batch_id, | |
| "total_requests": len(requests), | |
| "successful": success_count, | |
| "failed": len(requests) - success_count, | |
| "total_time": total_time, | |
| "avg_time_per_request": total_time / len(requests), | |
| "results": processed_results | |
| } | |
| # ========================================== | |
| # STARTUP EVENT | |
| # ========================================== | |
| async def startup_event(): | |
| global tts, device, gpu_semaphore, cpu_semaphore, io_semaphore, thread_pool | |
| print("=" * 60) | |
| print("🎙️ VieNeu-TTS FastAPI Server (Async)") | |
| print("=" * 60) | |
| # Setup async controls | |
| gpu_semaphore = asyncio.Semaphore(2) # 2 GPU tasks (parallel inference) | |
| cpu_semaphore = asyncio.Semaphore(4) # 4 CPU tasks (more speed processing) | |
| io_semaphore = asyncio.Semaphore(6) # 6 I/O tasks (more file operations) | |
| thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=6) | |
| print("🔄 Async setup: GPU(2) | CPU(4) | I/O(6) | ThreadPool(6)") | |
| # Device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🖥️ Using device: {device}") | |
| # Check if local backbone exists | |
| local_backbone = "./models/VieNeu-TTS" | |
| if os.path.exists(local_backbone): | |
| print("📦 Loading VieNeu-TTS model (hybrid: local backbone + online codec)...") | |
| backbone_repo = local_backbone | |
| codec_repo = "neuphonic/neucodec" # Codec must be online (VieNeuTTS limitation) | |
| print(" 🔧 Using local backbone (no internet for backbone)") | |
| print(" 🌐 Using online codec (small download)") | |
| else: | |
| print("📦 Loading VieNeu-TTS model from HuggingFace...") | |
| backbone_repo = "pnnbao-ump/VieNeu-TTS" | |
| codec_repo = "neuphonic/neucodec" | |
| print(" 🌐 Using online models (internet required)") | |
| print(" 💡 Run 'python download_models.py' to use local backbone") | |
| try: | |
| tts = VieNeuTTS( | |
| backbone_repo=backbone_repo, | |
| backbone_device=device, | |
| codec_repo=codec_repo, | |
| codec_device=device | |
| ) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Failed to load model: {e}") | |
| if not os.path.exists(local_backbone): | |
| print("💡 Try running: python download_models.py") | |
| raise e | |
| print(f"📦 Model: VieNeu-TTS-1000h") | |
| print(f"🎚️ Speed Control: Pydub") | |
| print("=" * 60) | |
| async def shutdown_event(): | |
| global thread_pool | |
| print("🔄 Shutting down server...") | |
| # Cleanup thread pool | |
| if thread_pool: | |
| print(" 🧹 Shutting down thread pool...") | |
| thread_pool.shutdown(wait=True) | |
| print(" ✅ Thread pool shutdown complete") | |
| # Clear CUDA cache | |
| if device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(" 🧹 CUDA cache cleared") | |
| print("✅ Server shutdown complete") | |
| # ========================================== | |
| # STARTUP | |
| # ========================================== | |
| def start_gui(): | |
| """Start GUI in separate thread""" | |
| import tkinter as tk | |
| from tkinter import ttk, messagebox | |
| import webbrowser | |
| import os | |
| import sys | |
| class ServerGUI: | |
| def __init__(self, root): | |
| self.root = root | |
| self.root.title("VieNeu-TTS Server Control") | |
| self.root.geometry("600x500") | |
| # Handle window close event | |
| self.root.protocol("WM_DELETE_WINDOW", self.on_closing) | |
| # Add menu bar | |
| self.setup_menu() | |
| # Server info | |
| info_frame = ttk.LabelFrame(root, text="Server Information", padding="10") | |
| info_frame.pack(fill=tk.X, padx=10, pady=5) | |
| ttk.Label(info_frame, text="🎙️ VieNeu-TTS FastAPI Server", font=("Arial", 14, "bold")).pack() | |
| url_frame = ttk.Frame(info_frame) | |
| url_frame.pack() | |
| ttk.Label(url_frame, text="Server URL: ").pack(side=tk.LEFT) | |
| url_label = ttk.Label(url_frame, text="http://127.0.0.1:8000", foreground="blue", cursor="hand2") | |
| url_label.pack(side=tk.LEFT) | |
| url_label.bind("<Button-1>", lambda e: webbrowser.open("http://127.0.0.1:8000")) | |
| self.status_label = ttk.Label(info_frame, text="Status: ✅ Running", foreground="green") | |
| self.status_label.pack() | |
| # Quick actions | |
| actions_frame = ttk.LabelFrame(root, text="Quick Actions", padding="10") | |
| actions_frame.pack(fill=tk.X, padx=10, pady=5) | |
| btn_frame = ttk.Frame(actions_frame) | |
| btn_frame.pack() | |
| ttk.Button(btn_frame, text="Open API Docs", command=self.open_docs).pack(side=tk.LEFT, padx=5) | |
| ttk.Button(btn_frame, text="Test Server", command=self.test_server).pack(side=tk.LEFT, padx=5) | |
| ttk.Button(btn_frame, text="Clear Cache", command=self.clear_cache).pack(side=tk.LEFT, padx=5) | |
| # Shutdown button | |
| shutdown_btn = ttk.Button(btn_frame, text="Tắt Server", command=self.shutdown_server) | |
| shutdown_btn.pack(side=tk.RIGHT, padx=5) | |
| shutdown_btn.configure(style="Accent.TButton") # Make it stand out | |
| # Settings control | |
| settings_frame = ttk.LabelFrame(root, text="Async Settings Control", padding="10") | |
| settings_frame.pack(fill=tk.X, padx=10, pady=5) | |
| # Current settings display | |
| self.settings_text = tk.Text(settings_frame, height=4, width=60, state="disabled") | |
| self.settings_text.pack(pady=(0, 10)) | |
| # Settings controls | |
| control_frame = ttk.Frame(settings_frame) | |
| control_frame.pack(fill=tk.X) | |
| # GPU Semaphore | |
| gpu_frame = ttk.Frame(control_frame) | |
| gpu_frame.pack(fill=tk.X, pady=2) | |
| ttk.Label(gpu_frame, text="GPU Semaphore (1-4):").pack(side=tk.LEFT) | |
| self.gpu_var = tk.StringVar(value="2") | |
| ttk.Spinbox(gpu_frame, from_=1, to=4, width=10, textvariable=self.gpu_var).pack(side=tk.RIGHT) | |
| # CPU Semaphore | |
| cpu_frame = ttk.Frame(control_frame) | |
| cpu_frame.pack(fill=tk.X, pady=2) | |
| ttk.Label(cpu_frame, text="CPU Semaphore (2-16):").pack(side=tk.LEFT) | |
| self.cpu_var = tk.StringVar(value="4") | |
| ttk.Spinbox(cpu_frame, from_=2, to=16, width=10, textvariable=self.cpu_var).pack(side=tk.RIGHT) | |
| # I/O Semaphore | |
| io_frame = ttk.Frame(control_frame) | |
| io_frame.pack(fill=tk.X, pady=2) | |
| ttk.Label(io_frame, text="I/O Semaphore (3-16):").pack(side=tk.LEFT) | |
| self.io_var = tk.StringVar(value="6") | |
| ttk.Spinbox(io_frame, from_=3, to=16, width=10, textvariable=self.io_var).pack(side=tk.RIGHT) | |
| # Thread Pool | |
| thread_frame = ttk.Frame(control_frame) | |
| thread_frame.pack(fill=tk.X, pady=2) | |
| ttk.Label(thread_frame, text="Thread Pool (2-20):").pack(side=tk.LEFT) | |
| self.thread_var = tk.StringVar(value="6") | |
| ttk.Spinbox(thread_frame, from_=2, to=20, width=10, textvariable=self.thread_var).pack(side=tk.RIGHT) | |
| # Apply button | |
| ttk.Button(control_frame, text="Apply Settings", command=self.apply_settings).pack(pady=10) | |
| # Presets | |
| presets_frame = ttk.LabelFrame(settings_frame, text="Performance Presets") | |
| presets_frame.pack(fill=tk.X, pady=(10, 0)) | |
| preset_grid = ttk.Frame(presets_frame) | |
| preset_grid.pack(pady=5) | |
| ttk.Button(preset_grid, text="Light (1,2,4,4)", command=lambda: self.apply_preset(1,2,4,4), width=15).grid(row=0, column=0, padx=2, pady=2) | |
| ttk.Button(preset_grid, text="Balanced (2,4,6,6)", command=lambda: self.apply_preset(2,4,6,6), width=15).grid(row=0, column=1, padx=2, pady=2) | |
| ttk.Button(preset_grid, text="Performance (3,8,10,10)", command=lambda: self.apply_preset(3,8,10,10), width=15).grid(row=1, column=0, padx=2, pady=2) | |
| ttk.Button(preset_grid, text="Ultra (4,12,12,16)", command=lambda: self.apply_preset(4,12,12,16), width=15).grid(row=1, column=1, padx=2, pady=2) | |
| # Status monitor | |
| monitor_frame = ttk.LabelFrame(root, text="Resource Monitor", padding="10") | |
| monitor_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) | |
| self.monitor_text = tk.Text(monitor_frame, height=8, state="disabled") | |
| scrollbar = ttk.Scrollbar(monitor_frame, orient="vertical", command=self.monitor_text.yview) | |
| self.monitor_text.configure(yscrollcommand=scrollbar.set) | |
| self.monitor_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) | |
| scrollbar.pack(side=tk.RIGHT, fill=tk.Y) | |
| # Start monitoring | |
| self.update_display() | |
| self.start_monitoring() | |
| def setup_menu(self): | |
| """Setup menu bar""" | |
| menubar = tk.Menu(self.root) | |
| self.root.config(menu=menubar) | |
| # File menu | |
| file_menu = tk.Menu(menubar, tearoff=0) | |
| menubar.add_cascade(label="File", menu=file_menu) | |
| file_menu.add_command(label="Ẩn giao diện", command=lambda: self.root.withdraw()) | |
| file_menu.add_separator() | |
| file_menu.add_command(label="Tắt server", command=self.shutdown_server) | |
| # View menu | |
| view_menu = tk.Menu(menubar, tearoff=0) | |
| menubar.add_cascade(label="View", menu=view_menu) | |
| view_menu.add_command(label="Refresh", command=self.update_display) | |
| view_menu.add_command(label="Open API Docs", command=self.open_docs) | |
| # Help menu | |
| help_menu = tk.Menu(menubar, tearoff=0) | |
| menubar.add_cascade(label="Help", menu=help_menu) | |
| help_menu.add_command(label="About", command=self.show_about) | |
| def show_about(self): | |
| """Show about dialog""" | |
| about_text = """VieNeu-TTS Server Control | |
| Version: 1.0.0 | |
| Server: FastAPI with Async Processing | |
| Model: VieNeu-TTS-1000h | |
| Features: | |
| • Real-time settings adjustment | |
| • Performance monitoring | |
| • Cache management | |
| • Multiple presets | |
| Server URL: http://127.0.0.1:8000""" | |
| messagebox.showinfo("About", about_text) | |
| def on_closing(self): | |
| """Handle window closing event""" | |
| result = messagebox.askyesnocancel( | |
| "Tắt Server", | |
| "Đóng giao diện sẽ tắt server.\n\n" + | |
| "• Yes: Tắt server và giao diện\n" + | |
| "• No: Chỉ ẩn giao diện (server vẫn chạy)\n" + | |
| "• Cancel: Không làm gì" | |
| ) | |
| if result is True: # Yes - Shutdown server | |
| self.log_message("🔄 Đang tắt server...") | |
| self.status_label.config(text="Status: 🔄 Shutting down...", foreground="orange") | |
| self.root.destroy() | |
| # Force exit the entire application | |
| import threading | |
| def force_exit(): | |
| import time | |
| time.sleep(1) # Give time for cleanup | |
| os._exit(0) # Force exit | |
| threading.Thread(target=force_exit, daemon=True).start() | |
| elif result is False: # No - Just hide GUI | |
| self.log_message("ℹ️ Giao diện đã ẩn. Server vẫn chạy tại http://127.0.0.1:8000") | |
| self.root.withdraw() # Hide window instead of destroying | |
| # Add system tray notification (if possible) | |
| try: | |
| import subprocess | |
| subprocess.run([ | |
| 'powershell', '-Command', | |
| f'Add-Type -AssemblyName System.Windows.Forms; ' + | |
| f'[System.Windows.Forms.MessageBox]::Show("Server vẫn chạy tại http://127.0.0.1:8000", "VieNeu-TTS", "OK", "Information")' | |
| ], capture_output=True) | |
| except: | |
| pass | |
| # If Cancel (None), do nothing | |
| def open_docs(self): | |
| webbrowser.open("http://127.0.0.1:8000/docs") | |
| def test_server(self): | |
| import requests | |
| try: | |
| test_data = { | |
| "text": "Test từ giao diện server", | |
| "voice_choice": "Tuyên (nam miền Bắc)", | |
| "speed_factor": 1.0 | |
| } | |
| response = requests.post("http://127.0.0.1:8000/tts", json=test_data, timeout=30) | |
| if response.status_code == 200: | |
| result = response.json() | |
| self.log_message(f"✅ Test thành công! Thời gian: {result.get('processing_time', 0):.2f}s") | |
| else: | |
| self.log_message(f"❌ Test thất bại: HTTP {response.status_code}") | |
| except Exception as e: | |
| self.log_message(f"❌ Lỗi test: {str(e)}") | |
| def clear_cache(self): | |
| import requests | |
| try: | |
| response = requests.post("http://127.0.0.1:8000/admin/clear_cache", timeout=5) | |
| if response.status_code == 200: | |
| self.log_message("✅ Cache đã được xóa") | |
| else: | |
| self.log_message("❌ Không thể xóa cache") | |
| except Exception as e: | |
| self.log_message(f"❌ Lỗi xóa cache: {str(e)}") | |
| def shutdown_server(self): | |
| """Shutdown server gracefully""" | |
| if messagebox.askokcancel("Tắt Server", "Bạn có chắc chắn muốn tắt server?"): | |
| self.log_message("🔄 Đang tắt server...") | |
| self.status_label.config(text="Status: 🔄 Shutting down...", foreground="orange") | |
| # Close the GUI and exit | |
| self.root.after(1000, lambda: [self.root.destroy(), os._exit(0)]) | |
| def apply_preset(self, gpu, cpu, io, threads): | |
| preset_names = { | |
| (1,2,4,4): "Light", | |
| (2,4,6,6): "Balanced", | |
| (3,8,10,10): "Performance", | |
| (4,12,12,16): "Ultra" | |
| } | |
| preset_name = preset_names.get((gpu, cpu, io, threads), "Custom") | |
| self.log_message(f"🎯 Áp dụng preset {preset_name}...") | |
| self.gpu_var.set(str(gpu)) | |
| self.cpu_var.set(str(cpu)) | |
| self.io_var.set(str(io)) | |
| self.thread_var.set(str(threads)) | |
| self.apply_settings() | |
| def apply_settings(self): | |
| import requests | |
| try: | |
| settings = { | |
| "gpu_semaphore": int(self.gpu_var.get()), | |
| "cpu_semaphore": int(self.cpu_var.get()), | |
| "io_semaphore": int(self.io_var.get()), | |
| "thread_pool": int(self.thread_var.get()) | |
| } | |
| self.log_message(f"🔄 Đang áp dụng cài đặt...") | |
| response = requests.post("http://127.0.0.1:8000/admin/update_settings", json=settings, timeout=5) | |
| if response.status_code == 200: | |
| self.log_message(f"✅ Cài đặt đã áp dụng: GPU({settings['gpu_semaphore']}) CPU({settings['cpu_semaphore']}) I/O({settings['io_semaphore']}) Threads({settings['thread_pool']})") | |
| # Update display after a short delay to see the changes | |
| self.root.after(1000, self.update_display) | |
| else: | |
| self.log_message(f"❌ Không thể áp dụng cài đặt: {response.text}") | |
| except Exception as e: | |
| self.log_message(f"❌ Lỗi áp dụng cài đặt: {str(e)}") | |
| def update_display(self): | |
| import requests | |
| try: | |
| response = requests.get("http://127.0.0.1:8000/status", timeout=3) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # Update settings display | |
| settings_info = "=== CÀI ĐẶT HIỆN TẠI ===\n" | |
| if 'async_processing' in data: | |
| async_data = data['async_processing'] | |
| for resource, info in async_data.items(): | |
| available = info.get('available', 0) | |
| max_cap = info.get('max_capacity', 0) | |
| in_use = max(0, max_cap - available) # Ensure non-negative | |
| settings_info += f"{resource}: {max_cap} max, {in_use} đang dùng\n" | |
| self.settings_text.config(state="normal") | |
| self.settings_text.delete(1.0, tk.END) | |
| self.settings_text.insert(1.0, settings_info) | |
| self.settings_text.config(state="disabled") | |
| # Update monitor | |
| monitor_info = f"=== TRẠNG THÁI SERVER ===\n" | |
| monitor_info += f"Cập nhật: {datetime.now().strftime('%H:%M:%S')}\n\n" | |
| if 'async_processing' in data: | |
| monitor_info += "📊 Tài nguyên Async:\n" | |
| async_data = data['async_processing'] | |
| for resource, info in async_data.items(): | |
| available = info.get('available', 0) | |
| max_cap = info.get('max_capacity', 1) | |
| in_use = max(0, max_cap - available) | |
| usage_pct = (in_use / max_cap) * 100 if max_cap > 0 else 0 | |
| # Visual progress bar | |
| bar_length = 10 | |
| filled = int((usage_pct / 100) * bar_length) | |
| bar = "█" * filled + "░" * (bar_length - filled) | |
| monitor_info += f" {resource}: {bar} {usage_pct:.0f}% ({in_use}/{max_cap})\n" | |
| if 'model_info' in data: | |
| model_info = data['model_info'] | |
| monitor_info += f"\n🖥️ Model Info:\n" | |
| monitor_info += f" Device: {model_info.get('device', 'unknown')}\n" | |
| monitor_info += f" Cache: {model_info.get('reference_cache_size', 0)} giọng nói\n" | |
| cached_voices = model_info.get('cached_voices', []) | |
| if cached_voices: | |
| monitor_info += f" Cached: {', '.join(cached_voices[:3])}" | |
| if len(cached_voices) > 3: | |
| monitor_info += f" (+{len(cached_voices)-3} khác)" | |
| monitor_info += "\n" | |
| self.monitor_text.config(state="normal") | |
| self.monitor_text.delete(1.0, tk.END) | |
| self.monitor_text.insert(1.0, monitor_info) | |
| self.monitor_text.config(state="disabled") | |
| except Exception: | |
| pass # Ignore errors during startup | |
| def log_message(self, message): | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| log_msg = f"[{timestamp}] {message}\n" | |
| self.monitor_text.config(state="normal") | |
| self.monitor_text.insert(tk.END, log_msg) | |
| self.monitor_text.see(tk.END) | |
| self.monitor_text.config(state="disabled") | |
| def start_monitoring(self): | |
| def monitor_loop(): | |
| while True: | |
| try: | |
| self.update_display() | |
| except: | |
| pass | |
| time.sleep(5) | |
| import threading | |
| threading.Thread(target=monitor_loop, daemon=True).start() | |
| try: | |
| root = tk.Tk() | |
| gui = ServerGUI(root) | |
| root.mainloop() | |
| except Exception as e: | |
| print(f"GUI Error: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import threading | |
| # Start GUI in separate thread | |
| gui_thread = threading.Thread(target=start_gui, daemon=True) | |
| gui_thread.start() | |
| # Start server | |
| uvicorn.run( | |
| "tts_server:app", | |
| host="127.0.0.1", | |
| port=8000, | |
| reload=False, | |
| log_level="info" | |
| ) |