| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| |
|
| |
|
| | class EnergyExtractor(nn.Module): |
| | def __init__(self, hop_size: int = 512, window_size: int = 1024, |
| | padding: str = 'reflect', min_db: float = -60, |
| | norm: bool = True, quantize_levels: int = None): |
| | super().__init__() |
| | self.hop_size = hop_size |
| | self.window_size = window_size |
| | self.padding = padding |
| | self.min_db = min_db |
| | self.norm = norm |
| | self.quantize_levels = quantize_levels |
| |
|
| | def forward(self, audio: torch.Tensor) -> torch.Tensor: |
| | |
| | n_frames = int(audio.size(-1) // self.hop_size) |
| |
|
| | |
| | pad_amount = (self.window_size - self.hop_size) // 2 |
| | audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding) |
| |
|
| | |
| | audio_squared = audio_padded ** 2 |
| |
|
| | |
| | audio_squared = audio_squared[:, None, None, :] |
| | energy = F.unfold(audio_squared, (1, self.window_size), stride=self.hop_size)[:, :, :n_frames] |
| | energy = energy.mean(dim=1) |
| |
|
| | |
| | |
| |
|
| | |
| | gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device)) |
| | gain_db = 10 * torch.log10(gain) |
| |
|
| | if self.norm: |
| | |
| | |
| | min_gain_db = self.min_db |
| | max_gain_db = torch.max(gain_db, dim=-1, keepdim=True)[0] |
| |
|
| | |
| | epsilon = 1e-8 |
| | gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon) |
| |
|
| | if self.quantize_levels is not None: |
| | |
| | gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1) |
| |
|
| | return gain_db.unsqueeze(-1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | energy_extractor = EnergyExtractor(hop_size=512, window_size=1024, padding='reflect', |
| | min_db=-60, norm=True) |
| | audio = torch.rand(1, 16000) |
| | energy = energy_extractor(audio) |
| | print(energy.shape) |
| | import librosa |
| | import matplotlib.pyplot as plt |
| | |
| | |
| | |
| | a1, _ = librosa.load('eg2.wav', sr=24000) |
| | audio = torch.tensor(a1[:5*16000]).unsqueeze(0) |
| | energy = energy_extractor(audio) |
| | print(energy.shape) |
| |
|
| | |
| | plt.figure(figsize=(12, 6)) |
| |
|
| | for i in range(energy.shape[0]): |
| | plt.plot(energy[i, :, 0].cpu().numpy(), label=f'Audio {i+1}') |
| |
|
| | plt.xlabel('Frame') |
| | plt.ylabel('Energy (dB)') |
| | plt.title('Energy over Time') |
| | plt.legend() |
| | plt.savefig('debug.png') |