File size: 6,221 Bytes
ed147e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""

Whisper-based speech-to-text transcription module.

Converts audio files to text using OpenAI's Whisper model.

"""

from pathlib import Path
from typing import Dict, List, Optional
import whisper
import torch

from src.utils.logger import setup_logger
from src.utils.config import settings

logger = setup_logger(__name__)


class WhisperTranscriber:
    """Handles audio transcription using Whisper ASR model."""
    
    def __init__(self, model_size: Optional[str] = None):
        """

        Initialize the Whisper transcriber.

        

        Args:

            model_size: Whisper model size (tiny, base, small, medium, large)

                       Defaults to config setting

        """
        self.model_size = model_size or settings.whisper_model_size
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        logger.info(f"Initializing Whisper transcriber with model: {self.model_size}")
        logger.info(f"Using device: {self.device}")
    
    def load_model(self) -> None:
        """Load the Whisper model into memory."""
        if self.model is not None:
            logger.info("Model already loaded")
            return
        
        try:
            logger.info(f"Loading Whisper {self.model_size} model...")
            self.model = whisper.load_model(self.model_size, device=self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load Whisper model: {e}")
            raise RuntimeError(f"Model loading failed: {str(e)}")
    
    def transcribe(

        self,

        audio_path: Path,

        language: str = "en",

        verbose: bool = True

    ) -> Dict[str, any]:
        """

        Transcribe audio file to text.

        

        Args:

            audio_path: Path to the audio file

            language: Language code (default: "en" for English)

            verbose: Whether to show progress during transcription

            

        Returns:

            Dictionary containing:

                - text: Full transcript

                - segments: List of timestamped segments

                - language: Detected/specified language

                

        Raises:

            FileNotFoundError: If audio file doesn't exist

            RuntimeError: If transcription fails

        """
        if not audio_path.exists():
            raise FileNotFoundError(f"Audio file not found: {audio_path}")
        
        # Load model if not already loaded
        self.load_model()
        
        try:
            logger.info(f"Starting transcription of: {audio_path}")
            logger.info(f"Language: {language}")
            
            # Transcribe with Whisper
            result = self.model.transcribe(
                str(audio_path),
                language=language,
                verbose=verbose,
                task="transcribe",
                fp16=torch.cuda.is_available()  # Use FP16 on GPU for speed
            )
            
            # Extract relevant information
            transcript_data = {
                'text': result['text'].strip(),
                'segments': self._process_segments(result['segments']),
                'language': result['language'],
            }
            
            logger.info(f"Transcription complete. Length: {len(transcript_data['text'])} characters")
            logger.info(f"Number of segments: {len(transcript_data['segments'])}")
            
            return transcript_data
            
        except Exception as e:
            logger.error(f"Transcription failed: {e}")
            raise RuntimeError(f"Transcription error: {str(e)}")
    
    def _process_segments(self, raw_segments: List[Dict]) -> List[Dict]:
        """

        Process raw Whisper segments into a cleaner format.

        

        Args:

            raw_segments: Raw segment data from Whisper

            

        Returns:

            List of processed segments with timestamps and text

        """
        processed = []
        
        for segment in raw_segments:
            processed.append({
                'id': segment['id'],
                'start': segment['start'],
                'end': segment['end'],
                'text': segment['text'].strip(),
            })
        
        return processed
    
    def transcribe_with_timestamps(

        self,

        audio_path: Path,

        language: str = "en"

    ) -> str:
        """

        Transcribe audio and format with timestamps.

        

        Args:

            audio_path: Path to the audio file

            language: Language code

            

        Returns:

            Formatted transcript with timestamps

        """
        result = self.transcribe(audio_path, language, verbose=False)
        
        formatted_lines = []
        for segment in result['segments']:
            timestamp = self._format_timestamp(segment['start'])
            formatted_lines.append(f"[{timestamp}] {segment['text']}")
        
        return "\n".join(formatted_lines)
    
    @staticmethod
    def _format_timestamp(seconds: float) -> str:
        """

        Format seconds into MM:SS or HH:MM:SS.

        

        Args:

            seconds: Time in seconds

            

        Returns:

            Formatted timestamp string

        """
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = int(seconds % 60)
        
        if hours > 0:
            return f"{hours:02d}:{minutes:02d}:{secs:02d}"
        else:
            return f"{minutes:02d}:{secs:02d}"
    
    def get_plain_text(self, audio_path: Path, language: str = "en") -> str:
        """

        Get plain text transcript without timestamps.

        

        Args:

            audio_path: Path to the audio file

            language: Language code

            

        Returns:

            Plain text transcript

        """
        result = self.transcribe(audio_path, language, verbose=False)
        return result['text']