Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| from pyannote.audio import Model | |
| from pyannote.core import Annotation, Segment | |
| # 1. PATHS | |
| CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt" | |
| TEST_AUDIO = "dataset/audio/clip_07.wav" | |
| def run_test(): | |
| print(f"Loading model directly...") | |
| model = Model.from_pretrained(CHECKPOINT_PATH) | |
| model.eval() # Set to evaluation mode | |
| # 2. Load Audio Manually | |
| waveform, sample_rate = torchaudio.load(TEST_AUDIO) | |
| # Model expects [batch, channels, samples] - adding a batch dimension | |
| if waveform.ndim == 2: | |
| waveform = waveform.unsqueeze(0) | |
| print("Running raw inference...") | |
| with torch.no_grad(): | |
| # Get raw scores [batch, frames, speakers] | |
| # This returns probabilities for each speaker class | |
| scores = model(waveform) | |
| # 3. Simple thresholding to find speakers | |
| # If score > 0.5, we consider that speaker "active" | |
| print("\n--- Raw Model Detections ---") | |
| # We'll use a very simple logic to show you what the model sees | |
| # The output usually has several speaker 'slots' (e.g., 7 slots) | |
| num_speakers = scores.shape[-1] | |
| # Moving average/thresholding logic | |
| # (Simplified for debugging) | |
| for s in range(num_speakers): | |
| active_frames = torch.where(scores[0, :, s] > 0.5)[0] | |
| if len(active_frames) > 0: | |
| # Just showing first and last detection for this slot to keep it clean | |
| start_time = active_frames[0] * 0.016 # Approximate frame shift | |
| end_time = active_frames[-1] * 0.016 | |
| print(f"Speaker Slot {s}: Detected activity between {start_time:.2f}s and {end_time:.2f}s") | |
| if __name__ == "__main__": | |
| run_test() |