| import io | |
| import os | |
| import subprocess | |
| import wave | |
| import av | |
| import numpy as np | |
| import torch | |
| def _audio_layout_name(channels: int) -> str: | |
| return "mono" if channels == 1 else "stereo" | |
| def _decode_audio_array(path: str, sample_rate: int | None = None, channels: int | None = None) -> tuple[np.ndarray, int]: | |
| with av.open(path) as container: | |
| stream = container.streams.audio[0] | |
| source_rate = int(stream.codec_context.sample_rate or stream.rate or sample_rate or 16000) | |
| if channels is None: | |
| channels = int(stream.codec_context.channels or 1) | |
| resampler = av.AudioResampler(format="fltp", layout=_audio_layout_name(channels), rate=sample_rate or source_rate) | |
| chunks = [] | |
| for frame in container.decode(stream): | |
| for out_frame in resampler.resample(frame): | |
| chunks.append(out_frame.to_ndarray()) | |
| for out_frame in resampler.resample(None): | |
| chunks.append(out_frame.to_ndarray()) | |
| if not chunks: | |
| raise RuntimeError(f"Failed to decode audio: {path}") | |
| audio = np.concatenate(chunks, axis=1).astype(np.float32, copy=False) | |
| return np.ascontiguousarray(audio), int(sample_rate or source_rate) | |
| def load_audio_tensor(path: str): | |
| audio, sample_rate = _decode_audio_array(path) | |
| return torch.from_numpy(audio.copy()).float(), sample_rate | |
| def load_audio_mono(path: str, sample_rate: int) -> np.ndarray: | |
| audio, _ = _decode_audio_array(path, sample_rate=sample_rate, channels=1) | |
| return audio[0] | |
| def _resample_audio_array(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray: | |
| channels = 1 if audio.ndim == 1 else int(audio.shape[0]) | |
| if audio.ndim == 1: | |
| audio = audio[np.newaxis, :] | |
| audio = np.ascontiguousarray(audio, dtype=np.float32) | |
| frame = av.AudioFrame.from_ndarray(audio, format="fltp", layout=_audio_layout_name(channels)) | |
| frame.sample_rate = int(src_sr) | |
| resampler = av.AudioResampler(format="fltp", layout=_audio_layout_name(channels), rate=int(dst_sr)) | |
| chunks = [] | |
| for out_frame in resampler.resample(frame): | |
| chunks.append(out_frame.to_ndarray()) | |
| for out_frame in resampler.resample(None): | |
| chunks.append(out_frame.to_ndarray()) | |
| if not chunks: | |
| raise RuntimeError(f"Failed to resample audio from {src_sr} to {dst_sr}") | |
| return np.ascontiguousarray(np.concatenate(chunks, axis=1), dtype=np.float32) | |
| def resample_audio_tensor(audio_tensor: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: | |
| if src_sr == dst_sr: | |
| return audio_tensor | |
| original_device = audio_tensor.device | |
| original_dtype = audio_tensor.dtype | |
| if audio_tensor.dim() == 1: | |
| audio_np = audio_tensor.detach().cpu().float().numpy()[np.newaxis, :] | |
| squeeze = True | |
| else: | |
| audio_np = audio_tensor.detach().cpu().float().numpy() | |
| squeeze = False | |
| result = torch.from_numpy(_resample_audio_array(audio_np, src_sr, dst_sr).copy()).to(dtype=original_dtype) | |
| result = result.squeeze(0) if squeeze else result | |
| return result.to(original_device) | |
| def change_speed_int16(input_audio: np.ndarray, speed: float, sample_rate: int) -> np.ndarray: | |
| raw_audio = input_audio.astype(np.int16, copy=False).tobytes() | |
| process = subprocess.run( | |
| [ | |
| "ffmpeg", | |
| "-nostdin", | |
| "-v", | |
| "error", | |
| "-f", | |
| "s16le", | |
| "-acodec", | |
| "pcm_s16le", | |
| "-ar", | |
| str(sample_rate), | |
| "-ac", | |
| "1", | |
| "-i", | |
| "pipe:0", | |
| "-filter:a", | |
| f"atempo={speed}", | |
| "-f", | |
| "s16le", | |
| "-acodec", | |
| "pcm_s16le", | |
| "pipe:1", | |
| ], | |
| input=raw_audio, | |
| check=True, | |
| capture_output=True, | |
| ) | |
| return np.frombuffer(process.stdout, dtype=np.int16) | |
| def _normalize_audio_array(audio: np.ndarray) -> np.ndarray: | |
| array = np.asarray(audio) | |
| if array.ndim == 2 and array.shape[0] == 1: | |
| array = array[0] | |
| if array.ndim != 1: | |
| raise ValueError(f"Only mono audio is supported, got shape {array.shape}") | |
| return np.ascontiguousarray(array) | |
| def _audio_frame_from_mono(audio: np.ndarray) -> av.AudioFrame: | |
| if np.issubdtype(audio.dtype, np.floating): | |
| pcm = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False)[np.newaxis, :] | |
| return av.AudioFrame.from_ndarray(np.ascontiguousarray(pcm), format="fltp", layout="mono") | |
| if audio.dtype == np.int16: | |
| return av.AudioFrame.from_ndarray(audio[np.newaxis, :], format="s16", layout="mono") | |
| if audio.dtype == np.int32: | |
| return av.AudioFrame.from_ndarray(audio[np.newaxis, :], format="s32", layout="mono") | |
| raise ValueError(f"Unsupported audio dtype for encoding: {audio.dtype}") | |
| def write_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: | |
| audio = _normalize_audio_array(audio) | |
| if np.issubdtype(audio.dtype, np.floating): | |
| pcm = np.clip(audio, -1.0, 1.0) | |
| sample_width = 2 | |
| raw = (pcm * 32767.0).astype(np.int16).tobytes() | |
| elif audio.dtype == np.int16: | |
| sample_width = 2 | |
| raw = audio.tobytes() | |
| elif audio.dtype == np.int32: | |
| sample_width = 4 | |
| raw = audio.tobytes() | |
| else: | |
| raise ValueError(f"Unsupported audio dtype for WAV: {audio.dtype}") | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav_file: | |
| wav_file.setnchannels(1) | |
| wav_file.setsampwidth(sample_width) | |
| wav_file.setframerate(int(sample_rate)) | |
| wav_file.writeframes(raw) | |
| return buffer.getvalue() | |
| def write_wav_file(path: str, audio: np.ndarray, sample_rate: int) -> None: | |
| with open(path, "wb") as fw: | |
| fw.write(write_wav_bytes(audio, sample_rate)) | |
| def write_ogg_bytes(audio: np.ndarray, sample_rate: int) -> bytes: | |
| audio = _normalize_audio_array(audio) | |
| buffer = io.BytesIO() | |
| with av.open(buffer, mode="w", format="ogg") as container: | |
| stream = container.add_stream("libvorbis", rate=int(sample_rate)) | |
| stream.layout = "mono" | |
| frame = _audio_frame_from_mono(audio) | |
| frame.sample_rate = int(sample_rate) | |
| for packet in stream.encode(frame): | |
| container.mux(packet) | |
| for packet in stream.encode(None): | |
| container.mux(packet) | |
| return buffer.getvalue() | |
| def write_audio_file(path: str, audio: np.ndarray, sample_rate: int, format: str | None = None) -> None: | |
| target_format = (format or os.path.splitext(path)[1].lstrip(".") or "wav").lower() | |
| if target_format == "wav": | |
| write_wav_file(path, audio, sample_rate) | |
| return | |
| if target_format == "ogg": | |
| payload = write_ogg_bytes(audio, sample_rate) | |
| with open(path, "wb") as fw: | |
| fw.write(payload) | |
| return | |
| raise ValueError(f"Unsupported output audio format: {target_format}") | |