File size: 4,327 Bytes
dd5bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Audio processing and transcription logic."""
import logging
import shutil
import tempfile
from pathlib import Path
from typing import Callable, List, Tuple

from src.diarization import get_pipeline
from src.vtt import create_vtt
from src.whisper import TranscriptSegment, get_transcripts

logger = logging.getLogger(__name__)


class AudioProcessor:
    """Handles audio processing, diarization, and transcription."""

    def __init__(
        self,
        openai_api_key: str,
        hf_api_key: str,
        transcription_model: str,
        pyannote_model: str,
        whisper_prompt: str = "",
        whisper_language: str | None = None
    ):
        """
        Initialize AudioProcessor.

        Args:
            openai_api_key: OpenAI API key for Whisper
            hf_api_key: Hugging Face API key for Pyannote
            transcription_model: Model name for transcription
            pyannote_model: Model name for diarization
            whisper_prompt: Optional prompt for Whisper
            whisper_language: Optional language code for Whisper
        """
        self.openai_api_key = openai_api_key
        self.hf_api_key = hf_api_key
        self.transcription_model = transcription_model
        self.pyannote_model = pyannote_model
        self.whisper_prompt = whisper_prompt
        self.whisper_language = whisper_language

    def process(
        self,
        audio_path: str | Path,
        progress_callback: Callable[[float, str], None] | None = None
    ) -> Tuple[str, List[TranscriptSegment], str]:
        """
        Process audio file: diarization + transcription.

        Args:
            audio_path: Path to audio file
            progress_callback: Optional callback for progress updates (progress, description)

        Returns:
            Tuple of (vtt_content, transcripts, audio_filename)
        """
        if not audio_path:
            return "", [], ""

        audio_path = Path(audio_path).absolute()
        tmp_dir = Path(tempfile.mkdtemp(prefix="whisper_diarization_"))
        logger.info(f"πŸ“ Created temporary directory: {tmp_dir}")

        try:
            # Step 1: Diarization
            if progress_callback:
                progress_callback(0, "Loading diarization model...")
            logger.info("πŸ”„ Starting diarization process")

            audio_segment, diarization = get_pipeline(
                audio_path,
                self.hf_api_key,
                self.pyannote_model,
                tmp_dir
            )

            if progress_callback:
                progress_callback(0.3, "Diarization complete. Starting transcription...")
            logger.info("βœ… Diarization complete")

            # Step 2: Transcription
            total_segments = sum(1 for _ in diarization.speaker_diarization.itertracks())
            logger.info(f"πŸ“Š Found {total_segments} segments to transcribe")

            def transcription_progress(i: int, total: int):
                if progress_callback:
                    progress_callback(
                        0.3 + (0.6 * i / total),
                        f"Transcribing segment {i}/{total}..."
                    )

            transcripts = get_transcripts(
                diarization,
                audio_segment,
                self.openai_api_key,
                self.transcription_model,
                self.whisper_prompt,
                self.whisper_language,
                tmp_dir,
                progress_callback=transcription_progress
            )

            # Step 3: Create VTT
            if progress_callback:
                progress_callback(0.9, "Creating VTT file...")
            logger.info("πŸ“ Creating VTT file")

            vtt = create_vtt(transcripts)

            if progress_callback:
                progress_callback(1.0, "Complete!")
            logger.info("βœ… Process complete")

            audio_filename = audio_path.stem
            return vtt.content, transcripts, audio_filename

        finally:
            # Cleanup
            if progress_callback:
                progress_callback(0.95, "Cleaning up temporary files...")
            logger.info("🧹 Cleaning up")

            if tmp_dir.exists():
                shutil.rmtree(tmp_dir)
                logger.info(f"πŸ—‘οΈ Removed temporary directory: {tmp_dir}")