Spaces:
Sleeping
Sleeping
File size: 3,377 Bytes
a12db03 d339e38 a3ea780 d339e38 a12db03 d339e38 a3ea780 a12db03 a3ea780 d339e38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from src.models.cnn import CNN
from src.data.augment import AudioAugment
from src.config.config import ProcessingConfig, DatasetConfig, TrainConfig
config = ProcessingConfig()
class AudioPredictor:
def __init__(
self,
model_path: str,
config: ProcessingConfig = config,
device: str = 'cuda'
) -> None:
self.config = config
self.audio_dataset = AudioAugment()
self.dataset_config = DatasetConfig()
self.train_config = TrainConfig()
self.device = device
self.model = self._load_model(model_path)
def _load_model(self, model_path: str) -> CNN:
model = CNN(n_classes=len(self.dataset_config.esc50_labels))
checkpoint = torch.load(model_path, map_location=self.device)
state_dict = checkpoint.get("model_state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint
model.load_state_dict(state_dict)
if isinstance(checkpoint, dict) and "best_val_acc" in checkpoint:
print(f"Model validation accuracy: {checkpoint['best_val_acc']:.4f}")
model.to(self.device).eval()
print("Model loaded successfully!\n")
return model
def _extract_patches(self, spectrogram: np.ndarray, hop: int) -> torch.Tensor:
n_frames, _ = spectrogram.shape
if n_frames < self.dataset_config.cnn_input_length:
spectrogram = np.pad(spectrogram, ((0, self.dataset_config.cnn_input_length - n_frames), (0, 0)), mode="constant")
n_frames = self.dataset_config.cnn_input_length
patches = np.concatenate([
spectrogram[s:s + self.dataset_config.cnn_input_length][np.newaxis, np.newaxis]
for s in range(0, n_frames - self.dataset_config.cnn_input_length + 1, hop)
], axis=0)
return torch.tensor(patches, dtype=torch.float32).to(self.device)
def _run_inference(self, patches: torch.Tensor, batch_size: int) -> torch.Tensor:
all_outputs = []
with torch.no_grad():
for i in range(0, len(patches), batch_size):
all_outputs.append(self.model(patches[i:i + batch_size]))
return torch.cat(all_outputs, dim=0).mean(dim=0)
def predict_class(self, spectrogram: np.ndarray, hop: int = 1) -> int:
patches = self._extract_patches(spectrogram, hop)
mean_activations = self._run_inference(patches, self.train_config.batch_size)
return mean_activations.argmax().item()
def predict_top_k(self, spectrogram: np.ndarray, hop: int = 1, top_k: int = 5):
patches = self._extract_patches(spectrogram, hop)
mean_logits = self._run_inference(patches, self.train_config.batch_size)
probs = F.softmax(mean_logits, dim=0)
top_probs, top_indices = torch.topk(probs, min(top_k, len(self.dataset_config.esc50_labels)))
return top_probs.cpu().numpy(), top_indices.cpu().numpy()
def predict_file(self, audio_file: str, top_k: int = 5):
spectrogram = np.array(self.audio_dataset._data_treatment_testing(audio_file)).squeeze()
predicted_class = self.predict_class(spectrogram)
top_probs, top_indices = self.predict_top_k(spectrogram, top_k=top_k)
return predicted_class, top_probs, top_indices
|