File size: 3,675 Bytes
b40215d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torchaudio
import torchaudio.transforms as T
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
import os
import numpy as np

# --- Configuration ---
# IMPORTANT: You need to agree to the terms of use for pyannote/speaker-diarization-3.1
# on Hugging Face: https://huggingface.co/pyannote/speaker-diarization-3.1
# Make sure to set your HF_TOKEN as an environment variable.
# For example: export HF_TOKEN="your_token_here"
if "HF_TOKEN" not in os.environ:
    print("WARNING: Hugging Face token not found. Diarization may fail.")
    print("Please set the HF_TOKEN environment variable with your token.")


AUDIO_FILE = "WhatsApp Audio 2026-01-24 at 12.43.45 PM.ogg"

# Choose a device (GPU if available, otherwise CPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- 1. Diarization Model ---
print("Loading diarization pipeline...")
# Use the full pipeline for speaker diarization
diarization_pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1"
)
diarization_pipeline.to(torch.device(DEVICE))
print("Diarization pipeline loaded.")


# --- 2. Transcription Model ---
# Using faster-whisper for efficient transcription
# Model options: "tiny", "base", "small", "medium", "large-v2"
# "base" is a good starting point.
transcription_model_size = "base"
print(f"Loading transcription model '{transcription_model_size}'...")
transcription_model = WhisperModel(transcription_model_size, device=DEVICE, compute_type="int8")
print("Transcription model loaded.")


# --- 3. Processing ---
if not os.path.exists(AUDIO_FILE):
    print(f"❌ Error: Audio file not found at '{AUDIO_FILE}'")
else:
    print(f"\nProcessing audio file: {AUDIO_FILE}")

    # Load full audio using torchaudio
    waveform, sample_rate = torchaudio.load(AUDIO_FILE)
    
    # Ensure audio is mono (if stereo, take one channel)
    if waveform.shape[0] > 1:
        waveform = waveform[0, :].unsqueeze(0)
        
    print(f"Original audio loaded: {waveform.shape[1]/sample_rate:.2f} seconds at {sample_rate} Hz.")

    # Resample audio to 16kHz for pyannote (common practice)
    TARGET_SAMPLE_RATE = 16000
    if sample_rate != TARGET_SAMPLE_RATE:
        print(f"Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz...")
        resampler = T.Resample(orig_freq=sample_rate, new_freq=TARGET_SAMPLE_RATE).to(DEVICE)
        waveform = resampler(waveform)
        sample_rate = TARGET_SAMPLE_RATE
        print(f"Audio resampled: {waveform.shape[1]/sample_rate:.2f} seconds at {sample_rate} Hz.")

    # --- Diarization Step ---
    print("Diarizing speakers...")
    diarization_input = {"waveform": waveform, "sample_rate": sample_rate}
    diarization = diarization_pipeline(diarization_input)
    
    print("Diarization complete.")

    # --- Transcription Step ---
    print("Transcribing segments...")
    
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        start_sample = int(turn.start * sample_rate)
        end_sample = int(turn.end * sample_rate)
        
        segment_waveform = waveform[:, start_sample:end_sample]
        
        # Convert tensor to numpy array for faster-whisper
        segment_np = segment_waveform.squeeze().cpu().numpy().astype(np.float32)

        # Transcribe the segment
        segments, _ = transcription_model.transcribe(segment_np, beam_size=5)
        
        # Combine transcribed parts for the current segment
        text = "".join(segment.text for segment in segments)
        
        # Print the result
        print(f"[{turn.start:.2f}s - {turn.end:.2f}s] {speaker}: {text.strip()}")

    print("\n✅ Processing finished.")