Spaces:
Runtime error
Runtime error
| """ | |
| TTS Model Module | |
| ================ | |
| Handles model loading, inference optimization, and audio generation. | |
| Implements caching, mixed precision, and efficient batch processing. | |
| """ | |
| import os | |
| import logging | |
| import time | |
| from typing import Dict, List, Tuple, Optional, Union | |
| from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class OptimizedTTSModel: | |
| """Optimized TTS model with caching and performance enhancements.""" | |
| def __init__(self, | |
| checkpoint: str = "Edmon02/TTS_NB_2", | |
| vocoder_checkpoint: str = "microsoft/speecht5_hifigan", | |
| device: Optional[str] = None, | |
| use_mixed_precision: bool = True, | |
| cache_embeddings: bool = True): | |
| """ | |
| Initialize the optimized TTS model. | |
| Args: | |
| checkpoint: Model checkpoint path | |
| vocoder_checkpoint: Vocoder checkpoint path | |
| device: Device to use ('cuda', 'cpu', or None for auto) | |
| use_mixed_precision: Whether to use mixed precision inference | |
| cache_embeddings: Whether to cache speaker embeddings | |
| """ | |
| self.checkpoint = checkpoint | |
| self.vocoder_checkpoint = vocoder_checkpoint | |
| self.use_mixed_precision = use_mixed_precision | |
| self.cache_embeddings = cache_embeddings | |
| # Auto-detect device | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| logger.info(f"Using device: {self.device}") | |
| # Initialize components | |
| self.processor = None | |
| self.model = None | |
| self.vocoder = None | |
| self.speaker_embeddings = {} | |
| self.embedding_cache = {} | |
| # Performance tracking | |
| self.inference_times = [] | |
| # Load models | |
| self._load_models() | |
| self._load_speaker_embeddings() | |
| def _load_models(self): | |
| """Load TTS model, processor, and vocoder.""" | |
| try: | |
| logger.info("Loading TTS models...") | |
| start_time = time.time() | |
| # Load processor | |
| self.processor = SpeechT5Processor.from_pretrained(self.checkpoint) | |
| # Load main model | |
| self.model = SpeechT5ForTextToSpeech.from_pretrained(self.checkpoint) | |
| self.model.to(self.device) | |
| self.model.eval() # Set to evaluation mode | |
| # Load vocoder | |
| self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_checkpoint) | |
| self.vocoder.to(self.device) | |
| self.vocoder.eval() | |
| # Enable mixed precision if supported | |
| if self.use_mixed_precision and self.device.type == "cuda": | |
| self.model.half() | |
| self.vocoder.half() | |
| logger.info("Mixed precision enabled") | |
| load_time = time.time() - start_time | |
| logger.info(f"Models loaded in {load_time:.2f}s") | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| raise | |
| def _load_speaker_embeddings(self): | |
| """Load speaker embeddings from .npy files.""" | |
| try: | |
| # Define available speaker embeddings | |
| embedding_files = { | |
| "BDL": "nb_620.npy", | |
| # Add more speakers as needed | |
| } | |
| base_path = Path(__file__).parent.parent | |
| for speaker, filename in embedding_files.items(): | |
| filepath = base_path / filename | |
| if filepath.exists(): | |
| embedding = np.load(filepath).astype(np.float32) | |
| self.speaker_embeddings[speaker] = torch.tensor(embedding).to(self.device) | |
| logger.info(f"Loaded embedding for speaker {speaker}") | |
| else: | |
| logger.warning(f"Speaker embedding file not found: {filepath}") | |
| if not self.speaker_embeddings: | |
| raise FileNotFoundError("No speaker embeddings found") | |
| except Exception as e: | |
| logger.error(f"Failed to load speaker embeddings: {e}") | |
| raise | |
| def _get_speaker_embedding(self, speaker: str) -> torch.Tensor: | |
| """ | |
| Get speaker embedding with caching. | |
| Args: | |
| speaker: Speaker identifier | |
| Returns: | |
| Speaker embedding tensor | |
| """ | |
| # Extract speaker code (first 3 characters) | |
| speaker_code = speaker[:3].upper() | |
| if speaker_code not in self.speaker_embeddings: | |
| logger.warning(f"Speaker {speaker_code} not found, using default") | |
| speaker_code = list(self.speaker_embeddings.keys())[0] | |
| # Return cached embedding with batch dimension | |
| embedding = self.speaker_embeddings[speaker_code] | |
| return embedding.unsqueeze(0) # Add batch dimension | |
| def _preprocess_text(self, text: str) -> torch.Tensor: | |
| """ | |
| Preprocess text for model input. | |
| Args: | |
| text: Input text | |
| Returns: | |
| Processed input tensor | |
| """ | |
| if not text.strip(): | |
| return None | |
| # Process text | |
| inputs = self.processor(text=text, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(self.device) | |
| # Limit input length to model's maximum | |
| max_length = getattr(self.model.config, 'max_text_positions', 600) | |
| input_ids = input_ids[..., :max_length] | |
| return input_ids | |
| def generate_speech(self, text: str, speaker: str = "BDL") -> Tuple[int, np.ndarray]: | |
| """ | |
| Generate speech from text. | |
| Args: | |
| text: Input text | |
| speaker: Speaker identifier | |
| Returns: | |
| Tuple of (sample_rate, audio_array) | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Handle empty text | |
| if not text or not text.strip(): | |
| logger.warning("Empty text provided") | |
| return 16000, np.zeros(0, dtype=np.int16) | |
| # Preprocess text | |
| input_ids = self._preprocess_text(text) | |
| if input_ids is None: | |
| return 16000, np.zeros(0, dtype=np.int16) | |
| # Get speaker embedding | |
| speaker_embedding = self._get_speaker_embedding(speaker) | |
| # Generate speech with mixed precision if enabled | |
| if self.use_mixed_precision and self.device.type == "cuda": | |
| with torch.cuda.amp.autocast(): | |
| speech = self.model.generate_speech( | |
| input_ids, | |
| speaker_embedding, | |
| vocoder=self.vocoder | |
| ) | |
| else: | |
| speech = self.model.generate_speech( | |
| input_ids, | |
| speaker_embedding, | |
| vocoder=self.vocoder | |
| ) | |
| # Convert to numpy and scale to int16 | |
| speech_np = speech.cpu().numpy() | |
| speech_int16 = (speech_np * 32767).astype(np.int16) | |
| # Track performance | |
| inference_time = time.time() - start_time | |
| self.inference_times.append(inference_time) | |
| logger.info(f"Generated {len(speech_int16)} samples in {inference_time:.3f}s") | |
| return 16000, speech_int16 | |
| except Exception as e: | |
| logger.error(f"Speech generation failed: {e}") | |
| return 16000, np.zeros(0, dtype=np.int16) | |
| def generate_speech_chunks(self, text_chunks: List[str], speaker: str = "BDL") -> Tuple[int, np.ndarray]: | |
| """ | |
| Generate speech from multiple text chunks and concatenate. | |
| Args: | |
| text_chunks: List of text chunks | |
| speaker: Speaker identifier | |
| Returns: | |
| Tuple of (sample_rate, concatenated_audio_array) | |
| """ | |
| if not text_chunks: | |
| return 16000, np.zeros(0, dtype=np.int16) | |
| logger.info(f"Generating speech for {len(text_chunks)} chunks") | |
| audio_segments = [] | |
| total_start_time = time.time() | |
| for i, chunk in enumerate(text_chunks): | |
| logger.debug(f"Processing chunk {i+1}/{len(text_chunks)}") | |
| sample_rate, audio = self.generate_speech(chunk, speaker) | |
| if len(audio) > 0: | |
| audio_segments.append(audio) | |
| if not audio_segments: | |
| logger.warning("No audio generated from chunks") | |
| return 16000, np.zeros(0, dtype=np.int16) | |
| # Concatenate all audio segments | |
| concatenated_audio = np.concatenate(audio_segments) | |
| total_time = time.time() - total_start_time | |
| logger.info(f"Generated {len(concatenated_audio)} samples from {len(text_chunks)} chunks in {total_time:.3f}s") | |
| return 16000, concatenated_audio | |
| def batch_generate_speech(self, texts: List[str], speaker: str = "BDL") -> List[Tuple[int, np.ndarray]]: | |
| """ | |
| Generate speech for multiple texts (batch processing). | |
| Args: | |
| texts: List of input texts | |
| speaker: Speaker identifier | |
| Returns: | |
| List of (sample_rate, audio_array) tuples | |
| """ | |
| results = [] | |
| for text in texts: | |
| result = self.generate_speech(text, speaker) | |
| results.append(result) | |
| return results | |
| def get_performance_stats(self) -> Dict[str, float]: | |
| """Get performance statistics.""" | |
| if not self.inference_times: | |
| return {"avg_inference_time": 0.0, "total_inferences": 0} | |
| return { | |
| "avg_inference_time": np.mean(self.inference_times), | |
| "min_inference_time": np.min(self.inference_times), | |
| "max_inference_time": np.max(self.inference_times), | |
| "total_inferences": len(self.inference_times) | |
| } | |
| def clear_performance_cache(self): | |
| """Clear performance tracking data.""" | |
| self.inference_times.clear() | |
| logger.info("Performance cache cleared") | |
| def get_available_speakers(self) -> List[str]: | |
| """Get list of available speakers.""" | |
| return list(self.speaker_embeddings.keys()) | |
| def optimize_for_inference(self): | |
| """Apply additional optimizations for inference.""" | |
| try: | |
| if hasattr(torch.backends, 'cudnn'): | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = False | |
| # Compile model for better performance (PyTorch 2.0+) | |
| if hasattr(torch, 'compile') and self.device.type == "cuda": | |
| logger.info("Compiling model for optimization...") | |
| self.model = torch.compile(self.model) | |
| self.vocoder = torch.compile(self.vocoder) | |
| logger.info("Model optimization completed") | |
| except Exception as e: | |
| logger.warning(f"Model optimization failed: {e}") | |
| def warmup(self, warmup_text: str = "Բարև ձեզ"): | |
| """ | |
| Warm up the model with a simple inference. | |
| Args: | |
| warmup_text: Text to use for warmup | |
| """ | |
| logger.info("Warming up model...") | |
| try: | |
| _ = self.generate_speech(warmup_text) | |
| logger.info("Model warmup completed") | |
| except Exception as e: | |
| logger.warning(f"Model warmup failed: {e}") | |