|
|
import torch
|
|
|
import asyncio
|
|
|
import websockets
|
|
|
import json
|
|
|
import threading
|
|
|
import numpy as np
|
|
|
import logging
|
|
|
import time
|
|
|
import tempfile
|
|
|
import os
|
|
|
import re
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
import subprocess
|
|
|
import struct
|
|
|
|
|
|
|
|
|
import nemo.collections.asr as nemo_asr
|
|
|
import soundfile as sf
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
from pyarabic.number import text2number
|
|
|
arabic_numbers_available = True
|
|
|
print("✓ pyarabic library available for Whisper number conversion")
|
|
|
except ImportError:
|
|
|
arabic_numbers_available = False
|
|
|
print("✗ pyarabic not available - install with: pip install pyarabic")
|
|
|
print("Arabic numbers will not be converted to digits for Whisper")
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
arabic_numbers_nemo = {
|
|
|
|
|
|
"سفر": "0", "فيرو": "0", "هيرو": "0","صفر": "0", "زيرو": "0", "٠": "0","زيو": "0","زير": "0","زير": "0","زر": "0","زروا": "0","زرا": "0","زيره ": "0","زرو ": "0",
|
|
|
"واحد": "1", "واحدة": "1", "١": "1",
|
|
|
"اتنين": "2", "اثنين": "2", "إثنين": "2", "اثنان": "2", "إثنان": "2", "٢": "2",
|
|
|
"تلاتة": "3", "ثلاثة": "3", "٣": "3","تلاته": "3","ثلاثه": "3","ثلاثا": "3","تلاتا": "3",
|
|
|
"اربعة": "4", "أربعة": "4", "٤": "4","اربعه": "4","أربعه": "4","أربع": "4","اربع": "4","اربعا": "4","أربعا": "4",
|
|
|
"خمسة": "5", "خمسه": "5", "٥": "5", "خمس": "5", "خمسا": "5",
|
|
|
"ستة": "6", "سته": "6", "٦": "6", "ست": "6", "ستّا": "6", "ستةً": "6",
|
|
|
"سبعة": "7", "سبعه": "7", "٧": "7", "سبع": "7", "سبعا": "7",
|
|
|
"ثمانية": "8", "ثمانيه": "8", "٨": "8", "ثمان": "8", "ثمنية": "8", "ثمنيه": "8", "ثمانيا": "8", "ثمن": "8",
|
|
|
"تسعة": "9", "تسعه": "9", "٩": "9", "تسع": "9", "تسعا": "9",
|
|
|
|
|
|
|
|
|
"عشرة": "10", "١٠": "10",
|
|
|
"حداشر": "11", "احد عشر": "11","احداشر": "11",
|
|
|
"اتناشر": "12", "اثنا عشر": "12",
|
|
|
"تلتاشر": "13", "ثلاثة عشر": "13",
|
|
|
"اربعتاشر": "14", "أربعة عشر": "14",
|
|
|
"خمستاشر": "15", "خمسة عشر": "15",
|
|
|
"ستاشر": "16", "ستة عشر": "16",
|
|
|
"سبعتاشر": "17", "سبعة عشر": "17",
|
|
|
"طمنتاشر": "18", "ثمانية عشر": "18",
|
|
|
"تسعتاشر": "19", "تسعة عشر": "19",
|
|
|
|
|
|
|
|
|
"عشرين": "20", "٢٠": "20",
|
|
|
"تلاتين": "30", "ثلاثين": "30", "٣٠": "30",
|
|
|
"اربعين": "40", "أربعين": "40", "٤٠": "40",
|
|
|
"خمسين": "50", "٥٠": "50",
|
|
|
"ستين": "60", "٦٠": "60",
|
|
|
"سبعين": "70", "٧٠": "70",
|
|
|
"تمانين": "80", "ثمانين": "80", "٨٠": "80","تمانون": "80","ثمانون": "80",
|
|
|
"تسعين": "90", "٩٠": "90",
|
|
|
|
|
|
|
|
|
"مية": "100", "مائة": "100", "مئة": "100", "١٠٠": "100",
|
|
|
"ميتين": "200", "مائتين": "200",
|
|
|
"تلاتمية": "300", "ثلاثمائة": "300",
|
|
|
"اربعمية": "400", "أربعمائة": "400",
|
|
|
"خمسمية": "500", "خمسمائة": "500",
|
|
|
"ستمية": "600", "ستمائة": "600",
|
|
|
"سبعمية": "700", "سبعمائة": "700",
|
|
|
"تمانمية": "800", "ثمانمائة": "800",
|
|
|
"تسعمية": "900", "تسعمائة": "900",
|
|
|
|
|
|
|
|
|
"ألف": "1000", "الف": "1000", "١٠٠٠": "1000",
|
|
|
"ألفين": "2000", "الفين": "2000",
|
|
|
"تلات تلاف": "3000", "ثلاثة آلاف": "3000",
|
|
|
"اربعة آلاف": "4000", "أربعة آلاف": "4000",
|
|
|
"خمسة آلاف": "5000",
|
|
|
"ستة آلاف": "6000",
|
|
|
"سبعة آلاف": "7000",
|
|
|
"تمانية آلاف": "8000", "ثمانية آلاف": "8000",
|
|
|
"تسعة آلاف": "9000",
|
|
|
|
|
|
|
|
|
"عشرة آلاف": "10000",
|
|
|
"مية ألف": "100000", "مائة ألف": "100000",
|
|
|
"مليون": "1000000", "١٠٠٠٠٠٠": "1000000",
|
|
|
"ملايين": "1000000",
|
|
|
"مليار": "1000000000", "١٠٠٠٠٠٠٠٠٠": "1000000000"
|
|
|
}
|
|
|
|
|
|
def replace_arabic_numbers_nemo(text: str) -> str:
|
|
|
"""Convert Arabic number words to digits for NeMo"""
|
|
|
for word, digit in arabic_numbers_nemo.items():
|
|
|
text = re.sub(rf"\b{word}\b", digit, text)
|
|
|
return text
|
|
|
|
|
|
def convert_arabic_numbers_whisper(sentence: str) -> str:
|
|
|
"""
|
|
|
Replace Arabic number words in a sentence with digits for Whisper,
|
|
|
preserving all other words and punctuation.
|
|
|
"""
|
|
|
if not arabic_numbers_available or not sentence.strip():
|
|
|
return sentence
|
|
|
|
|
|
try:
|
|
|
|
|
|
replacements = {
|
|
|
"اربعة": "أربعة", "اربع": "أربع", "اثنين": "اثنان",
|
|
|
"اتنين": "اثنان", "ثلاث": "ثلاثة", "خمس": "خمسة",
|
|
|
"ست": "ستة", "سبع": "سبعة", "ثمان": "ثمانية",
|
|
|
"تسع": "تسعة", "عشر": "عشرة",
|
|
|
}
|
|
|
for wrong, correct in replacements.items():
|
|
|
sentence = re.sub(rf"\b{wrong}\b", correct, sentence)
|
|
|
|
|
|
|
|
|
words = re.split(r'(\s+)', sentence)
|
|
|
converted_words = []
|
|
|
|
|
|
for word in words:
|
|
|
stripped = word.strip()
|
|
|
if not stripped:
|
|
|
converted_words.append(word)
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
num = text2number(stripped)
|
|
|
if isinstance(num, int):
|
|
|
if num != 0 or stripped == "صفر":
|
|
|
converted_words.append(str(num))
|
|
|
else:
|
|
|
converted_words.append(word)
|
|
|
else:
|
|
|
converted_words.append(word)
|
|
|
except Exception:
|
|
|
converted_words.append(word)
|
|
|
|
|
|
return ''.join(converted_words)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error converting Arabic numbers: {e}")
|
|
|
return sentence
|
|
|
|
|
|
|
|
|
asr_model_nemo = None
|
|
|
whisper_model = None
|
|
|
whisper_processor = None
|
|
|
whisper_tokenizer = None
|
|
|
device = None
|
|
|
torch_dtype = None
|
|
|
|
|
|
def initialize_models():
|
|
|
"""Initialize both NeMo and Whisper models"""
|
|
|
global asr_model_nemo, whisper_model, whisper_processor, whisper_tokenizer, device, torch_dtype
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
logger.info(f"Using device: {device}")
|
|
|
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
|
|
|
|
|
|
|
|
logger.info("Loading NeMo FastConformer Arabic ASR model...")
|
|
|
model_path = "stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo"
|
|
|
|
|
|
if os.path.exists(model_path):
|
|
|
try:
|
|
|
asr_model_nemo = nemo_asr.models.EncDecCTCModel.restore_from(model_path)
|
|
|
asr_model_nemo.eval()
|
|
|
logger.info("✓ NeMo FastConformer model loaded successfully")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load NeMo model: {e}")
|
|
|
asr_model_nemo = None
|
|
|
else:
|
|
|
logger.warning(f"NeMo model not found at: {model_path}")
|
|
|
asr_model_nemo = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Loading Whisper large-v3 model...")
|
|
|
MODEL_NAME = "alaatiger989/FT_Arabic_Whisper_V1_1"
|
|
|
|
|
|
try:
|
|
|
|
|
|
try:
|
|
|
import flash_attn
|
|
|
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
|
MODEL_NAME,
|
|
|
torch_dtype=torch_dtype,
|
|
|
low_cpu_mem_usage=True,
|
|
|
use_safetensors=True,
|
|
|
attn_implementation="flash_attention_2"
|
|
|
)
|
|
|
logger.info("✓ Whisper loaded with flash attention")
|
|
|
except:
|
|
|
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
|
MODEL_NAME,
|
|
|
torch_dtype=torch_dtype,
|
|
|
low_cpu_mem_usage=True,
|
|
|
use_safetensors=True
|
|
|
)
|
|
|
logger.info("✓ Whisper loaded with standard attention")
|
|
|
|
|
|
whisper_model.to(device)
|
|
|
whisper_processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
whisper_tokenizer = whisper_processor.tokenizer
|
|
|
|
|
|
logger.info("✓ Whisper model + tokenizer loaded successfully")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load Whisper model: {e}")
|
|
|
whisper_model = None
|
|
|
|
|
|
|
|
|
initialize_models()
|
|
|
|
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
|
|
|
|
|
|
|
|
class JambonzAudioBuffer:
|
|
|
def __init__(self, sample_rate=8000, chunk_duration=1.0):
|
|
|
self.sample_rate = sample_rate
|
|
|
self.chunk_duration = chunk_duration
|
|
|
self.chunk_samples = int(chunk_duration * sample_rate)
|
|
|
|
|
|
self.buffer = np.array([], dtype=np.float32)
|
|
|
self.lock = threading.Lock()
|
|
|
self.total_audio = np.array([], dtype=np.float32)
|
|
|
|
|
|
|
|
|
self.silence_threshold = 0.01
|
|
|
self.min_speech_samples = int(0.3 * sample_rate)
|
|
|
|
|
|
def add_audio(self, audio_data):
|
|
|
with self.lock:
|
|
|
self.buffer = np.concatenate([self.buffer, audio_data])
|
|
|
self.total_audio = np.concatenate([self.total_audio, audio_data])
|
|
|
|
|
|
|
|
|
logger.debug(f"Added {len(audio_data)} audio samples, total: {len(self.total_audio)}")
|
|
|
|
|
|
def has_chunk_ready(self):
|
|
|
with self.lock:
|
|
|
ready = len(self.buffer) >= self.chunk_samples
|
|
|
if ready:
|
|
|
logger.debug(f"Chunk ready: {len(self.buffer)} >= {self.chunk_samples}")
|
|
|
return ready
|
|
|
|
|
|
def is_speech(self, audio_chunk):
|
|
|
"""Enhanced VAD based on energy - better for Whisper"""
|
|
|
if len(audio_chunk) < self.min_speech_samples:
|
|
|
logger.debug(f"Audio too short for VAD: {len(audio_chunk)} < {self.min_speech_samples}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
rms_energy = np.sqrt(np.mean(audio_chunk ** 2))
|
|
|
|
|
|
|
|
|
peak_amplitude = np.max(np.abs(audio_chunk))
|
|
|
|
|
|
is_speech = rms_energy > self.silence_threshold or peak_amplitude > (self.silence_threshold * 2)
|
|
|
|
|
|
logger.debug(f"VAD check - RMS: {rms_energy:.4f}, Peak: {peak_amplitude:.4f}, "
|
|
|
f"Threshold: {self.silence_threshold}, Speech: {is_speech}")
|
|
|
|
|
|
return is_speech
|
|
|
|
|
|
def get_chunk_for_processing(self):
|
|
|
"""Get audio chunk for processing"""
|
|
|
with self.lock:
|
|
|
if len(self.buffer) < self.chunk_samples:
|
|
|
return None
|
|
|
|
|
|
logger.debug(f"Returning processing signal, buffer size: {len(self.buffer)}")
|
|
|
return np.array([1])
|
|
|
|
|
|
def get_all_audio(self):
|
|
|
"""Get all accumulated audio"""
|
|
|
with self.lock:
|
|
|
audio_copy = self.total_audio.copy()
|
|
|
logger.debug(f"Returning {len(audio_copy)} total audio samples")
|
|
|
return audio_copy
|
|
|
|
|
|
def clear(self):
|
|
|
with self.lock:
|
|
|
self.buffer = np.array([], dtype=np.float32)
|
|
|
self.total_audio = np.array([], dtype=np.float32)
|
|
|
logger.debug("Audio buffer cleared")
|
|
|
|
|
|
def reset_for_new_segment(self):
|
|
|
"""Reset buffers for new transcription segment"""
|
|
|
with self.lock:
|
|
|
self.buffer = np.array([], dtype=np.float32)
|
|
|
self.total_audio = np.array([], dtype=np.float32)
|
|
|
logger.debug("Audio buffer reset for new segment")
|
|
|
|
|
|
def linear16_to_audio(audio_bytes, sample_rate=8000):
|
|
|
"""Convert LINEAR16 PCM bytes to numpy array"""
|
|
|
try:
|
|
|
audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
|
audio_array = audio_array.astype(np.float32) / 32768.0
|
|
|
return audio_array
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error converting LINEAR16 to audio: {e}")
|
|
|
return np.array([], dtype=np.float32)
|
|
|
|
|
|
from scipy.signal import resample_poly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
from scipy.signal import resample_poly, butter, lfilter
|
|
|
import webrtcvad
|
|
|
import noisereduce as nr
|
|
|
|
|
|
|
|
|
_vad = webrtcvad.Vad(2)
|
|
|
|
|
|
def resample_audio(audio_data, source_rate, target_rate=16000,
|
|
|
lowcut=80.0, highcut=7600.0,
|
|
|
frame_ms=30, required_ratio=0.55):
|
|
|
"""
|
|
|
Resample -> Bandpass filter -> Noise reduction -> WebRTC VAD speech detection.
|
|
|
|
|
|
Returns:
|
|
|
processed_audio (np.ndarray float32): cleaned/resampled audio
|
|
|
is_speech (bool): True if VAD detects speech
|
|
|
"""
|
|
|
|
|
|
|
|
|
if source_rate != target_rate:
|
|
|
gcd = np.gcd(source_rate, target_rate)
|
|
|
up = target_rate // gcd
|
|
|
down = source_rate // gcd
|
|
|
try:
|
|
|
audio_data = resample_poly(audio_data, up, down).astype(np.float32)
|
|
|
except Exception:
|
|
|
audio_data = np.repeat(audio_data, int(target_rate/source_rate)).astype(np.float32)
|
|
|
else:
|
|
|
audio_data = audio_data.astype(np.float32)
|
|
|
|
|
|
|
|
|
try:
|
|
|
nyq = 0.5 * target_rate
|
|
|
low = lowcut / nyq
|
|
|
high = highcut / nyq
|
|
|
b, a = butter(4, [low, high], btype='band')
|
|
|
audio_data = lfilter(b, a, audio_data).astype(np.float32)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
try:
|
|
|
if len(audio_data) >= int(0.25 * target_rate):
|
|
|
noise_clip = audio_data[:int(0.25 * target_rate)]
|
|
|
audio_data = nr.reduce_noise(y=audio_data, y_noise=noise_clip, sr=target_rate).astype(np.float32)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
def frame_generator(frame_ms, audio, sample_rate):
|
|
|
n = int(sample_rate * (frame_ms / 1000.0))
|
|
|
if len(audio) < n:
|
|
|
return
|
|
|
offset = 0
|
|
|
while offset + n <= len(audio):
|
|
|
frame = audio[offset:offset+n]
|
|
|
yield (frame * 32767).astype(np.int16).tobytes()
|
|
|
offset += n
|
|
|
|
|
|
frames = list(frame_generator(frame_ms, audio_data, target_rate))
|
|
|
voiced = 0
|
|
|
for f in frames:
|
|
|
try:
|
|
|
if _vad.is_speech(f, target_rate):
|
|
|
voiced += 1
|
|
|
except Exception:
|
|
|
pass
|
|
|
ratio = voiced / max(1, len(frames))
|
|
|
is_speech = ratio >= required_ratio
|
|
|
|
|
|
return audio_data, is_speech
|
|
|
|
|
|
def transcribe_with_nemo(audio_data, source_sample_rate=8000, target_sample_rate=16000):
|
|
|
"""Transcribe audio using NeMo FastConformer"""
|
|
|
try:
|
|
|
if len(audio_data) == 0 or asr_model_nemo is None:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
resampled_audio, has_speech = resample_audio(audio_data, source_sample_rate, target_sample_rate)
|
|
|
|
|
|
if has_speech:
|
|
|
print("Speech detected, sending to ASR...")
|
|
|
|
|
|
min_samples = int(0.3 * target_sample_rate)
|
|
|
if len(resampled_audio) < min_samples:
|
|
|
return ""
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
|
|
sf.write(tmp_file.name, resampled_audio, target_sample_rate)
|
|
|
tmp_path = tmp_file.name
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = asr_model_nemo.transcribe([tmp_path])
|
|
|
|
|
|
if result and len(result) > 0:
|
|
|
|
|
|
if hasattr(result[0], 'text'):
|
|
|
raw_text = result[0].text
|
|
|
elif isinstance(result[0], str):
|
|
|
raw_text = result[0]
|
|
|
else:
|
|
|
raw_text = str(result[0])
|
|
|
|
|
|
if not isinstance(raw_text, str):
|
|
|
raw_text = str(raw_text)
|
|
|
|
|
|
if raw_text and raw_text.strip():
|
|
|
|
|
|
cleaned_text = replace_arabic_numbers_nemo(raw_text)
|
|
|
end_time = time.time()
|
|
|
|
|
|
if cleaned_text.strip():
|
|
|
logger.info(f"NeMo transcription: '{cleaned_text}' (processed in {end_time - start_time:.2f}s)")
|
|
|
|
|
|
return cleaned_text.strip()
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if os.path.exists(tmp_path):
|
|
|
os.remove(tmp_path)
|
|
|
|
|
|
return ""
|
|
|
else:
|
|
|
print("Silence/noise, skipping...")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during NeMo transcription: {e}")
|
|
|
return ""
|
|
|
|
|
|
def transcribe_with_whisper(audio_data, source_sample_rate=8000, target_sample_rate=16000):
|
|
|
"""Transcribe audio chunk using Whisper model directly"""
|
|
|
try:
|
|
|
if len(audio_data) == 0 or whisper_model is None:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
resampled_audio, has_speech = resample_audio(audio_data, source_sample_rate, target_sample_rate)
|
|
|
if has_speech:
|
|
|
print("Speech detected, sending to ASR...")
|
|
|
|
|
|
min_samples = int(0.1 * target_sample_rate)
|
|
|
if len(resampled_audio) < min_samples:
|
|
|
return ""
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
input_features = whisper_processor(
|
|
|
resampled_audio,
|
|
|
sampling_rate=target_sample_rate,
|
|
|
return_tensors="pt"
|
|
|
).input_features
|
|
|
|
|
|
|
|
|
input_features = input_features.to(device=device, dtype=torch_dtype)
|
|
|
|
|
|
|
|
|
attention_mask = torch.ones(
|
|
|
input_features.shape[:-1],
|
|
|
dtype=torch.long,
|
|
|
device=device
|
|
|
)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
predicted_ids = whisper_model.generate(
|
|
|
input_features,
|
|
|
attention_mask=attention_mask,
|
|
|
max_new_tokens=128,
|
|
|
do_sample=False,
|
|
|
|
|
|
num_beams=1,
|
|
|
language="english",
|
|
|
task="translate",
|
|
|
pad_token_id=whisper_tokenizer.pad_token_id,
|
|
|
eos_token_id=whisper_tokenizer.eos_token_id
|
|
|
)
|
|
|
|
|
|
|
|
|
transcription = whisper_tokenizer.batch_decode(
|
|
|
predicted_ids,
|
|
|
skip_special_tokens=True
|
|
|
)[0].strip()
|
|
|
|
|
|
end_time = time.time()
|
|
|
|
|
|
logger.info(f"Whisper transcription completed in {end_time - start_time:.2f}s: '{transcription}'")
|
|
|
return transcription
|
|
|
else:
|
|
|
print("Silence/noise, skipping...")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during Whisper transcription: {e}")
|
|
|
return ""
|
|
|
|
|
|
class UnifiedSTTHandler:
|
|
|
def __init__(self, websocket):
|
|
|
self.websocket = websocket
|
|
|
self.audio_buffer = None
|
|
|
self.config = {}
|
|
|
self.running = False
|
|
|
self.transcription_task = None
|
|
|
self.use_nemo = False
|
|
|
|
|
|
|
|
|
self.interim_count = 0
|
|
|
self.last_interim_time = None
|
|
|
self.silence_timeout = 2.9
|
|
|
self.min_interim_count = 1
|
|
|
self.auto_final_task = None
|
|
|
self.accumulated_transcript = ""
|
|
|
self.final_sent = False
|
|
|
self.segment_number = 0
|
|
|
self.last_partial = ""
|
|
|
|
|
|
|
|
|
self.processing_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
async def add_audio_data(self, audio_bytes):
|
|
|
"""Add audio data to buffer with enhanced debugging"""
|
|
|
if self.audio_buffer and self.running:
|
|
|
audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
|
|
|
self.audio_buffer.add_audio(audio_data)
|
|
|
|
|
|
model_name = "NeMo" if self.use_nemo else "Whisper"
|
|
|
|
|
|
|
|
|
if len(audio_data) > 0:
|
|
|
total_samples = len(self.audio_buffer.get_all_audio())
|
|
|
total_seconds = total_samples / self.config["sample_rate"]
|
|
|
|
|
|
|
|
|
if int(total_seconds) != getattr(self, '_last_logged_second', -1):
|
|
|
logger.info(f"{model_name} - Accumulated {total_seconds:.1f}s of audio ({total_samples} samples)")
|
|
|
self._last_logged_second = int(total_seconds)
|
|
|
|
|
|
|
|
|
chunk_ready = self.audio_buffer.has_chunk_ready()
|
|
|
logger.info(f"{model_name} - Chunk ready: {chunk_ready}")
|
|
|
|
|
|
async def start_processing(self, start_message):
|
|
|
"""Initialize with start message from jambonz"""
|
|
|
self.config = {
|
|
|
"language": start_message.get("language", "ar-EG"),
|
|
|
"format": start_message.get("format", "raw"),
|
|
|
"encoding": start_message.get("encoding", "LINEAR16"),
|
|
|
"sample_rate": start_message.get("sampleRateHz", 8000),
|
|
|
"interim_results": True,
|
|
|
"options": start_message.get("options", {})
|
|
|
}
|
|
|
|
|
|
|
|
|
language = self.config["language"]
|
|
|
if language == "ar-EG":
|
|
|
logger.info("Selected NeMo FastConformer")
|
|
|
self.use_nemo = True
|
|
|
model_name = "NeMo FastConformer"
|
|
|
elif language == "ar-EG-whis":
|
|
|
logger.info("Selected Whisper large-v3")
|
|
|
self.use_nemo = False
|
|
|
model_name = "Whisper large-v3"
|
|
|
else:
|
|
|
|
|
|
self.use_nemo = True
|
|
|
model_name = "NeMo FastConformer (default)"
|
|
|
|
|
|
logger.info(f"STT session started with {model_name} for language: {language}")
|
|
|
logger.info(f"Config: {self.config}")
|
|
|
|
|
|
|
|
|
if self.use_nemo and asr_model_nemo is None:
|
|
|
await self.send_error("NeMo model not available")
|
|
|
return
|
|
|
elif not self.use_nemo and whisper_model is None:
|
|
|
await self.send_error("Whisper model not available")
|
|
|
return
|
|
|
|
|
|
|
|
|
if self.use_nemo:
|
|
|
chunk_duration = 1.0
|
|
|
else:
|
|
|
chunk_duration = 2.0
|
|
|
|
|
|
self.audio_buffer = JambonzAudioBuffer(
|
|
|
sample_rate=self.config["sample_rate"],
|
|
|
chunk_duration=chunk_duration
|
|
|
)
|
|
|
|
|
|
|
|
|
if not self.use_nemo:
|
|
|
self.audio_buffer.silence_threshold = 0.005
|
|
|
|
|
|
|
|
|
self.running = True
|
|
|
self.interim_count = 0
|
|
|
self.last_interim_time = None
|
|
|
self.accumulated_transcript = ""
|
|
|
self.final_sent = False
|
|
|
self.segment_number = 0
|
|
|
self.processing_count = 0
|
|
|
self.last_partial = ""
|
|
|
|
|
|
|
|
|
self.transcription_task = asyncio.create_task(self._process_audio_chunks())
|
|
|
|
|
|
|
|
|
self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final())
|
|
|
|
|
|
logger.info(f"Background tasks started for {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
async def stop_processing(self):
|
|
|
"""Stop current processing session"""
|
|
|
logger.info("Stopping STT session...")
|
|
|
self.running = False
|
|
|
|
|
|
|
|
|
for task in [self.transcription_task, self.auto_final_task]:
|
|
|
if task:
|
|
|
task.cancel()
|
|
|
try:
|
|
|
await task
|
|
|
except asyncio.CancelledError:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if not self.final_sent and self.accumulated_transcript.strip():
|
|
|
await self.send_transcription(self.accumulated_transcript, is_final=True)
|
|
|
|
|
|
|
|
|
if self.audio_buffer:
|
|
|
all_audio = self.audio_buffer.get_all_audio()
|
|
|
if len(all_audio) > 0 and not self.final_sent:
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
if self.use_nemo:
|
|
|
final_transcription = await loop.run_in_executor(
|
|
|
executor, transcribe_with_nemo, all_audio, self.config["sample_rate"]
|
|
|
)
|
|
|
else:
|
|
|
final_transcription = await loop.run_in_executor(
|
|
|
executor, transcribe_with_whisper, all_audio, self.config["sample_rate"]
|
|
|
)
|
|
|
|
|
|
if final_transcription.strip():
|
|
|
await self.send_transcription(final_transcription, is_final=True)
|
|
|
|
|
|
|
|
|
if self.audio_buffer:
|
|
|
self.audio_buffer.clear()
|
|
|
|
|
|
logger.info("STT session stopped")
|
|
|
|
|
|
async def start_new_segment(self):
|
|
|
"""Start a new transcription segment"""
|
|
|
self.segment_number += 1
|
|
|
self.interim_count = 0
|
|
|
self.last_interim_time = None
|
|
|
self.accumulated_transcript = ""
|
|
|
self.final_sent = False
|
|
|
self.last_partial = ""
|
|
|
self.processing_count = 0
|
|
|
|
|
|
if self.audio_buffer:
|
|
|
self.audio_buffer.reset_for_new_segment()
|
|
|
|
|
|
logger.info(f"Started new transcription segment #{self.segment_number}")
|
|
|
|
|
|
async def add_audio_data(self, audio_bytes):
|
|
|
"""Add audio data to buffer"""
|
|
|
if self.audio_buffer and self.running:
|
|
|
audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
|
|
|
self.audio_buffer.add_audio(audio_data)
|
|
|
|
|
|
async def _process_audio_chunks(self):
|
|
|
"""Process audio chunks for interim results - with debugging"""
|
|
|
model_name = "NeMo" if self.use_nemo else "Whisper"
|
|
|
logger.info(f"Starting audio chunk processing for {model_name}")
|
|
|
|
|
|
chunk_count = 0
|
|
|
|
|
|
while self.running:
|
|
|
try:
|
|
|
if self.audio_buffer and self.audio_buffer.has_chunk_ready():
|
|
|
chunk_count += 1
|
|
|
logger.info(f"{model_name} - Processing chunk #{chunk_count}")
|
|
|
|
|
|
chunk_signal = self.audio_buffer.get_chunk_for_processing()
|
|
|
if chunk_signal is not None:
|
|
|
all_audio = self.audio_buffer.get_all_audio()
|
|
|
|
|
|
logger.info(f"{model_name} - Got {len(all_audio)} samples for processing")
|
|
|
|
|
|
if len(all_audio) > 0:
|
|
|
|
|
|
latest_chunk_start = max(0, len(all_audio) - self.audio_buffer.chunk_samples)
|
|
|
latest_chunk = all_audio[latest_chunk_start:]
|
|
|
|
|
|
|
|
|
has_speech = self.audio_buffer.is_speech(latest_chunk)
|
|
|
logger.info(f"{model_name} - Speech detected: {has_speech}")
|
|
|
|
|
|
if has_speech:
|
|
|
logger.info(f"{model_name} - Starting transcription...")
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
start_time = time.time()
|
|
|
|
|
|
try:
|
|
|
|
|
|
if self.use_nemo:
|
|
|
transcription = await loop.run_in_executor(
|
|
|
executor, transcribe_with_nemo, all_audio, self.config["sample_rate"]
|
|
|
)
|
|
|
else:
|
|
|
transcription = await loop.run_in_executor(
|
|
|
executor, transcribe_with_whisper, all_audio, self.config["sample_rate"]
|
|
|
)
|
|
|
|
|
|
process_time = time.time() - start_time
|
|
|
logger.info(f"{model_name} - Transcription completed in {process_time:.2f}s: '{transcription}'")
|
|
|
|
|
|
if transcription and transcription.strip():
|
|
|
self.processing_count += 1
|
|
|
self.accumulated_transcript = transcription
|
|
|
|
|
|
if transcription != self.last_partial or self.interim_count == 0:
|
|
|
self.last_partial = transcription
|
|
|
self.interim_count += 1
|
|
|
self.last_interim_time = time.time()
|
|
|
logger.info(f"{model_name} - Updated interim_count to {self.interim_count}")
|
|
|
else:
|
|
|
self.last_interim_time = time.time()
|
|
|
logger.info(f"{model_name} - Same transcription, updating time only")
|
|
|
else:
|
|
|
logger.info(f"{model_name} - No transcription result")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"{model_name} - Transcription error: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
else:
|
|
|
logger.debug(f"{model_name} - No speech in chunk")
|
|
|
else:
|
|
|
logger.warning(f"{model_name} - Chunk signal was None")
|
|
|
else:
|
|
|
|
|
|
if self.audio_buffer:
|
|
|
current_size = len(self.audio_buffer.buffer)
|
|
|
required_size = self.audio_buffer.chunk_samples
|
|
|
if current_size > 0:
|
|
|
logger.debug(f"{model_name} - Buffer: {current_size}/{required_size} samples")
|
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"{model_name} - Error in chunk processing: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
async def _monitor_for_auto_final(self):
|
|
|
"""Monitor for auto-final conditions with model-specific timeouts"""
|
|
|
model_name = "NeMo" if self.use_nemo else "Whisper"
|
|
|
timeout = 2.0 if self.use_nemo else 3.0
|
|
|
|
|
|
logger.info(f"Starting auto-final monitoring for {model_name} (timeout: {timeout}s)")
|
|
|
|
|
|
while self.running:
|
|
|
try:
|
|
|
current_time = time.time()
|
|
|
|
|
|
if (self.interim_count >= self.min_interim_count and
|
|
|
self.last_interim_time is not None and
|
|
|
(current_time - self.last_interim_time) >= timeout and
|
|
|
not self.final_sent and
|
|
|
self.accumulated_transcript.strip()):
|
|
|
|
|
|
silence_duration = current_time - self.last_interim_time
|
|
|
logger.info(f"Auto-final triggered for segment #{self.segment_number} ({model_name}) - "
|
|
|
f"Interim count: {self.interim_count}, Silence: {silence_duration:.1f}s")
|
|
|
|
|
|
await self.send_transcription(self.accumulated_transcript, is_final=True)
|
|
|
await self.start_new_segment()
|
|
|
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in auto-final monitoring: {e}")
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
|
|
|
|
|
|
|
async def send_transcription(self, text, is_final=True, confidence=0.9):
|
|
|
"""Send transcription in jambonz format"""
|
|
|
try:
|
|
|
|
|
|
if not self.use_nemo and is_final:
|
|
|
original_text = text
|
|
|
converted_text = convert_arabic_numbers_whisper(text)
|
|
|
|
|
|
if original_text != converted_text:
|
|
|
logger.info(f"Whisper - Arabic numbers converted: '{original_text}' -> '{converted_text}'")
|
|
|
text = converted_text
|
|
|
|
|
|
message = {
|
|
|
"type": "transcription",
|
|
|
"is_final": True,
|
|
|
"alternatives": [
|
|
|
{
|
|
|
"transcript": text,
|
|
|
"confidence": confidence
|
|
|
}
|
|
|
],
|
|
|
"language": self.config.get("language", "ar-EG"),
|
|
|
"channel": 1
|
|
|
}
|
|
|
|
|
|
await self.websocket.send(json.dumps(message))
|
|
|
self.final_sent = True
|
|
|
|
|
|
model_name = "NeMo" if self.use_nemo else "Whisper"
|
|
|
logger.info(f"Sent FINAL transcription ({model_name}): '{text}'")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error sending transcription: {e}")
|
|
|
|
|
|
async def send_error(self, error_message):
|
|
|
"""Send error message in jambonz format"""
|
|
|
try:
|
|
|
message = {
|
|
|
"type": "error",
|
|
|
"error": error_message
|
|
|
}
|
|
|
await self.websocket.send(json.dumps(message))
|
|
|
logger.error(f"Sent error: {error_message}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error sending error message: {e}")
|
|
|
|
|
|
async def handle_jambonz_websocket(websocket):
|
|
|
"""Handle jambonz WebSocket connections"""
|
|
|
|
|
|
client_id = f"jambonz_{id(websocket)}"
|
|
|
logger.info(f"New unified STT connection: {client_id}")
|
|
|
|
|
|
handler = UnifiedSTTHandler(websocket)
|
|
|
|
|
|
try:
|
|
|
async for message in websocket:
|
|
|
try:
|
|
|
if isinstance(message, str):
|
|
|
data = json.loads(message)
|
|
|
message_type = data.get("type")
|
|
|
|
|
|
if message_type == "start":
|
|
|
logger.info(f"Received start message: {data}")
|
|
|
await handler.start_processing(data)
|
|
|
|
|
|
elif message_type == "stop":
|
|
|
logger.info("Received stop message - closing WebSocket")
|
|
|
await handler.stop_processing()
|
|
|
await websocket.close(code=1000, reason="Session stopped by client")
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
logger.warning(f"Unknown message type: {message_type}")
|
|
|
await handler.send_error(f"Unknown message type: {message_type}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
if not handler.running or handler.audio_buffer is None:
|
|
|
logger.warning("Received audio data outside of active session")
|
|
|
await handler.send_error("Received audio before start message or after stop")
|
|
|
continue
|
|
|
|
|
|
await handler.add_audio_data(message)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
logger.error(f"JSON decode error: {e}")
|
|
|
await handler.send_error(f"Invalid JSON: {str(e)}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error processing message: {e}")
|
|
|
await handler.send_error(f"Processing error: {str(e)}")
|
|
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
|
logger.info(f"Unified STT connection closed: {client_id}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Unified STT WebSocket error: {e}")
|
|
|
try:
|
|
|
await handler.send_error(str(e))
|
|
|
except:
|
|
|
pass
|
|
|
finally:
|
|
|
if handler.running:
|
|
|
await handler.stop_processing()
|
|
|
logger.info(f"Unified STT connection ended: {client_id}")
|
|
|
|
|
|
async def main():
|
|
|
"""Start the Unified Arabic STT WebSocket server"""
|
|
|
logger.info("Starting Unified Arabic STT WebSocket server on port 3007...")
|
|
|
|
|
|
|
|
|
models_available = []
|
|
|
if asr_model_nemo is not None:
|
|
|
models_available.append("NeMo FastConformer (ar-EG)")
|
|
|
if whisper_model is not None:
|
|
|
models_available.append("Whisper large-v3 (ar-EG-whis)")
|
|
|
|
|
|
if not models_available:
|
|
|
logger.error("No models available! Please check model paths and installations.")
|
|
|
return
|
|
|
|
|
|
|
|
|
server = await websockets.serve(
|
|
|
handle_jambonz_websocket,
|
|
|
"0.0.0.0",
|
|
|
3007,
|
|
|
ping_interval=20,
|
|
|
ping_timeout=10,
|
|
|
close_timeout=10
|
|
|
)
|
|
|
|
|
|
logger.info("Unified Arabic STT WebSocket server started on ws://0.0.0.0:3007")
|
|
|
logger.info("Ready to handle jambonz STT requests with both models")
|
|
|
logger.info("ROUTING:")
|
|
|
logger.info("- language: 'ar-EG' → NeMo FastConformer (with built-in number conversion)")
|
|
|
logger.info("- language: 'ar-EG-whis' → Whisper large-v3 (with pyarabic number conversion)")
|
|
|
logger.info("FEATURES:")
|
|
|
logger.info("- Continuous transcription with segmentation")
|
|
|
logger.info("- Voice Activity Detection")
|
|
|
logger.info("- Auto-final detection (2s silence timeout)")
|
|
|
logger.info("- Model-specific number conversion")
|
|
|
logger.info(f"AVAILABLE MODELS: {', '.join(models_available)}")
|
|
|
|
|
|
|
|
|
await server.wait_closed()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("=" * 80)
|
|
|
print("Unified Arabic STT Server (NeMo + Whisper)")
|
|
|
print("=" * 80)
|
|
|
print("WebSocket Port: 3007")
|
|
|
print("Protocol: jambonz STT API")
|
|
|
print("Audio Format: LINEAR16 PCM @ 8kHz → 16kHz")
|
|
|
print()
|
|
|
print("LANGUAGE ROUTING:")
|
|
|
print("- 'ar-EG' → NeMo FastConformer")
|
|
|
print(" • Built-in Arabic number word to digit conversion")
|
|
|
print(" • Optimized for Arabic dialects")
|
|
|
print("- 'ar-EG-whis' → Whisper large-v3")
|
|
|
print(" • pyarabic library number conversion (final transcripts only)")
|
|
|
print(" • OpenAI Whisper model")
|
|
|
print()
|
|
|
print("FEATURES:")
|
|
|
print("- Automatic model selection based on language parameter")
|
|
|
print("- Voice Activity Detection")
|
|
|
print("- Auto-final detection (2 seconds silence)")
|
|
|
print("- Model-specific number conversion strategies")
|
|
|
print("- Continuous transcription with segmentation")
|
|
|
print()
|
|
|
|
|
|
|
|
|
nemo_status = "✓ Available" if asr_model_nemo is not None else "✗ Not Available"
|
|
|
whisper_status = "✓ Available" if whisper_model is not None else "✗ Not Available"
|
|
|
arabic_numbers_status = "✓ Available" if arabic_numbers_available else "✗ Not Available (install pyarabic)"
|
|
|
|
|
|
print("MODEL STATUS:")
|
|
|
print(f"- NeMo FastConformer: {nemo_status}")
|
|
|
print(f"- Whisper large-v3: {whisper_status}")
|
|
|
print(f"- pyarabic (Whisper numbers): {arabic_numbers_status}")
|
|
|
print("=" * 80)
|
|
|
|
|
|
try:
|
|
|
asyncio.run(main())
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nShutting down unified server...")
|
|
|
except Exception as e:
|
|
|
print(f"Server error: {e}") |