Spaces:
Sleeping
Sleeping
File size: 6,221 Bytes
ed147e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """
Whisper-based speech-to-text transcription module.
Converts audio files to text using OpenAI's Whisper model.
"""
from pathlib import Path
from typing import Dict, List, Optional
import whisper
import torch
from src.utils.logger import setup_logger
from src.utils.config import settings
logger = setup_logger(__name__)
class WhisperTranscriber:
"""Handles audio transcription using Whisper ASR model."""
def __init__(self, model_size: Optional[str] = None):
"""
Initialize the Whisper transcriber.
Args:
model_size: Whisper model size (tiny, base, small, medium, large)
Defaults to config setting
"""
self.model_size = model_size or settings.whisper_model_size
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing Whisper transcriber with model: {self.model_size}")
logger.info(f"Using device: {self.device}")
def load_model(self) -> None:
"""Load the Whisper model into memory."""
if self.model is not None:
logger.info("Model already loaded")
return
try:
logger.info(f"Loading Whisper {self.model_size} model...")
self.model = whisper.load_model(self.model_size, device=self.device)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Whisper model: {e}")
raise RuntimeError(f"Model loading failed: {str(e)}")
def transcribe(
self,
audio_path: Path,
language: str = "en",
verbose: bool = True
) -> Dict[str, any]:
"""
Transcribe audio file to text.
Args:
audio_path: Path to the audio file
language: Language code (default: "en" for English)
verbose: Whether to show progress during transcription
Returns:
Dictionary containing:
- text: Full transcript
- segments: List of timestamped segments
- language: Detected/specified language
Raises:
FileNotFoundError: If audio file doesn't exist
RuntimeError: If transcription fails
"""
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# Load model if not already loaded
self.load_model()
try:
logger.info(f"Starting transcription of: {audio_path}")
logger.info(f"Language: {language}")
# Transcribe with Whisper
result = self.model.transcribe(
str(audio_path),
language=language,
verbose=verbose,
task="transcribe",
fp16=torch.cuda.is_available() # Use FP16 on GPU for speed
)
# Extract relevant information
transcript_data = {
'text': result['text'].strip(),
'segments': self._process_segments(result['segments']),
'language': result['language'],
}
logger.info(f"Transcription complete. Length: {len(transcript_data['text'])} characters")
logger.info(f"Number of segments: {len(transcript_data['segments'])}")
return transcript_data
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise RuntimeError(f"Transcription error: {str(e)}")
def _process_segments(self, raw_segments: List[Dict]) -> List[Dict]:
"""
Process raw Whisper segments into a cleaner format.
Args:
raw_segments: Raw segment data from Whisper
Returns:
List of processed segments with timestamps and text
"""
processed = []
for segment in raw_segments:
processed.append({
'id': segment['id'],
'start': segment['start'],
'end': segment['end'],
'text': segment['text'].strip(),
})
return processed
def transcribe_with_timestamps(
self,
audio_path: Path,
language: str = "en"
) -> str:
"""
Transcribe audio and format with timestamps.
Args:
audio_path: Path to the audio file
language: Language code
Returns:
Formatted transcript with timestamps
"""
result = self.transcribe(audio_path, language, verbose=False)
formatted_lines = []
for segment in result['segments']:
timestamp = self._format_timestamp(segment['start'])
formatted_lines.append(f"[{timestamp}] {segment['text']}")
return "\n".join(formatted_lines)
@staticmethod
def _format_timestamp(seconds: float) -> str:
"""
Format seconds into MM:SS or HH:MM:SS.
Args:
seconds: Time in seconds
Returns:
Formatted timestamp string
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
else:
return f"{minutes:02d}:{secs:02d}"
def get_plain_text(self, audio_path: Path, language: str = "en") -> str:
"""
Get plain text transcript without timestamps.
Args:
audio_path: Path to the audio file
language: Language code
Returns:
Plain text transcript
"""
result = self.transcribe(audio_path, language, verbose=False)
return result['text']
|