File size: 12,762 Bytes
d01de5d 130ce6d d01de5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
"""
Automatic Speech Recognition (ASR) Engine for Carsa AI
A comprehensive ASR engine that converts speech audio to text using
state-of-the-art speech recognition models. Optimized for English speech
recognition with support for various audio formats.
Features:
- High-quality speech-to-text conversion
- Support for WAV, MP3, and other audio formats
- Automatic audio preprocessing
- GPU acceleration when available
- Robust error handling
Author: Carsa AI Team
Version: 1.0.0
"""
import torch
import logging
import io
import tempfile
import os
import soundfile as sf
from transformers import pipeline
import librosa
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ASREngine:
"""
A production-ready Automatic Speech Recognition engine.
This class provides speech-to-text capabilities using Hugging Face's
transformers library with Whisper or similar ASR models.
"""
def __init__(self, model_name="openai/whisper-small"):
"""
Initialize the ASR Engine.
Args:
model_name (str): The ASR model to use. Default: "openai/whisper-base"
Options: "openai/whisper-tiny", "openai/whisper-base",
"openai/whisper-small", "openai/whisper-medium"
Raises:
Exception: If model loading fails
"""
try:
self.device = 0 if torch.cuda.is_available() else -1
device_name = "GPU" if torch.cuda.is_available() else "CPU"
logger.info(f"ASR Engine using device: {device_name}")
self.model_name = model_name
self.sample_rate = 16000 # Whisper expects 16kHz audio
logger.info(f"Loading ASR model: {model_name}")
# Load the ASR pipeline
self.transcriber = pipeline(
"automatic-speech-recognition",
model=model_name,
device=self.device,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
return_timestamps=False # Set to True if you want word-level timestamps
)
logger.info("β
ASR Engine initialized successfully!")
except Exception as e:
logger.error(f"β Failed to initialize ASR Engine: {e}")
raise Exception(f"ASR Engine initialization failed: {str(e)}")
def _preprocess_audio(self, audio_bytes):
"""
Preprocess audio data for speech recognition.
Args:
audio_bytes (bytes): Raw audio data
Returns:
np.ndarray: Preprocessed audio array
Raises:
Exception: If audio preprocessing fails
"""
try:
# First try using BytesIO (faster method)
try:
audio_file = io.BytesIO(audio_bytes)
audio_data, sr = sf.read(audio_file)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample if needed
if sr != self.sample_rate:
import librosa
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate)
# Normalize and enhance audio
if len(audio_data) > 0:
# Remove DC offset
audio_data = audio_data - np.mean(audio_data)
# Apply normalization
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val
# Apply gentle noise gate (remove very quiet sections)
noise_floor = 0.01 # 1% threshold
audio_data = np.where(np.abs(audio_data) < noise_floor, 0, audio_data)
logger.info(f"Audio preprocessed (BytesIO): {len(audio_data)} samples at {self.sample_rate}Hz")
return audio_data
except Exception as e1:
logger.warning(f"BytesIO method failed: {e1}, trying temporary file method...")
# Fallback to temporary file method
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
try:
# Load audio using librosa
import librosa
audio_array, sr = librosa.load(temp_path, sr=self.sample_rate, mono=True)
# Normalize audio
if len(audio_array) > 0:
max_val = np.max(np.abs(audio_array))
if max_val > 0:
audio_array = audio_array / max_val
logger.info(f"Audio preprocessed (file): {len(audio_array)} samples at {sr}Hz")
return audio_array
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"β Audio preprocessing failed: {e}")
raise Exception(f"Failed to preprocess audio: {str(e)}")
def transcribe(self, audio_bytes):
"""
Transcribe audio bytes to text.
Args:
audio_bytes (bytes): Audio data in bytes format
Returns:
str: Transcribed text
Raises:
ValueError: If audio data is invalid
RuntimeError: If transcription fails
"""
if not audio_bytes:
raise ValueError("Audio data cannot be empty")
try:
logger.info("Starting speech transcription...")
# Preprocess audio
audio_array = self._preprocess_audio(audio_bytes)
if len(audio_array) == 0:
logger.warning("Empty audio array after preprocessing")
return ""
# Perform transcription with compatible settings
result = self.transcriber(audio_array)
# Extract text from result
if isinstance(result, dict):
transcribed_text = result.get('text', '').strip()
elif isinstance(result, str):
transcribed_text = result.strip()
else:
transcribed_text = str(result).strip()
# Clean up common transcription artifacts
transcribed_text = self._clean_transcription(transcribed_text)
logger.info(f"Transcription completed: '{transcribed_text[:100]}{'...' if len(transcribed_text) > 100 else ''}'")
return transcribed_text
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise RuntimeError(f"Speech transcription failed: {str(e)}")
def _clean_transcription(self, text):
"""
Clean up common transcription artifacts and repetitive patterns.
Args:
text (str): Raw transcription text
Returns:
str: Cleaned transcription text
"""
if not text:
return ""
import re
# Remove excessive repetition (more than 3 consecutive identical words/chars)
# Pattern: word-word-word-word... -> word
text = re.sub(r'\b(\w+)(?:-\1){3,}\b', r'\1', text)
# Remove excessive repetition of single characters
# Pattern: I-I-I-I... -> I
text = re.sub(r'\b(\w)(?:-\1){2,}\b', r'\1', text)
# Remove excessive repetition of words
# Pattern: yeah yeah yeah yeah... -> yeah
text = re.sub(r'\b(\w+)(?:\s+\1){3,}\b', r'\1', text, flags=re.IGNORECASE)
# Clean up extra spaces and punctuation
text = re.sub(r'\s+', ' ', text)
text = text.strip()
# If result is too short or just repetitive noise, return empty
if len(text) < 3 or len(set(text.lower().split())) == 1:
return ""
return text
def transcribe_file(self, file_path):
"""
Transcribe audio from a file.
Args:
file_path (str): Path to the audio file
Returns:
str: Transcribed text
Raises:
FileNotFoundError: If file doesn't exist
RuntimeError: If transcription fails
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
try:
with open(file_path, 'rb') as f:
audio_bytes = f.read()
return self.transcribe(audio_bytes)
except Exception as e:
logger.error(f"File transcription failed: {e}")
raise RuntimeError(f"File transcription failed: {str(e)}")
def get_supported_formats(self):
"""
Get list of supported audio formats.
Returns:
list: List of supported audio file extensions
"""
return ['.wav', '.mp3', '.m4a', '.flac', '.ogg', '.aac']
def get_engine_info(self):
"""
Get information about the ASR engine.
Returns:
dict: Engine information including model and device details
"""
return {
"engine": "ASR Engine",
"version": "1.0.0",
"model": self.model_name,
"device": "GPU" if torch.cuda.is_available() else "CPU",
"sample_rate": self.sample_rate,
"supported_formats": self.get_supported_formats(),
"framework": "transformers + whisper"
}
def health_check(self):
"""
Perform a health check on the ASR engine.
Returns:
dict: Health status information
"""
try:
# Test with a simple sine wave
test_audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.sample_rate))
test_audio = (test_audio * 32767).astype(np.int16)
# Convert to bytes
test_bytes = test_audio.tobytes()
# Try transcription (should return empty or noise)
self.transcriber(test_audio.astype(np.float32))
return {
"status": "healthy",
"message": "ASR engine is functioning correctly",
"model_loaded": True
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"message": f"ASR engine health check failed: {str(e)}",
"model_loaded": hasattr(self, 'transcriber') and self.transcriber is not None
}
def main():
"""Example usage and testing of the ASR Engine."""
try:
# Initialize the engine
logger.info("Testing ASR Engine...")
engine = ASREngine()
# Print engine info
info = engine.get_engine_info()
logger.info(f"Engine Info: {info}")
# Perform health check
health = engine.health_check()
logger.info(f"Health Check: {health}")
# Test with a simple audio file if available
test_files = ["test_audio.wav", "sample.wav", "test.wav"]
for test_file in test_files:
if os.path.exists(test_file):
try:
transcription = engine.transcribe_file(test_file)
logger.info(f"π― Transcription: {transcription}")
break
except Exception as e:
logger.error(f"Failed to transcribe {test_file}: {e}")
else:
logger.info("No test audio files found. Engine is ready for use.")
logger.info("π ASR Engine testing completed!")
except Exception as e:
logger.error(f"β ASR Engine test failed: {e}")
if __name__ == "__main__":
main()
|