chatbot / speech_io.py
Nguyen5's picture
commit
b2fa85d
import numpy as np
import soundfile as sf
import librosa
from transformers import pipeline
ASR_MODEL_ID = "openai/whisper-small" # multilingual
TTS_MODEL_ID = "facebook/mms-tts-deu" # bạn có thể thay nếu muốn đa ngôn ngữ
_asr = None
_tts = None
# ============================================
# LOAD AUDIO – chuẩn hóa 16kHz mono
# ============================================
def load_audio_16k(path):
audio, sr = sf.read(path)
# Stereo → Mono
if audio.ndim > 1:
audio = audio.mean(axis=1)
# Resample → 16kHz
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
sr = 16000
return audio.astype(np.float32), sr
# ============================================
# LOAD WHISPER PIPELINE (multilingual)
# ============================================
def get_asr_pipeline():
global _asr
if _asr is None:
_asr = pipeline(
task="automatic-speech-recognition",
model=ASR_MODEL_ID,
return_timestamps=False,
chunk_length_s=30,
)
return _asr
# ============================================
# MULTILINGUAL STT
# ============================================
def transcribe_audio(audio_path: str) -> str:
if audio_path is None:
return ""
audio, sr = load_audio_16k(audio_path)
# Nếu quá ngắn → Whisper sẽ sinh ký tự rác
if len(audio) < sr * 0.4:
return ""
asr = get_asr_pipeline()
# Không đặt language → Whisper tự detect ngôn ngữ
result = asr(
{"array": audio, "sampling_rate": sr},
generate_kwargs={
"task": "transcribe", # không translate — giữ nguyên ngôn ngữ gốc
"temperature": 0.0 # giảm hallucination như "ვვვ..."
}
)
text = result.get("text", "").strip()
# Fix edge case: nếu Whisper trả về ký tự vô nghĩa → bỏ qua
if set(text) <= {"ვ", " "}:
return ""
return text
# ============================================
# TEXT → SPEECH (chưa multilingual)
# ============================================
def get_tts_pipeline():
global _tts
if _tts is None:
_tts = pipeline(task="text-to-speech", model=TTS_MODEL_ID)
return _tts
def synthesize_speech(text: str):
if not text.strip():
return None
tts = get_tts_pipeline()
out = tts(text)
audio = np.array(out["audio"], dtype=np.float32)
sr = out.get("sampling_rate", 16000)
max_val = np.max(np.abs(audio)) or 1.0
audio = audio / max_val
return sr, (audio * 32767).astype(np.int16)