File size: 7,118 Bytes
6752b9e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import io
import os
import subprocess
import wave
import av
import numpy as np
import torch
def _audio_layout_name(channels: int) -> str:
return "mono" if channels == 1 else "stereo"
def _decode_audio_array(path: str, sample_rate: int | None = None, channels: int | None = None) -> tuple[np.ndarray, int]:
with av.open(path) as container:
stream = container.streams.audio[0]
source_rate = int(stream.codec_context.sample_rate or stream.rate or sample_rate or 16000)
if channels is None:
channels = int(stream.codec_context.channels or 1)
resampler = av.AudioResampler(format="fltp", layout=_audio_layout_name(channels), rate=sample_rate or source_rate)
chunks = []
for frame in container.decode(stream):
for out_frame in resampler.resample(frame):
chunks.append(out_frame.to_ndarray())
for out_frame in resampler.resample(None):
chunks.append(out_frame.to_ndarray())
if not chunks:
raise RuntimeError(f"Failed to decode audio: {path}")
audio = np.concatenate(chunks, axis=1).astype(np.float32, copy=False)
return np.ascontiguousarray(audio), int(sample_rate or source_rate)
def load_audio_tensor(path: str):
audio, sample_rate = _decode_audio_array(path)
return torch.from_numpy(audio.copy()).float(), sample_rate
def load_audio_mono(path: str, sample_rate: int) -> np.ndarray:
audio, _ = _decode_audio_array(path, sample_rate=sample_rate, channels=1)
return audio[0]
def _resample_audio_array(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
channels = 1 if audio.ndim == 1 else int(audio.shape[0])
if audio.ndim == 1:
audio = audio[np.newaxis, :]
audio = np.ascontiguousarray(audio, dtype=np.float32)
frame = av.AudioFrame.from_ndarray(audio, format="fltp", layout=_audio_layout_name(channels))
frame.sample_rate = int(src_sr)
resampler = av.AudioResampler(format="fltp", layout=_audio_layout_name(channels), rate=int(dst_sr))
chunks = []
for out_frame in resampler.resample(frame):
chunks.append(out_frame.to_ndarray())
for out_frame in resampler.resample(None):
chunks.append(out_frame.to_ndarray())
if not chunks:
raise RuntimeError(f"Failed to resample audio from {src_sr} to {dst_sr}")
return np.ascontiguousarray(np.concatenate(chunks, axis=1), dtype=np.float32)
def resample_audio_tensor(audio_tensor: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
if src_sr == dst_sr:
return audio_tensor
original_device = audio_tensor.device
original_dtype = audio_tensor.dtype
if audio_tensor.dim() == 1:
audio_np = audio_tensor.detach().cpu().float().numpy()[np.newaxis, :]
squeeze = True
else:
audio_np = audio_tensor.detach().cpu().float().numpy()
squeeze = False
result = torch.from_numpy(_resample_audio_array(audio_np, src_sr, dst_sr).copy()).to(dtype=original_dtype)
result = result.squeeze(0) if squeeze else result
return result.to(original_device)
def change_speed_int16(input_audio: np.ndarray, speed: float, sample_rate: int) -> np.ndarray:
raw_audio = input_audio.astype(np.int16, copy=False).tobytes()
process = subprocess.run(
[
"ffmpeg",
"-nostdin",
"-v",
"error",
"-f",
"s16le",
"-acodec",
"pcm_s16le",
"-ar",
str(sample_rate),
"-ac",
"1",
"-i",
"pipe:0",
"-filter:a",
f"atempo={speed}",
"-f",
"s16le",
"-acodec",
"pcm_s16le",
"pipe:1",
],
input=raw_audio,
check=True,
capture_output=True,
)
return np.frombuffer(process.stdout, dtype=np.int16)
def _normalize_audio_array(audio: np.ndarray) -> np.ndarray:
array = np.asarray(audio)
if array.ndim == 2 and array.shape[0] == 1:
array = array[0]
if array.ndim != 1:
raise ValueError(f"Only mono audio is supported, got shape {array.shape}")
return np.ascontiguousarray(array)
def _audio_frame_from_mono(audio: np.ndarray) -> av.AudioFrame:
if np.issubdtype(audio.dtype, np.floating):
pcm = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False)[np.newaxis, :]
return av.AudioFrame.from_ndarray(np.ascontiguousarray(pcm), format="fltp", layout="mono")
if audio.dtype == np.int16:
return av.AudioFrame.from_ndarray(audio[np.newaxis, :], format="s16", layout="mono")
if audio.dtype == np.int32:
return av.AudioFrame.from_ndarray(audio[np.newaxis, :], format="s32", layout="mono")
raise ValueError(f"Unsupported audio dtype for encoding: {audio.dtype}")
def write_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
audio = _normalize_audio_array(audio)
if np.issubdtype(audio.dtype, np.floating):
pcm = np.clip(audio, -1.0, 1.0)
sample_width = 2
raw = (pcm * 32767.0).astype(np.int16).tobytes()
elif audio.dtype == np.int16:
sample_width = 2
raw = audio.tobytes()
elif audio.dtype == np.int32:
sample_width = 4
raw = audio.tobytes()
else:
raise ValueError(f"Unsupported audio dtype for WAV: {audio.dtype}")
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(int(sample_rate))
wav_file.writeframes(raw)
return buffer.getvalue()
def write_wav_file(path: str, audio: np.ndarray, sample_rate: int) -> None:
with open(path, "wb") as fw:
fw.write(write_wav_bytes(audio, sample_rate))
def write_ogg_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
audio = _normalize_audio_array(audio)
buffer = io.BytesIO()
with av.open(buffer, mode="w", format="ogg") as container:
stream = container.add_stream("libvorbis", rate=int(sample_rate))
stream.layout = "mono"
frame = _audio_frame_from_mono(audio)
frame.sample_rate = int(sample_rate)
for packet in stream.encode(frame):
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)
return buffer.getvalue()
def write_audio_file(path: str, audio: np.ndarray, sample_rate: int, format: str | None = None) -> None:
target_format = (format or os.path.splitext(path)[1].lstrip(".") or "wav").lower()
if target_format == "wav":
write_wav_file(path, audio, sample_rate)
return
if target_format == "ogg":
payload = write_ogg_bytes(audio, sample_rate)
with open(path, "wb") as fw:
fw.write(payload)
return
raise ValueError(f"Unsupported output audio format: {target_format}")
|