File size: 8,309 Bytes
dd5bcef
ce8528d
dd5bcef
 
 
ce8528d
dd5bcef
 
 
 
 
3d5ee3a
dd5bcef
 
 
 
 
 
 
 
 
 
 
 
ce8528d
dd5bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce8528d
dd5bcef
 
 
ce8528d
dd5bcef
 
 
ce8528d
 
 
dd5bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce8528d
dd5bcef
 
ce8528d
 
 
 
 
 
dd5bcef
ce8528d
dd5bcef
 
26055ba
dd5bcef
 
 
 
ce8528d
 
dd5bcef
 
ce8528d
 
dd5bcef
 
 
 
 
 
 
2a77a56
 
dd5bcef
 
 
 
 
ce8528d
dd5bcef
 
 
ce8528d
 
 
 
 
 
 
dd5bcef
 
ce8528d
 
 
dd5bcef
 
 
 
 
 
 
 
 
 
ce8528d
dd5bcef
 
 
 
 
 
 
1333284
dd5bcef
1333284
 
dd5bcef
 
 
 
ce8528d
 
 
 
379a259
dd5bcef
ce8528d
 
 
 
 
 
 
 
727ff34
dd5bcef
26055ba
dd5bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26055ba
dd5bcef
 
 
 
 
26055ba
dd5bcef
 
3d5ee3a
dd5bcef
 
 
 
 
 
26055ba
 
ce8528d
26055ba
 
dd5bcef
 
ce8528d
 
 
dd5bcef
 
 
 
 
 
ce8528d
dd5bcef
 
 
 
ce8528d
dd5bcef
 
 
 
ce8528d
dd5bcef
 
 
 
 
 
 
 
 
 
 
 
ce8528d
dd5bcef
 
 
 
 
26055ba
1333284
dd5bcef
26055ba
ce8528d
26055ba
 
 
 
 
ce8528d
dd5bcef
 
 
 
 
 
