lucamartinelli's picture
Gradio
dd5bcef
raw
history blame
7.95 kB
"""Whisper + Pyannote Transcription & Diarization Web Interface."""
import logging
import tempfile
from pathlib import Path
import gradio as gr
from src.audio_processor import AudioProcessor
from src.speaker_manager import SpeakerManager
from src.vtt_utils import clean_vtt, validate_vtt
logging.basicConfig(level=logging.INFO)
def process_audio(
audio_path: str,
openai_api_key: str,
hf_api_key: str,
transcription_model: str,
pyannote_model: str,
openai_whisper_prompt: str,
openai_whisper_language: str | None,
progress=gr.Progress()
):
"""
Process audio file with diarization and transcription.
Returns:
Tuple of (vtt_content, transcripts, audio_filename)
"""
if not audio_path:
return "", [], ""
processor = AudioProcessor(
openai_api_key=openai_api_key,
hf_api_key=hf_api_key,
transcription_model=transcription_model,
pyannote_model=pyannote_model,
whisper_prompt=openai_whisper_prompt,
whisper_language=openai_whisper_language
)
return processor.process(
audio_path=audio_path,
progress_callback=lambda p, desc: progress(p, desc=desc)
)
def rename_speaker_in_vtt(vtt_content: str, transcripts_state, old_speaker: str, new_speaker: str):
"""Rename speaker and regenerate VTT."""
if not vtt_content or not transcripts_state:
return vtt_content
return SpeakerManager.rename_speaker(transcripts_state, old_speaker, new_speaker)
def prepare_download(vtt_content: str, audio_filename: str) -> str | None:
"""
Prepare VTT file for download.
Args:
vtt_content: VTT content as string
audio_filename: Base filename for the audio
Returns:
Path to temporary VTT file, or None if inputs are invalid
"""
if not vtt_content or not audio_filename:
return None
download_path = Path(tempfile.gettempdir()) / f"{audio_filename}.vtt"
with open(download_path, 'w', encoding='utf-8') as f:
f.write(vtt_content)
return str(download_path)
with gr.Blocks(title="Transcription & Diarization") as app:
gr.Markdown("""
# ๐ŸŽ™๏ธ Transcription & Diarization
Fill the required settings, upload an audio file, and start the transcription using Whisper and Pyannote!
""")
transcripts_state = gr.State([])
audio_filename_state = gr.State("")
with gr.Row():
with gr.Column():
with gr.Accordion("โš™๏ธ Settings", open=True):
openapi_api_key = gr.Textbox(label="OpenAI API key", type="password")
hf_api_key = gr.Textbox(label="Hugging Face API key", type="password")
with gr.Accordion("โš™๏ธ Additional settings", open=False):
transcription_model = gr.Dropdown(
label="Transcription model",
choices=[("Whisper", "whisper-1")],
value="whisper-1"
)
pyannote_model = gr.Dropdown(
label="Pyannote model",
choices=[("Speaker diarization community 1", "pyannote/speaker-diarization-community-1")],
value="pyannote/speaker-diarization-community-1"
)
openai_whisper_prompt = gr.Textbox(label="Additional whisper prompt", value="")
openai_whisper_language = gr.Dropdown(
label="Whisper language",
choices=[
("Default (Auto-detect)", None),
("๐Ÿ‡ฎ๐Ÿ‡น Italian", "it"),
("๐Ÿ‡ฉ๐Ÿ‡ช German", "de"),
("๐Ÿ‡ฌ๐Ÿ‡ง English", "en"),
("๐Ÿ‡ช๐Ÿ‡ธ Spanish", "es"),
("๐Ÿ‡ซ๐Ÿ‡ท French", "fr"),
],
value=None
)
audio_input = gr.Audio(type="filepath", label="Upload audio")
submit_btn = gr.Button("Transcript", variant="primary", interactive=False)
with gr.Column():
with gr.Group():
output_vtt = gr.Textbox(
label="Transcription",
lines=20,
placeholder="Your transcription will appear here...",
buttons=["copy"],
container=False,
)
validation_status = gr.Markdown("โšช No content", container=True)
with gr.Row():
clean_btn = gr.Button("Clean & improve VTT", variant="secondary", interactive=False)
download_file = gr.File(label="Download VTT", visible=False)
download_btn = gr.Button("Download VTT", variant="secondary", interactive=False)
with gr.Accordion("๐ŸŽญ Rename speakers", open=False):
with gr.Row():
old_speaker_name = gr.Textbox(label="Current speaker name (e.g., SPEAKER_00)", placeholder="SPEAKER_00", value="SPEAKER_00")
new_speaker_name = gr.Textbox(label="New speaker name", placeholder="Davide")
rename_btn = gr.Button("Rename")
def check_inputs(openai_key: str, hf_key: str, audio) -> gr.Button:
"""
Enable submit button only if both API keys and audio are provided.
Args:
openai_key: OpenAI API key
hf_key: Hugging Face API key
audio: Audio file path
Returns:
Button component with updated interactive state
"""
is_ready = bool(openai_key and hf_key and audio)
return gr.Button(interactive=is_ready)
def update_validation(vtt_content: str):
"""
Update validation status and button states when VTT content changes.
Args:
vtt_content: VTT content to validate
Returns:
Tuple of (status_message, clean_button, download_button)
"""
status, status_type = validate_vtt(vtt_content)
# Enable buttons only if VTT is valid
is_valid = status_type == "success"
return (
status,
gr.Button(interactive=is_valid), # clean_btn
gr.Button(interactive=is_valid) # download_btn
)
# Enable/disable submit button based on API keys and audio input
openapi_api_key.change(
fn=check_inputs,
inputs=[openapi_api_key, hf_api_key, audio_input],
outputs=submit_btn
)
hf_api_key.change(
fn=check_inputs,
inputs=[openapi_api_key, hf_api_key, audio_input],
outputs=submit_btn
)
audio_input.change(
fn=check_inputs,
inputs=[openapi_api_key, hf_api_key, audio_input],
outputs=submit_btn
)
# Main transcription process
submit_btn.click(
fn=process_audio,
inputs=[
audio_input,
openapi_api_key,
hf_api_key,
transcription_model,
pyannote_model,
openai_whisper_prompt,
openai_whisper_language
],
outputs=[output_vtt, transcripts_state, audio_filename_state],
)
# Real-time VTT validation and button state management
output_vtt.change(
fn=update_validation,
inputs=[output_vtt],
outputs=[validation_status, clean_btn, download_btn]
)
# VTT cleaning and improvement
clean_btn.click(
fn=clean_vtt,
inputs=[output_vtt],
outputs=[output_vtt]
)
# VTT file download
download_btn.click(
fn=prepare_download,
inputs=[output_vtt, audio_filename_state],
outputs=[download_file]
)
# Speaker renaming
rename_btn.click(
fn=rename_speaker_in_vtt,
inputs=[output_vtt, transcripts_state, old_speaker_name, new_speaker_name],
outputs=output_vtt
)
if __name__ == "__main__":
app.launch()