File size: 3,393 Bytes
f7e0be6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d1795f
f7e0be6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from io import BytesIO
import os
import gradio as gr
import spaces
import torch
from pyannote.audio import Pipeline
import torchaudio
from pydub import AudioSegment
from pyannote.audio import Pipeline
import json
import requests


# Authenticate with Huggingface
AUTH_TOKEN = os.getenv("HF_TOKEN")

# Load the diarization pipeline
device = torch.device("cuda")
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-community-1", 
    token=AUTH_TOKEN).to(device)

def preprocess_audio(audio_path):
    """Convert audio to mono, 16kHz WAV format suitable for pyannote."""
    try:
        if isinstance(audio_path, str):
            bytes = False
        else:
            bytes = True
            
        # Load audio with pydub
        audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path)
        # Convert to mono and set sample rate to 16kHz
        audio = audio.set_channels(1).set_frame_rate(16000)
        # Export to temporary WAV file
        temp_wav = "temp_audio.wav"
        audio.export(temp_wav, format="wav")
        return temp_wav
    except Exception as e:
        raise ValueError(f"Error preprocessing audio: {str(e)}")

def handle_audio(url, audio_path, num_speakers):
    """Handle audio processing and diarization."""
    if url:
        response = requests.get(url, timeout=60)
        audio_path = response.content

    audio_path = preprocess_audio(audio_path)
    res = diarize_audio(audio_path, num_speakers)
    # Clean up temporary file
    if os.path.exists(audio_path):
        os.remove(audio_path)
    return res

    
@spaces.GPU(duration=180)
def diarize_audio(audio_path, num_speakers):
    """Perform speaker diarization and return formatted results."""
    try:
        # Load audio for pyannote
        waveform, sample_rate = torchaudio.load(audio_path)
        audio_dict = {"waveform": waveform, "sample_rate": sample_rate}

        # Configure pipeline with number of speakers
        pipeline_params = {"num_speakers": num_speakers} if num_speakers > 0 else { "min_speakers": 2, "max_speakers": 6 }
        diarization = pipeline(audio_dict, **pipeline_params)

        # Format results
        results = []
        for turn, speaker in diarization.exclusive_speaker_diarization:
            result = {
                "start": round(turn.start, 3),
                "end": round(turn.end, 3),
                "speaker_id": speaker
            }
            results.append(result)

        return json.dumps(results, indent=2)

    except Exception as e:
        return f"Error: {str(e)}", ""

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Speaker Diarization with speaker-diarization-community-1")
    gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.")
    
    with gr.Row():
        url_input = gr.Textbox(label="URL")
        audio_input = gr.Audio(label="Upload Audio File", type="filepath")
        num_speakers = gr.Slider(minimum=0, maximum=10, step=1, label="Number of Speakers", value=2)
    
    submit_btn = gr.Button("Diarize")
    
    with gr.Row():
        json_output = gr.Textbox(label="Diarization Results (JSON)")
    
    submit_btn.click(
        fn=handle_audio,
        inputs=[url_input, audio_input, num_speakers],
        outputs=[json_output],
        concurrency_limit=2,
    )

# Launch the Gradio app
demo.launch()