multimodalart's picture
Upload 25 files
bbb0e68 verified
"""Audio processing utilities for KugelAudio."""
import os
from typing import Optional, Union, List, Dict, Any
import numpy as np
import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AudioNormalizer:
"""Normalize audio to target dB FS level.
This ensures consistent input levels for the model while
maintaining audio quality and avoiding clipping.
"""
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
self.target_dB_FS = target_dB_FS
self.eps = eps
def normalize_db(self, audio: np.ndarray) -> tuple:
"""Adjust audio to target dB FS level."""
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
return audio * scalar, rms, scalar
def avoid_clipping(self, audio: np.ndarray) -> tuple:
"""Scale down if necessary to avoid clipping."""
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
return audio / scalar, scalar
return audio, 1.0
def __call__(self, audio: np.ndarray) -> np.ndarray:
"""Normalize audio: adjust dB FS then avoid clipping."""
audio, _, _ = self.normalize_db(audio)
audio, _ = self.avoid_clipping(audio)
return audio
class AudioProcessor(FeatureExtractionMixin):
"""Processor for audio preprocessing and postprocessing.
Handles:
- Audio format conversion (stereo to mono)
- Normalization
- Loading from various file formats
- Saving to WAV files
Example:
>>> processor = AudioProcessor(sampling_rate=24000)
>>> audio = processor("path/to/audio.wav")
>>> processor.save_audio(generated_audio, "output.wav")
"""
model_input_names = ["input_features"]
def __init__(
self,
sampling_rate: int = 24000,
normalize_audio: bool = True,
target_dB_FS: float = -25,
eps: float = 1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate
self.normalize_audio = normalize_audio
self.normalizer = AudioNormalizer(target_dB_FS, eps) if normalize_audio else None
self.feature_extractor_dict = {
"sampling_rate": sampling_rate,
"normalize_audio": normalize_audio,
"target_dB_FS": target_dB_FS,
"eps": eps,
}
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
"""Convert stereo to mono if needed."""
if len(audio.shape) == 1:
return audio
elif len(audio.shape) == 2:
if audio.shape[0] == 2:
return np.mean(audio, axis=0)
elif audio.shape[1] == 2:
return np.mean(audio, axis=1)
elif audio.shape[0] == 1:
return audio.squeeze(0)
elif audio.shape[1] == 1:
return audio.squeeze(1)
else:
raise ValueError(f"Unexpected audio shape: {audio.shape}")
else:
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
def _process_single(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
"""Process a single audio array."""
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
else:
audio = audio.astype(np.float32)
audio = self._ensure_mono(audio)
if self.normalize_audio and self.normalizer:
audio = self.normalizer(audio)
return audio
def _load_from_path(self, audio_path: str) -> np.ndarray:
"""Load audio from file path."""
ext = os.path.splitext(audio_path)[1].lower()
if ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]:
import librosa
audio, _ = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
return audio
elif ext == ".pt":
tensor = torch.load(audio_path, map_location="cpu", weights_only=True).squeeze()
return tensor.numpy().astype(np.float32)
elif ext == ".npy":
return np.load(audio_path).astype(np.float32)
else:
raise ValueError(f"Unsupported format: {ext}")
def __call__(
self,
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[str]] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
"""Process audio input(s).
Args:
audio: Audio input - path, array, or list of either
sampling_rate: Input sampling rate (for validation)
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy)
Returns:
Dictionary with processed audio
"""
if audio is None:
raise ValueError("Audio input is required")
if sampling_rate is not None and sampling_rate != self.sampling_rate:
logger.warning(
f"Input sampling rate ({sampling_rate}) differs from expected ({self.sampling_rate}). "
"Please resample your audio."
)
# Handle different input types
if isinstance(audio, str):
audio = self._load_from_path(audio)
is_batched = False
elif isinstance(audio, list):
if all(isinstance(item, str) for item in audio):
audio = [self._load_from_path(p) for p in audio]
is_batched = True
else:
is_batched = isinstance(audio[0], (np.ndarray, list))
else:
is_batched = False
# Process
if is_batched:
processed = [self._process_single(a) for a in audio]
else:
processed = [self._process_single(audio)]
# Convert to tensors
if return_tensors == "pt":
if len(processed) == 1:
features = torch.from_numpy(processed[0]).unsqueeze(0).unsqueeze(1)
else:
features = torch.stack([torch.from_numpy(a) for a in processed]).unsqueeze(1)
elif return_tensors == "np":
if len(processed) == 1:
features = processed[0][np.newaxis, np.newaxis, :]
else:
features = np.stack(processed)[:, np.newaxis, :]
else:
features = processed[0] if len(processed) == 1 else processed
return {"audio": features}
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> List[str]:
"""Save audio to WAV file(s).
Args:
audio: Audio data to save
output_path: Output path (directory for batched audio)
sampling_rate: Sampling rate (defaults to processor's rate)
normalize: Whether to normalize before saving
batch_prefix: Prefix for batch files
Returns:
List of saved file paths
"""
import soundfile as sf
if sampling_rate is None:
sampling_rate = self.sampling_rate
# Convert to numpy
if isinstance(audio, torch.Tensor):
audio_np = audio.float().detach().cpu().numpy()
elif isinstance(audio, list):
if all(isinstance(a, torch.Tensor) for a in audio):
audio_np = [a.float().detach().cpu().numpy() for a in audio]
else:
audio_np = audio
else:
audio_np = audio
saved_paths = []
if isinstance(audio_np, list):
os.makedirs(output_path, exist_ok=True)
for i, item in enumerate(audio_np):
item = self._prepare_for_save(item, normalize)
path = os.path.join(output_path, f"{batch_prefix}{i}.wav")
sf.write(path, item, sampling_rate)
saved_paths.append(path)
elif len(audio_np.shape) >= 3 and audio_np.shape[0] > 1:
os.makedirs(output_path, exist_ok=True)
for i in range(audio_np.shape[0]):
item = audio_np[i].squeeze()
item = self._prepare_for_save(item, normalize)
path = os.path.join(output_path, f"{batch_prefix}{i}.wav")
sf.write(path, item, sampling_rate)
saved_paths.append(path)
else:
item = self._prepare_for_save(audio_np.squeeze(), normalize)
sf.write(output_path, item, sampling_rate)
saved_paths.append(output_path)
return saved_paths
def _prepare_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
"""Prepare audio for saving."""
if len(audio.shape) > 1 and audio.shape[0] == 1:
audio = audio.squeeze(0)
if normalize:
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return self.feature_extractor_dict