carsa_api / asr_engine.py
athmontech's picture
Remove Hausa language support - model discontinued
130ce6d
"""
Automatic Speech Recognition (ASR) Engine for Carsa AI
A comprehensive ASR engine that converts speech audio to text using
state-of-the-art speech recognition models. Optimized for English speech
recognition with support for various audio formats.
Features:
- High-quality speech-to-text conversion
- Support for WAV, MP3, and other audio formats
- Automatic audio preprocessing
- GPU acceleration when available
- Robust error handling
Author: Carsa AI Team
Version: 1.0.0
"""
import torch
import logging
import io
import tempfile
import os
import soundfile as sf
from transformers import pipeline
import librosa
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ASREngine:
"""
A production-ready Automatic Speech Recognition engine.
This class provides speech-to-text capabilities using Hugging Face's
transformers library with Whisper or similar ASR models.
"""
def __init__(self, model_name="openai/whisper-small"):
"""
Initialize the ASR Engine.
Args:
model_name (str): The ASR model to use. Default: "openai/whisper-base"
Options: "openai/whisper-tiny", "openai/whisper-base",
"openai/whisper-small", "openai/whisper-medium"
Raises:
Exception: If model loading fails
"""
try:
self.device = 0 if torch.cuda.is_available() else -1
device_name = "GPU" if torch.cuda.is_available() else "CPU"
logger.info(f"ASR Engine using device: {device_name}")
self.model_name = model_name
self.sample_rate = 16000 # Whisper expects 16kHz audio
logger.info(f"Loading ASR model: {model_name}")
# Load the ASR pipeline
self.transcriber = pipeline(
"automatic-speech-recognition",
model=model_name,
device=self.device,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
return_timestamps=False # Set to True if you want word-level timestamps
)
logger.info("βœ… ASR Engine initialized successfully!")
except Exception as e:
logger.error(f"❌ Failed to initialize ASR Engine: {e}")
raise Exception(f"ASR Engine initialization failed: {str(e)}")
def _preprocess_audio(self, audio_bytes):
"""
Preprocess audio data for speech recognition.
Args:
audio_bytes (bytes): Raw audio data
Returns:
np.ndarray: Preprocessed audio array
Raises:
Exception: If audio preprocessing fails
"""
try:
# First try using BytesIO (faster method)
try:
audio_file = io.BytesIO(audio_bytes)
audio_data, sr = sf.read(audio_file)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample if needed
if sr != self.sample_rate:
import librosa
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate)
# Normalize and enhance audio
if len(audio_data) > 0:
# Remove DC offset
audio_data = audio_data - np.mean(audio_data)
# Apply normalization
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val
# Apply gentle noise gate (remove very quiet sections)
noise_floor = 0.01 # 1% threshold
audio_data = np.where(np.abs(audio_data) < noise_floor, 0, audio_data)
logger.info(f"Audio preprocessed (BytesIO): {len(audio_data)} samples at {self.sample_rate}Hz")
return audio_data
except Exception as e1:
logger.warning(f"BytesIO method failed: {e1}, trying temporary file method...")
# Fallback to temporary file method
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
try:
# Load audio using librosa
import librosa
audio_array, sr = librosa.load(temp_path, sr=self.sample_rate, mono=True)
# Normalize audio
if len(audio_array) > 0:
max_val = np.max(np.abs(audio_array))
if max_val > 0:
audio_array = audio_array / max_val
logger.info(f"Audio preprocessed (file): {len(audio_array)} samples at {sr}Hz")
return audio_array
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"❌ Audio preprocessing failed: {e}")
raise Exception(f"Failed to preprocess audio: {str(e)}")
def transcribe(self, audio_bytes):
"""
Transcribe audio bytes to text.
Args:
audio_bytes (bytes): Audio data in bytes format
Returns:
str: Transcribed text
Raises:
ValueError: If audio data is invalid
RuntimeError: If transcription fails
"""
if not audio_bytes:
raise ValueError("Audio data cannot be empty")
try:
logger.info("Starting speech transcription...")
# Preprocess audio
audio_array = self._preprocess_audio(audio_bytes)
if len(audio_array) == 0:
logger.warning("Empty audio array after preprocessing")
return ""
# Perform transcription with compatible settings
result = self.transcriber(audio_array)
# Extract text from result
if isinstance(result, dict):
transcribed_text = result.get('text', '').strip()
elif isinstance(result, str):
transcribed_text = result.strip()
else:
transcribed_text = str(result).strip()
# Clean up common transcription artifacts
transcribed_text = self._clean_transcription(transcribed_text)
logger.info(f"Transcription completed: '{transcribed_text[:100]}{'...' if len(transcribed_text) > 100 else ''}'")
return transcribed_text
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise RuntimeError(f"Speech transcription failed: {str(e)}")
def _clean_transcription(self, text):
"""
Clean up common transcription artifacts and repetitive patterns.
Args:
text (str): Raw transcription text
Returns:
str: Cleaned transcription text
"""
if not text:
return ""
import re
# Remove excessive repetition (more than 3 consecutive identical words/chars)
# Pattern: word-word-word-word... -> word
text = re.sub(r'\b(\w+)(?:-\1){3,}\b', r'\1', text)
# Remove excessive repetition of single characters
# Pattern: I-I-I-I... -> I
text = re.sub(r'\b(\w)(?:-\1){2,}\b', r'\1', text)
# Remove excessive repetition of words
# Pattern: yeah yeah yeah yeah... -> yeah
text = re.sub(r'\b(\w+)(?:\s+\1){3,}\b', r'\1', text, flags=re.IGNORECASE)
# Clean up extra spaces and punctuation
text = re.sub(r'\s+', ' ', text)
text = text.strip()
# If result is too short or just repetitive noise, return empty
if len(text) < 3 or len(set(text.lower().split())) == 1:
return ""
return text
def transcribe_file(self, file_path):
"""
Transcribe audio from a file.
Args:
file_path (str): Path to the audio file
Returns:
str: Transcribed text
Raises:
FileNotFoundError: If file doesn't exist
RuntimeError: If transcription fails
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
try:
with open(file_path, 'rb') as f:
audio_bytes = f.read()
return self.transcribe(audio_bytes)
except Exception as e:
logger.error(f"File transcription failed: {e}")
raise RuntimeError(f"File transcription failed: {str(e)}")
def get_supported_formats(self):
"""
Get list of supported audio formats.
Returns:
list: List of supported audio file extensions
"""
return ['.wav', '.mp3', '.m4a', '.flac', '.ogg', '.aac']
def get_engine_info(self):
"""
Get information about the ASR engine.
Returns:
dict: Engine information including model and device details
"""
return {
"engine": "ASR Engine",
"version": "1.0.0",
"model": self.model_name,
"device": "GPU" if torch.cuda.is_available() else "CPU",
"sample_rate": self.sample_rate,
"supported_formats": self.get_supported_formats(),
"framework": "transformers + whisper"
}
def health_check(self):
"""
Perform a health check on the ASR engine.
Returns:
dict: Health status information
"""
try:
# Test with a simple sine wave
test_audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.sample_rate))
test_audio = (test_audio * 32767).astype(np.int16)
# Convert to bytes
test_bytes = test_audio.tobytes()
# Try transcription (should return empty or noise)
self.transcriber(test_audio.astype(np.float32))
return {
"status": "healthy",
"message": "ASR engine is functioning correctly",
"model_loaded": True
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"message": f"ASR engine health check failed: {str(e)}",
"model_loaded": hasattr(self, 'transcriber') and self.transcriber is not None
}
def main():
"""Example usage and testing of the ASR Engine."""
try:
# Initialize the engine
logger.info("Testing ASR Engine...")
engine = ASREngine()
# Print engine info
info = engine.get_engine_info()
logger.info(f"Engine Info: {info}")
# Perform health check
health = engine.health_check()
logger.info(f"Health Check: {health}")
# Test with a simple audio file if available
test_files = ["test_audio.wav", "sample.wav", "test.wav"]
for test_file in test_files:
if os.path.exists(test_file):
try:
transcription = engine.transcribe_file(test_file)
logger.info(f"🎯 Transcription: {transcription}")
break
except Exception as e:
logger.error(f"Failed to transcribe {test_file}: {e}")
else:
logger.info("No test audio files found. Engine is ready for use.")
logger.info("πŸŽ‰ ASR Engine testing completed!")
except Exception as e:
logger.error(f"❌ ASR Engine test failed: {e}")
if __name__ == "__main__":
main()