whisper-diarization / src /audio_processor.py
lucamartinelli's picture
Gradio
dd5bcef
"""Audio processing and transcription logic."""
import logging
import shutil
import tempfile
from pathlib import Path
from typing import Callable, List, Tuple
from src.diarization import get_pipeline
from src.vtt import create_vtt
from src.whisper import TranscriptSegment, get_transcripts
logger = logging.getLogger(__name__)
class AudioProcessor:
"""Handles audio processing, diarization, and transcription."""
def __init__(
self,
openai_api_key: str,
hf_api_key: str,
transcription_model: str,
pyannote_model: str,
whisper_prompt: str = "",
whisper_language: str | None = None
):
"""
Initialize AudioProcessor.
Args:
openai_api_key: OpenAI API key for Whisper
hf_api_key: Hugging Face API key for Pyannote
transcription_model: Model name for transcription
pyannote_model: Model name for diarization
whisper_prompt: Optional prompt for Whisper
whisper_language: Optional language code for Whisper
"""
self.openai_api_key = openai_api_key
self.hf_api_key = hf_api_key
self.transcription_model = transcription_model
self.pyannote_model = pyannote_model
self.whisper_prompt = whisper_prompt
self.whisper_language = whisper_language
def process(
self,
audio_path: str | Path,
progress_callback: Callable[[float, str], None] | None = None
) -> Tuple[str, List[TranscriptSegment], str]:
"""
Process audio file: diarization + transcription.
Args:
audio_path: Path to audio file
progress_callback: Optional callback for progress updates (progress, description)
Returns:
Tuple of (vtt_content, transcripts, audio_filename)
"""
if not audio_path:
return "", [], ""
audio_path = Path(audio_path).absolute()
tmp_dir = Path(tempfile.mkdtemp(prefix="whisper_diarization_"))
logger.info(f"πŸ“ Created temporary directory: {tmp_dir}")
try:
# Step 1: Diarization
if progress_callback:
progress_callback(0, "Loading diarization model...")
logger.info("πŸ”„ Starting diarization process")
audio_segment, diarization = get_pipeline(
audio_path,
self.hf_api_key,
self.pyannote_model,
tmp_dir
)
if progress_callback:
progress_callback(0.3, "Diarization complete. Starting transcription...")
logger.info("βœ… Diarization complete")
# Step 2: Transcription
total_segments = sum(1 for _ in diarization.speaker_diarization.itertracks())
logger.info(f"πŸ“Š Found {total_segments} segments to transcribe")
def transcription_progress(i: int, total: int):
if progress_callback:
progress_callback(
0.3 + (0.6 * i / total),
f"Transcribing segment {i}/{total}..."
)
transcripts = get_transcripts(
diarization,
audio_segment,
self.openai_api_key,
self.transcription_model,
self.whisper_prompt,
self.whisper_language,
tmp_dir,
progress_callback=transcription_progress
)
# Step 3: Create VTT
if progress_callback:
progress_callback(0.9, "Creating VTT file...")
logger.info("πŸ“ Creating VTT file")
vtt = create_vtt(transcripts)
if progress_callback:
progress_callback(1.0, "Complete!")
logger.info("βœ… Process complete")
audio_filename = audio_path.stem
return vtt.content, transcripts, audio_filename
finally:
# Cleanup
if progress_callback:
progress_callback(0.95, "Cleaning up temporary files...")
logger.info("🧹 Cleaning up")
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
logger.info(f"πŸ—‘οΈ Removed temporary directory: {tmp_dir}")