Spaces:
Running
Running
| # utils.py | |
| # Utility functions for the Dia TTS server | |
| import logging | |
| import time | |
| import os | |
| import io | |
| import numpy as np | |
| import soundfile as sf | |
| from typing import Optional, Tuple | |
| logger = logging.getLogger(__name__) | |
| # --- Audio Processing --- | |
| def encode_audio( | |
| audio_array: np.ndarray, sample_rate: int, output_format: str = "opus" | |
| ) -> Optional[bytes]: | |
| """ | |
| Encodes a NumPy audio array into the specified format in memory. | |
| Args: | |
| audio_array: NumPy array containing audio data (float32, range [-1, 1]). | |
| sample_rate: Sample rate of the audio data. | |
| output_format: Desired output format ('opus' or 'wav'). | |
| Returns: | |
| Bytes object containing the encoded audio, or None on failure. | |
| """ | |
| if audio_array is None or audio_array.size == 0: | |
| logger.warning("encode_audio received empty or None audio array.") | |
| return None | |
| start_time = time.time() | |
| output_buffer = io.BytesIO() | |
| try: | |
| if output_format == "opus": | |
| # Soundfile expects int16 for Opus usually, but let's try float32 first | |
| # It might convert internally or require specific subtypes. | |
| # If this fails, we might need to convert to int16 first: | |
| # audio_int16 = (audio_array * 32767).astype(np.int16) | |
| # sf.write(output_buffer, audio_int16, sample_rate, format='ogg', subtype='opus') | |
| sf.write( | |
| output_buffer, audio_array, sample_rate, format="ogg", subtype="opus" | |
| ) | |
| content_type = "audio/ogg; codecs=opus" | |
| elif output_format == "wav": | |
| # WAV typically uses int16 | |
| audio_int16 = (audio_array * 32767).astype(np.int16) | |
| sf.write( | |
| output_buffer, audio_int16, sample_rate, format="wav", subtype="pcm_16" | |
| ) | |
| content_type = "audio/wav" | |
| else: | |
| logger.error(f"Unsupported output format requested: {output_format}") | |
| return None | |
| encoded_bytes = output_buffer.getvalue() | |
| end_time = time.time() | |
| logger.info( | |
| f"Encoded {len(encoded_bytes)} bytes to {output_format} in {end_time - start_time:.3f} seconds." | |
| ) | |
| return encoded_bytes | |
| except ImportError: | |
| logger.critical( | |
| "`soundfile` or its dependency `libsndfile` not found/installed correctly. Cannot encode audio." | |
| ) | |
| raise # Re-raise critical error | |
| except Exception as e: | |
| logger.error(f"Error encoding audio to {output_format}: {e}", exc_info=True) | |
| return None | |
| def save_audio_to_file( | |
| audio_array: np.ndarray, sample_rate: int, file_path: str | |
| ) -> bool: | |
| """ | |
| Saves a NumPy audio array to a WAV file. | |
| Args: | |
| audio_array: NumPy array containing audio data (float32, range [-1, 1]). | |
| sample_rate: Sample rate of the audio data. | |
| file_path: Path to save the WAV file. | |
| Returns: | |
| True if saving was successful, False otherwise. | |
| """ | |
| if audio_array is None or audio_array.size == 0: | |
| logger.warning("save_audio_to_file received empty or None audio array.") | |
| return False | |
| if not file_path.lower().endswith(".wav"): | |
| logger.warning( | |
| f"File path '{file_path}' does not end with .wav. Saving as WAV anyway." | |
| ) | |
| # Optionally change the extension: file_path += ".wav" | |
| start_time = time.time() | |
| try: | |
| # Ensure output directory exists | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| # WAV typically uses int16 | |
| audio_int16 = (audio_array * 32767).astype(np.int16) | |
| sf.write(file_path, audio_int16, sample_rate, format="wav", subtype="pcm_16") | |
| end_time = time.time() | |
| logger.info( | |
| f"Saved WAV file to {file_path} in {end_time - start_time:.3f} seconds." | |
| ) | |
| return True | |
| except ImportError: | |
| logger.critical( | |
| "`soundfile` or its dependency `libsndfile` not found/installed correctly. Cannot save audio." | |
| ) | |
| return False # Indicate failure | |
| except Exception as e: | |
| logger.error(f"Error saving WAV file to {file_path}: {e}", exc_info=True) | |
| return False | |
| # --- Other Utilities (Optional) --- | |
| class PerformanceMonitor: | |
| """Simple performance monitoring.""" | |
| def __init__(self): | |
| self.start_time = time.time() | |
| self.events = [] | |
| def record(self, event_name: str): | |
| self.events.append((event_name, time.time())) | |
| def report(self) -> str: | |
| report_lines = ["Performance Report:"] | |
| last_time = self.start_time | |
| total_duration = time.time() - self.start_time | |
| for name, timestamp in self.events: | |
| duration = timestamp - last_time | |
| report_lines.append(f" - {name}: {duration:.3f}s") | |
| last_time = timestamp | |
| report_lines.append(f"Total Duration: {total_duration:.3f}s") | |
| return "\n".join(report_lines) | |