ce8528d
dd5bcef
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""Whisper + Pyannote Transcription & Diarization Web Interface."""

import logging
import tempfile
from pathlib import Path
from datetime import datetime

import gradio as gr

from src.audio_processor import AudioProcessor
from src.speaker_manager import SpeakerManager
from src.vtt_utils import validate_vtt

logging.basicConfig(level=logging.INFO)


def process_audio(
    audio_path: str,
    openai_api_key: str,
    hf_api_key: str,
    transcription_model: str,
    pyannote_model: str,
    openai_whisper_prompt: str,
    openai_whisper_language: str | None,
    progress=gr.Progress(),
):
    """
    Process audio file with diarization and transcription.

    Returns:
        Tuple of (vtt_content, transcripts, audio_filename)
    """
    if not audio_path:
        return "", [], ""

    processor = AudioProcessor(
        openai_api_key=openai_api_key,
        hf_api_key=hf_api_key,
        transcription_model=transcription_model,
        pyannote_model=pyannote_model,
        whisper_prompt=openai_whisper_prompt,
        whisper_language=openai_whisper_language,
    )

    return processor.process(
        audio_path=audio_path, progress_callback=lambda p, desc: progress(p, desc=desc)
    )


def rename_speaker_in_vtt(
    vtt_content: str, transcripts_state, old_speaker: str, new_speaker: str
):
    """Rename speaker and regenerate VTT."""
    if not vtt_content or not transcripts_state:
        return vtt_content

    return SpeakerManager.rename_speaker(transcripts_state, old_speaker, new_speaker)


def prepare_download(vtt_content: str, audio_filename: str) -> str | None:
    """
    Prepare VTT file for download.

    Args:
        vtt_content: VTT content as string
        audio_filename: Base filename for the audio

    Returns:
        Path to temporary VTT file, or None if inputs are invalid
    """
    if not vtt_content:
        return None

    if not audio_filename:
        audio_filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # Create a unique temp directory to avoid caching issues
    temp_dir = Path(tempfile.mkdtemp())
    download_path = temp_dir / f"{audio_filename}.vtt"

    with open(download_path, "w", encoding="utf-8") as f:
        f.write(vtt_content)

    return str(download_path)


with gr.Blocks(title="Transcription & Diarization") as app:

    gr.Markdown(
        """
                # ๐ŸŽ™๏ธ Transcription & Diarization
                Fill the required settings, upload an audio file, and start the transcription using Whisper and Pyannote!
                """
    )

    transcripts_state = gr.State([])
    audio_filename_state = gr.State("")

    with gr.Row():
        with gr.Column():
            with gr.Accordion("โš™๏ธ Settings", open=True):
                openapi_api_key = gr.Textbox(label="OpenAI API key")
                hf_api_key = gr.Textbox(label="Hugging Face API key")

            with gr.Accordion("โš™๏ธ Additional settings", open=False):
                transcription_model = gr.Dropdown(
                    label="Transcription model",
                    choices=[("Whisper", "whisper-1")],
                    value="whisper-1",
                )
                pyannote_model = gr.Dropdown(
                    label="Pyannote model",
                    choices=[
                        (
                            "Speaker diarization community 1",
                            "pyannote/speaker-diarization-community-1",
                        )
                    ],
                    value="pyannote/speaker-diarization-community-1",
                )

                openai_whisper_prompt = gr.Textbox(
                    label="Additional whisper prompt", value=""
                )
                openai_whisper_language = gr.Dropdown(
                    label="Whisper language",
                    choices=[
                        ("Default (Auto-detect)", None),
                        ("๐Ÿ‡ฎ๐Ÿ‡น Italian", "it"),
                        ("๐Ÿ‡ฉ๐Ÿ‡ช German", "de"),
                        ("๐Ÿ‡ฌ๐Ÿ‡ง English", "en"),
                        ("๐Ÿ‡ช๐Ÿ‡ธ Spanish", "es"),
                        ("๐Ÿ‡ซ๐Ÿ‡ท French", "fr"),
                    ],
                    value=None,
                )

            audio_input = gr.Audio(type="filepath", label="Upload audio")
            submit_btn = gr.Button("Transcript", variant="primary", interactive=False)

        with gr.Column():
            with gr.Group():
                output_vtt = gr.Code(
                    label="Transcription",
                    max_lines=40,
                    wrap_lines=True,
                )

                validation_status = gr.Markdown("โšช No content", container=True)

            download_btn = gr.DownloadButton(
                "Download VTT", variant="primary", visible=False
            )

            with gr.Accordion("๐ŸŽญ Rename speakers", open=True):
                with gr.Row():
                    old_speaker_name = gr.Textbox(
                        label="Current speaker name (e.g., SPEAKER_00)",
                        placeholder="SPEAKER_00",
                        value="SPEAKER_00",
                    )
                    new_speaker_name = gr.Textbox(
                        label="New speaker name", placeholder="Davide"
                    )

                rename_btn = gr.Button("Rename")

    def check_inputs(openai_key: str, hf_key: str, audio) -> gr.Button:
        """
        Enable submit button only if both API keys and audio are provided.

        Args:
            openai_key: OpenAI API key
            hf_key: Hugging Face API key
            audio: Audio file path

        Returns:
            Button component with updated interactive state
        """
        is_ready = bool(openai_key and hf_key and audio)
        return gr.Button(interactive=is_ready)

    def update_validation(vtt_content: str, audio_filename: str):
        """
        Update validation status and button states when VTT content changes.

        Args:
            vtt_content: VTT content to validate
            audio_filename: Audio filename for download

        Returns:
            Tuple of (status_message, download_file)
        """
        status, status_type = validate_vtt(vtt_content)

        # Enable buttons only if VTT is valid
        is_valid = status_type == "success"

        # Prepare download file if valid
        file_path = None
        if is_valid and vtt_content:
            file_path = prepare_download(vtt_content, audio_filename)

        return (
            status,
            gr.DownloadButton(
                value=file_path, visible=bool(file_path), interactive=True
            ),
        )

    # Enable/disable submit button based on API keys and audio input
    openapi_api_key.change(
        fn=check_inputs,
        inputs=[openapi_api_key, hf_api_key, audio_input],
        outputs=submit_btn,
    )
    hf_api_key.change(
        fn=check_inputs,
        inputs=[openapi_api_key, hf_api_key, audio_input],
        outputs=submit_btn,
    )
    audio_input.change(
        fn=check_inputs,
        inputs=[openapi_api_key, hf_api_key, audio_input],
        outputs=submit_btn,
    )

    # Main transcription process
    submit_btn.click(
        fn=process_audio,
        inputs=[
            audio_input,
            openapi_api_key,
            hf_api_key,
            transcription_model,
            pyannote_model,
            openai_whisper_prompt,
            openai_whisper_language,
        ],
        outputs=[output_vtt, transcripts_state, audio_filename_state],
    )

    # Real-time VTT validation and button state management
    # We need to update validation whenever VTT content OR filename changes
    output_vtt.input(
        fn=update_validation,
        inputs=[output_vtt, audio_filename_state],
        outputs=[validation_status, download_btn],
    )

    audio_filename_state.change(
        fn=update_validation,
        inputs=[output_vtt, audio_filename_state],
        outputs=[validation_status, download_btn],
    )

    # Speaker renaming
    rename_btn.click(
        fn=rename_speaker_in_vtt,
        inputs=[output_vtt, transcripts_state, old_speaker_name, new_speaker_name],
        outputs=output_vtt,
    )

if __name__ == "__main__":
    app.launch()