lucamartinelli's picture
Fixes
1333284
"""Whisper + Pyannote Transcription & Diarization Web Interface."""
import logging
import tempfile
from pathlib import Path
from datetime import datetime
import gradio as gr
from src.audio_processor import AudioProcessor
from src.speaker_manager import SpeakerManager
from src.vtt_utils import validate_vtt
logging.basicConfig(level=logging.INFO)
def process_audio(
audio_path: str,
openai_api_key: str,
hf_api_key: str,
transcription_model: str,
pyannote_model: str,
openai_whisper_prompt: str,
openai_whisper_language: str | None,
progress=gr.Progress(),
):
"""
Process audio file with diarization and transcription.
Returns:
Tuple of (vtt_content, transcripts, audio_filename)
"""
if not audio_path:
return "", [], ""
processor = AudioProcessor(
openai_api_key=openai_api_key,
hf_api_key=hf_api_key,
transcription_model=transcription_model,
pyannote_model=pyannote_model,
whisper_prompt=openai_whisper_prompt,
whisper_language=openai_whisper_language,
)
return processor.process(
audio_path=audio_path, progress_callback=lambda p, desc: progress(p, desc=desc)
)
def rename_speaker_in_vtt(
vtt_content: str, transcripts_state, old_speaker: str, new_speaker: str
):
"""Rename speaker and regenerate VTT."""
if not vtt_content or not transcripts_state:
return vtt_content
return SpeakerManager.rename_speaker(transcripts_state, old_speaker, new_speaker)
def prepare_download(vtt_content: str, audio_filename: str) -> str | None:
"""
Prepare VTT file for download.
Args:
vtt_content: VTT content as string
audio_filename: Base filename for the audio
Returns:
Path to temporary VTT file, or None if inputs are invalid
"""
if not vtt_content:
return None
if not audio_filename:
audio_filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Create a unique temp directory to avoid caching issues
temp_dir = Path(tempfile.mkdtemp())
download_path = temp_dir / f"{audio_filename}.vtt"
with open(download_path, "w", encoding="utf-8") as f:
f.write(vtt_content)
return str(download_path)
with gr.Blocks(title="Transcription & Diarization") as app:
gr.Markdown(
"""
# ๐ŸŽ™๏ธ Transcription & Diarization
Fill the required settings, upload an audio file, and start the transcription using Whisper and Pyannote!
"""
)
transcripts_state = gr.State([])
audio_filename_state = gr.State("")
with gr.Row():
with gr.Column():
with gr.Accordion("โš™๏ธ Settings", open=True):
openapi_api_key = gr.Textbox(label="OpenAI API key")
hf_api_key = gr.Textbox(label="Hugging Face API key")
with gr.Accordion("โš™๏ธ Additional settings", open=False):
transcription_model = gr.Dropdown(
label="Transcription model",
choices=[("Whisper", "whisper-1")],
value="whisper-1",
)
pyannote_model = gr.Dropdown(
label="Pyannote model",
choices=[
(
"Speaker diarization community 1",
"pyannote/speaker-diarization-community-1",
)
],
value="pyannote/speaker-diarization-community-1",
)
openai_whisper_prompt = gr.Textbox(
label="Additional whisper prompt", value=""
)
openai_whisper_language = gr.Dropdown(
label="Whisper language",
choices=[
("Default (Auto-detect)", None),
("๐Ÿ‡ฎ๐Ÿ‡น Italian", "it"),
("๐Ÿ‡ฉ๐Ÿ‡ช German", "de"),
("๐Ÿ‡ฌ๐Ÿ‡ง English", "en"),
("๐Ÿ‡ช๐Ÿ‡ธ Spanish", "es"),
("๐Ÿ‡ซ๐Ÿ‡ท French", "fr"),
],
value=None,
)
audio_input = gr.Audio(type="filepath", label="Upload audio")
submit_btn = gr.Button("Transcript", variant="primary", interactive=False)
with gr.Column():
with gr.Group():
output_vtt = gr.Code(
label="Transcription",
max_lines=40,
wrap_lines=True,
)
validation_status = gr.Markdown("โšช No content", container=True)
download_btn = gr.DownloadButton(
"Download VTT", variant="primary", visible=False
)
with gr.Accordion("๐ŸŽญ Rename speakers", open=True):
with gr.Row():
old_speaker_name = gr.Textbox(
label="Current speaker name (e.g., SPEAKER_00)",
placeholder="SPEAKER_00",
value="SPEAKER_00",
)
new_speaker_name = gr.Textbox(
label="New speaker name", placeholder="Davide"
)
rename_btn = gr.Button("Rename")
def check_inputs(openai_key: str, hf_key: str, audio) -> gr.Button:
"""
Enable submit button only if both API keys and audio are provided.
Args:
openai_key: OpenAI API key
hf_key: Hugging Face API key
audio: Audio file path
Returns:
Button component with updated interactive state
"""
is_ready = bool(openai_key and hf_key and audio)
return gr.Button(interactive=is_ready)
def update_validation(vtt_content: str, audio_filename: str):
"""
Update validation status and button states when VTT content changes.
Args:
vtt_content: VTT content to validate
audio_filename: Audio filename for download
Returns:
Tuple of (status_message, download_file)
"""
status, status_type = validate_vtt(vtt_content)
# Enable buttons only if VTT is valid
is_valid = status_type == "success"
# Prepare download file if valid
file_path = None
if is_valid and vtt_content:
file_path = prepare_download(vtt_content, audio_filename)
return (
status,
gr.DownloadButton(
value=file_path, visible=bool(file_path), interactive=True
),
)
# Enable/disable submit button based on API keys and audio input
openapi_api_key.change(
fn=check_inputs,
inputs=[openapi_api_key, hf_api_key, audio_input],
outputs=submit_btn,
)
hf_api_key.change(
fn=check_inputs,
inputs=[openapi_api_key, hf_api_key, audio_input],
outputs=submit_btn,
)
audio_input.change(
fn=check_inputs,
inputs=[openapi_api_key, hf_api_key, audio_input],
outputs=submit_btn,
)
# Main transcription process
submit_btn.click(
fn=process_audio,
inputs=[
audio_input,
openapi_api_key,
hf_api_key,
transcription_model,
pyannote_model,
openai_whisper_prompt,
openai_whisper_language,
],
outputs=[output_vtt, transcripts_state, audio_filename_state],
)
# Real-time VTT validation and button state management
# We need to update validation whenever VTT content OR filename changes
output_vtt.input(
fn=update_validation,
inputs=[output_vtt, audio_filename_state],
outputs=[validation_status, download_btn],
)
audio_filename_state.change(
fn=update_validation,
inputs=[output_vtt, audio_filename_state],
outputs=[validation_status, download_btn],
)
# Speaker renaming
rename_btn.click(
fn=rename_speaker_in_vtt,
inputs=[output_vtt, transcripts_state, old_speaker_name, new_speaker_name],
outputs=output_vtt,
)
if __name__ == "__main__":
app.launch()