Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from model import BoundaryDetectionModel # Assume the model definition is in model.py | |
| from audio_dataset import pad_audio # Use the provided padding function | |
| def load_model(checkpoint_path, device): | |
| model = BoundaryDetectionModel().to(device) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)["model_state_dict"]) | |
| model.eval() | |
| return model | |
| def preprocess_audio(audio_path, sample_rate=16000, target_length=8): | |
| waveform, sr = torchaudio.load(audio_path) | |
| waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) | |
| waveform = pad_audio(waveform, sample_rate, target_length) | |
| print(waveform.shape) | |
| return waveform | |
| def infer_single_audio(model, audio_path, device): | |
| audio_tensor = preprocess_audio(audio_path).to(device) | |
| with torch.no_grad(): | |
| output = model(audio_tensor).squeeze(-1).cpu().numpy() # Remove extra dimensions | |
| prediction = (output > 0.5).astype(int) # Round outputs for binary prediction if needed | |
| return output, prediction | |
| def main_inference(audio_path, checkpoint_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = load_model(checkpoint_path, device) | |
| print(f"Running inference on: {audio_path}") | |
| output, prediction = infer_single_audio(model, audio_path, device) | |
| print(f"Model Output: {output}") | |
| print(f"Binary Prediction: {prediction}") | |
| if __name__ == "__main__": | |
| audio_path = "Real\RFP_R_24918.wav" # Path to the audio file for inference | |
| checkpoint_path = "checkpoint_epoch_21_eer_0.24.pth" # Path to the trained model checkpoint | |
| main_inference(audio_path, checkpoint_path) | |