Spaces:
Sleeping
Sleeping
| """Audio processing and transcription logic.""" | |
| import logging | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Callable, List, Tuple | |
| from src.diarization import get_pipeline | |
| from src.vtt import create_vtt | |
| from src.whisper import TranscriptSegment, get_transcripts | |
| logger = logging.getLogger(__name__) | |
| class AudioProcessor: | |
| """Handles audio processing, diarization, and transcription.""" | |
| def __init__( | |
| self, | |
| openai_api_key: str, | |
| hf_api_key: str, | |
| transcription_model: str, | |
| pyannote_model: str, | |
| whisper_prompt: str = "", | |
| whisper_language: str | None = None | |
| ): | |
| """ | |
| Initialize AudioProcessor. | |
| Args: | |
| openai_api_key: OpenAI API key for Whisper | |
| hf_api_key: Hugging Face API key for Pyannote | |
| transcription_model: Model name for transcription | |
| pyannote_model: Model name for diarization | |
| whisper_prompt: Optional prompt for Whisper | |
| whisper_language: Optional language code for Whisper | |
| """ | |
| self.openai_api_key = openai_api_key | |
| self.hf_api_key = hf_api_key | |
| self.transcription_model = transcription_model | |
| self.pyannote_model = pyannote_model | |
| self.whisper_prompt = whisper_prompt | |
| self.whisper_language = whisper_language | |
| def process( | |
| self, | |
| audio_path: str | Path, | |
| progress_callback: Callable[[float, str], None] | None = None | |
| ) -> Tuple[str, List[TranscriptSegment], str]: | |
| """ | |
| Process audio file: diarization + transcription. | |
| Args: | |
| audio_path: Path to audio file | |
| progress_callback: Optional callback for progress updates (progress, description) | |
| Returns: | |
| Tuple of (vtt_content, transcripts, audio_filename) | |
| """ | |
| if not audio_path: | |
| return "", [], "" | |
| audio_path = Path(audio_path).absolute() | |
| tmp_dir = Path(tempfile.mkdtemp(prefix="whisper_diarization_")) | |
| logger.info(f"π Created temporary directory: {tmp_dir}") | |
| try: | |
| # Step 1: Diarization | |
| if progress_callback: | |
| progress_callback(0, "Loading diarization model...") | |
| logger.info("π Starting diarization process") | |
| audio_segment, diarization = get_pipeline( | |
| audio_path, | |
| self.hf_api_key, | |
| self.pyannote_model, | |
| tmp_dir | |
| ) | |
| if progress_callback: | |
| progress_callback(0.3, "Diarization complete. Starting transcription...") | |
| logger.info("β Diarization complete") | |
| # Step 2: Transcription | |
| total_segments = sum(1 for _ in diarization.speaker_diarization.itertracks()) | |
| logger.info(f"π Found {total_segments} segments to transcribe") | |
| def transcription_progress(i: int, total: int): | |
| if progress_callback: | |
| progress_callback( | |
| 0.3 + (0.6 * i / total), | |
| f"Transcribing segment {i}/{total}..." | |
| ) | |
| transcripts = get_transcripts( | |
| diarization, | |
| audio_segment, | |
| self.openai_api_key, | |
| self.transcription_model, | |
| self.whisper_prompt, | |
| self.whisper_language, | |
| tmp_dir, | |
| progress_callback=transcription_progress | |
| ) | |
| # Step 3: Create VTT | |
| if progress_callback: | |
| progress_callback(0.9, "Creating VTT file...") | |
| logger.info("π Creating VTT file") | |
| vtt = create_vtt(transcripts) | |
| if progress_callback: | |
| progress_callback(1.0, "Complete!") | |
| logger.info("β Process complete") | |
| audio_filename = audio_path.stem | |
| return vtt.content, transcripts, audio_filename | |
| finally: | |
| # Cleanup | |
| if progress_callback: | |
| progress_callback(0.95, "Cleaning up temporary files...") | |
| logger.info("π§Ή Cleaning up") | |
| if tmp_dir.exists(): | |
| shutil.rmtree(tmp_dir) | |
| logger.info(f"ποΈ Removed temporary directory: {tmp_dir}") | |