|
|
""" |
|
|
Automatic Speech Recognition (ASR) Engine for Carsa AI |
|
|
|
|
|
A comprehensive ASR engine that converts speech audio to text using |
|
|
state-of-the-art speech recognition models. Optimized for English speech |
|
|
recognition with support for various audio formats. |
|
|
|
|
|
Features: |
|
|
- High-quality speech-to-text conversion |
|
|
- Support for WAV, MP3, and other audio formats |
|
|
- Automatic audio preprocessing |
|
|
- GPU acceleration when available |
|
|
- Robust error handling |
|
|
|
|
|
Author: Carsa AI Team |
|
|
Version: 1.0.0 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import logging |
|
|
import io |
|
|
import tempfile |
|
|
import os |
|
|
import soundfile as sf |
|
|
from transformers import pipeline |
|
|
import librosa |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ASREngine: |
|
|
""" |
|
|
A production-ready Automatic Speech Recognition engine. |
|
|
|
|
|
This class provides speech-to-text capabilities using Hugging Face's |
|
|
transformers library with Whisper or similar ASR models. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name="openai/whisper-small"): |
|
|
""" |
|
|
Initialize the ASR Engine. |
|
|
|
|
|
Args: |
|
|
model_name (str): The ASR model to use. Default: "openai/whisper-base" |
|
|
Options: "openai/whisper-tiny", "openai/whisper-base", |
|
|
"openai/whisper-small", "openai/whisper-medium" |
|
|
|
|
|
Raises: |
|
|
Exception: If model loading fails |
|
|
""" |
|
|
try: |
|
|
self.device = 0 if torch.cuda.is_available() else -1 |
|
|
device_name = "GPU" if torch.cuda.is_available() else "CPU" |
|
|
logger.info(f"ASR Engine using device: {device_name}") |
|
|
|
|
|
self.model_name = model_name |
|
|
self.sample_rate = 16000 |
|
|
|
|
|
logger.info(f"Loading ASR model: {model_name}") |
|
|
|
|
|
|
|
|
self.transcriber = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=model_name, |
|
|
device=self.device, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
return_timestamps=False |
|
|
) |
|
|
|
|
|
logger.info("β
ASR Engine initialized successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to initialize ASR Engine: {e}") |
|
|
raise Exception(f"ASR Engine initialization failed: {str(e)}") |
|
|
|
|
|
def _preprocess_audio(self, audio_bytes): |
|
|
""" |
|
|
Preprocess audio data for speech recognition. |
|
|
|
|
|
Args: |
|
|
audio_bytes (bytes): Raw audio data |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Preprocessed audio array |
|
|
|
|
|
Raises: |
|
|
Exception: If audio preprocessing fails |
|
|
""" |
|
|
try: |
|
|
|
|
|
try: |
|
|
audio_file = io.BytesIO(audio_bytes) |
|
|
audio_data, sr = sf.read(audio_file) |
|
|
|
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
|
|
|
|
|
if sr != self.sample_rate: |
|
|
import librosa |
|
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate) |
|
|
|
|
|
|
|
|
if len(audio_data) > 0: |
|
|
|
|
|
audio_data = audio_data - np.mean(audio_data) |
|
|
|
|
|
|
|
|
max_val = np.max(np.abs(audio_data)) |
|
|
if max_val > 0: |
|
|
audio_data = audio_data / max_val |
|
|
|
|
|
|
|
|
noise_floor = 0.01 |
|
|
audio_data = np.where(np.abs(audio_data) < noise_floor, 0, audio_data) |
|
|
|
|
|
logger.info(f"Audio preprocessed (BytesIO): {len(audio_data)} samples at {self.sample_rate}Hz") |
|
|
return audio_data |
|
|
|
|
|
except Exception as e1: |
|
|
logger.warning(f"BytesIO method failed: {e1}, trying temporary file method...") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: |
|
|
temp_file.write(audio_bytes) |
|
|
temp_path = temp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
import librosa |
|
|
audio_array, sr = librosa.load(temp_path, sr=self.sample_rate, mono=True) |
|
|
|
|
|
|
|
|
if len(audio_array) > 0: |
|
|
max_val = np.max(np.abs(audio_array)) |
|
|
if max_val > 0: |
|
|
audio_array = audio_array / max_val |
|
|
|
|
|
logger.info(f"Audio preprocessed (file): {len(audio_array)} samples at {sr}Hz") |
|
|
return audio_array |
|
|
|
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_path): |
|
|
os.unlink(temp_path) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Audio preprocessing failed: {e}") |
|
|
raise Exception(f"Failed to preprocess audio: {str(e)}") |
|
|
|
|
|
def transcribe(self, audio_bytes): |
|
|
""" |
|
|
Transcribe audio bytes to text. |
|
|
|
|
|
Args: |
|
|
audio_bytes (bytes): Audio data in bytes format |
|
|
|
|
|
Returns: |
|
|
str: Transcribed text |
|
|
|
|
|
Raises: |
|
|
ValueError: If audio data is invalid |
|
|
RuntimeError: If transcription fails |
|
|
""" |
|
|
if not audio_bytes: |
|
|
raise ValueError("Audio data cannot be empty") |
|
|
|
|
|
try: |
|
|
logger.info("Starting speech transcription...") |
|
|
|
|
|
|
|
|
audio_array = self._preprocess_audio(audio_bytes) |
|
|
|
|
|
if len(audio_array) == 0: |
|
|
logger.warning("Empty audio array after preprocessing") |
|
|
return "" |
|
|
|
|
|
|
|
|
result = self.transcriber(audio_array) |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
transcribed_text = result.get('text', '').strip() |
|
|
elif isinstance(result, str): |
|
|
transcribed_text = result.strip() |
|
|
else: |
|
|
transcribed_text = str(result).strip() |
|
|
|
|
|
|
|
|
transcribed_text = self._clean_transcription(transcribed_text) |
|
|
|
|
|
logger.info(f"Transcription completed: '{transcribed_text[:100]}{'...' if len(transcribed_text) > 100 else ''}'") |
|
|
|
|
|
return transcribed_text |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Transcription failed: {e}") |
|
|
raise RuntimeError(f"Speech transcription failed: {str(e)}") |
|
|
|
|
|
def _clean_transcription(self, text): |
|
|
""" |
|
|
Clean up common transcription artifacts and repetitive patterns. |
|
|
|
|
|
Args: |
|
|
text (str): Raw transcription text |
|
|
|
|
|
Returns: |
|
|
str: Cleaned transcription text |
|
|
""" |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
text = re.sub(r'\b(\w+)(?:-\1){3,}\b', r'\1', text) |
|
|
|
|
|
|
|
|
|
|
|
text = re.sub(r'\b(\w)(?:-\1){2,}\b', r'\1', text) |
|
|
|
|
|
|
|
|
|
|
|
text = re.sub(r'\b(\w+)(?:\s+\1){3,}\b', r'\1', text, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
text = text.strip() |
|
|
|
|
|
|
|
|
if len(text) < 3 or len(set(text.lower().split())) == 1: |
|
|
return "" |
|
|
|
|
|
return text |
|
|
|
|
|
def transcribe_file(self, file_path): |
|
|
""" |
|
|
Transcribe audio from a file. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the audio file |
|
|
|
|
|
Returns: |
|
|
str: Transcribed text |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If file doesn't exist |
|
|
RuntimeError: If transcription fails |
|
|
""" |
|
|
if not os.path.exists(file_path): |
|
|
raise FileNotFoundError(f"Audio file not found: {file_path}") |
|
|
|
|
|
try: |
|
|
with open(file_path, 'rb') as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
return self.transcribe(audio_bytes) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"File transcription failed: {e}") |
|
|
raise RuntimeError(f"File transcription failed: {str(e)}") |
|
|
|
|
|
def get_supported_formats(self): |
|
|
""" |
|
|
Get list of supported audio formats. |
|
|
|
|
|
Returns: |
|
|
list: List of supported audio file extensions |
|
|
""" |
|
|
return ['.wav', '.mp3', '.m4a', '.flac', '.ogg', '.aac'] |
|
|
|
|
|
def get_engine_info(self): |
|
|
""" |
|
|
Get information about the ASR engine. |
|
|
|
|
|
Returns: |
|
|
dict: Engine information including model and device details |
|
|
""" |
|
|
return { |
|
|
"engine": "ASR Engine", |
|
|
"version": "1.0.0", |
|
|
"model": self.model_name, |
|
|
"device": "GPU" if torch.cuda.is_available() else "CPU", |
|
|
"sample_rate": self.sample_rate, |
|
|
"supported_formats": self.get_supported_formats(), |
|
|
"framework": "transformers + whisper" |
|
|
} |
|
|
|
|
|
def health_check(self): |
|
|
""" |
|
|
Perform a health check on the ASR engine. |
|
|
|
|
|
Returns: |
|
|
dict: Health status information |
|
|
""" |
|
|
try: |
|
|
|
|
|
test_audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.sample_rate)) |
|
|
test_audio = (test_audio * 32767).astype(np.int16) |
|
|
|
|
|
|
|
|
test_bytes = test_audio.tobytes() |
|
|
|
|
|
|
|
|
self.transcriber(test_audio.astype(np.float32)) |
|
|
|
|
|
return { |
|
|
"status": "healthy", |
|
|
"message": "ASR engine is functioning correctly", |
|
|
"model_loaded": True |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Health check failed: {e}") |
|
|
return { |
|
|
"status": "unhealthy", |
|
|
"message": f"ASR engine health check failed: {str(e)}", |
|
|
"model_loaded": hasattr(self, 'transcriber') and self.transcriber is not None |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage and testing of the ASR Engine.""" |
|
|
try: |
|
|
|
|
|
logger.info("Testing ASR Engine...") |
|
|
engine = ASREngine() |
|
|
|
|
|
|
|
|
info = engine.get_engine_info() |
|
|
logger.info(f"Engine Info: {info}") |
|
|
|
|
|
|
|
|
health = engine.health_check() |
|
|
logger.info(f"Health Check: {health}") |
|
|
|
|
|
|
|
|
test_files = ["test_audio.wav", "sample.wav", "test.wav"] |
|
|
|
|
|
for test_file in test_files: |
|
|
if os.path.exists(test_file): |
|
|
try: |
|
|
transcription = engine.transcribe_file(test_file) |
|
|
logger.info(f"π― Transcription: {transcription}") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to transcribe {test_file}: {e}") |
|
|
else: |
|
|
logger.info("No test audio files found. Engine is ready for use.") |
|
|
|
|
|
logger.info("π ASR Engine testing completed!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β ASR Engine test failed: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|