import torch import torchaudio import numpy as np from model import BoundaryDetectionModel # Assume the model definition is in model.py from audio_dataset import pad_audio # Use the provided padding function def load_model(checkpoint_path, device): model = BoundaryDetectionModel().to(device) model.load_state_dict(torch.load(checkpoint_path, map_location=device)["model_state_dict"]) model.eval() return model def preprocess_audio(audio_path, sample_rate=16000, target_length=8): waveform, sr = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) waveform = pad_audio(waveform, sample_rate, target_length) print(waveform.shape) return waveform def infer_single_audio(model, audio_path, device): audio_tensor = preprocess_audio(audio_path).to(device) with torch.no_grad(): output = model(audio_tensor).squeeze(-1).cpu().numpy() # Remove extra dimensions prediction = (output > 0.5).astype(int) # Round outputs for binary prediction if needed return output, prediction def main_inference(audio_path, checkpoint_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model(checkpoint_path, device) print(f"Running inference on: {audio_path}") output, prediction = infer_single_audio(model, audio_path, device) print(f"Model Output: {output}") print(f"Binary Prediction: {prediction}") if __name__ == "__main__": audio_path = "Real\RFP_R_24918.wav" # Path to the audio file for inference checkpoint_path = "checkpoint_epoch_21_eer_0.24.pth" # Path to the trained model checkpoint main_inference(audio_path, checkpoint_path)