Add voice chat feature with audio storage
Browse files
main.py
CHANGED
|
@@ -15,6 +15,10 @@ import requests
|
|
| 15 |
import soundfile as sf
|
| 16 |
import subprocess
|
| 17 |
import imageio_ffmpeg
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
| 20 |
logger = logging.getLogger(__name__)
|
|
@@ -42,6 +46,51 @@ for _p in ["/tmp/huggingface", "/tmp/models", "/tmp/hf_asr"]:
|
|
| 42 |
|
| 43 |
ASK_URL = os.getenv("ASK_URL", "https://remostart-farmlingua-ai-conversational.hf.space/ask")
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
asr_models = {
|
| 46 |
"ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
|
| 47 |
"yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
|
|
@@ -76,6 +125,13 @@ class SpeakRequest(BaseModel):
|
|
| 76 |
repetition_penalty: float | None = 1.1
|
| 77 |
max_length: int | None = 4000
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def load_audio_tokenizer():
|
| 80 |
global audio_tokenizer
|
| 81 |
|
|
@@ -367,6 +423,147 @@ def _map_lang_code(code: str) -> str:
|
|
| 367 |
m = {"yo": "yoruba", "ha": "hausa", "ig": "igbo", "en": "english"}
|
| 368 |
return m.get(code.lower(), "english")
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
@app.post("/tts")
|
| 371 |
async def text_to_speech(request: TTSRequest):
|
| 372 |
global model, audio_tokenizer
|
|
|
|
| 15 |
import soundfile as sf
|
| 16 |
import subprocess
|
| 17 |
import imageio_ffmpeg
|
| 18 |
+
import uuid
|
| 19 |
+
import time
|
| 20 |
+
import threading
|
| 21 |
+
from pathlib import Path
|
| 22 |
|
| 23 |
logging.basicConfig(level=logging.INFO)
|
| 24 |
logger = logging.getLogger(__name__)
|
|
|
|
| 46 |
|
| 47 |
ASK_URL = os.getenv("ASK_URL", "https://remostart-farmlingua-ai-conversational.hf.space/ask")
|
| 48 |
|
| 49 |
+
AUDIO_STORAGE_DIR = Path("/tmp/voice_chat_audio")
|
| 50 |
+
AUDIO_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
AUDIO_EXPIRY_SECONDS = 3600
|
| 52 |
+
|
| 53 |
+
audio_registry: Dict[str, Dict[str, Any]] = {}
|
| 54 |
+
audio_registry_lock = threading.Lock()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def cleanup_expired_audio():
|
| 58 |
+
now = time.time()
|
| 59 |
+
expired_ids = []
|
| 60 |
+
with audio_registry_lock:
|
| 61 |
+
for audio_id, info in audio_registry.items():
|
| 62 |
+
if now - info["created_at"] > AUDIO_EXPIRY_SECONDS:
|
| 63 |
+
expired_ids.append(audio_id)
|
| 64 |
+
for audio_id in expired_ids:
|
| 65 |
+
info = audio_registry.pop(audio_id, None)
|
| 66 |
+
if info and os.path.exists(info["path"]):
|
| 67 |
+
try:
|
| 68 |
+
os.unlink(info["path"])
|
| 69 |
+
except Exception:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def store_audio(audio_data: bytes, suffix: str = ".wav") -> str:
|
| 74 |
+
cleanup_expired_audio()
|
| 75 |
+
audio_id = str(uuid.uuid4())
|
| 76 |
+
file_path = AUDIO_STORAGE_DIR / f"{audio_id}{suffix}"
|
| 77 |
+
with open(file_path, "wb") as f:
|
| 78 |
+
f.write(audio_data)
|
| 79 |
+
with audio_registry_lock:
|
| 80 |
+
audio_registry[audio_id] = {
|
| 81 |
+
"path": str(file_path),
|
| 82 |
+
"created_at": time.time()
|
| 83 |
+
}
|
| 84 |
+
return audio_id
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_audio_path(audio_id: str) -> Optional[str]:
|
| 88 |
+
with audio_registry_lock:
|
| 89 |
+
info = audio_registry.get(audio_id)
|
| 90 |
+
if info and os.path.exists(info["path"]):
|
| 91 |
+
return info["path"]
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
asr_models = {
|
| 95 |
"ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
|
| 96 |
"yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
|
|
|
|
| 125 |
repetition_penalty: float | None = 1.1
|
| 126 |
max_length: int | None = 4000
|
| 127 |
|
| 128 |
+
|
| 129 |
+
class VoiceChatResponse(BaseModel):
|
| 130 |
+
user_transcription: str
|
| 131 |
+
user_audio_id: str
|
| 132 |
+
ai_response: str
|
| 133 |
+
ai_audio_id: str
|
| 134 |
+
|
| 135 |
def load_audio_tokenizer():
|
| 136 |
global audio_tokenizer
|
| 137 |
|
|
|
|
| 423 |
m = {"yo": "yoruba", "ha": "hausa", "ig": "igbo", "en": "english"}
|
| 424 |
return m.get(code.lower(), "english")
|
| 425 |
|
| 426 |
+
|
| 427 |
+
@app.get("/audio/{audio_id}")
|
| 428 |
+
async def get_audio(audio_id: str):
|
| 429 |
+
file_path = get_audio_path(audio_id)
|
| 430 |
+
if not file_path:
|
| 431 |
+
raise HTTPException(status_code=404, detail="Audio not found or expired")
|
| 432 |
+
return FileResponse(
|
| 433 |
+
file_path,
|
| 434 |
+
media_type="audio/wav",
|
| 435 |
+
filename=f"{audio_id}.wav"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@app.post("/voice-chat", response_model=VoiceChatResponse)
|
| 440 |
+
async def voice_chat(audio_file: UploadFile = File(...), language: str = Form(...)):
|
| 441 |
+
global model, audio_tokenizer
|
| 442 |
+
|
| 443 |
+
if language not in ["yo", "ha", "ig", "en"]:
|
| 444 |
+
raise HTTPException(status_code=400, detail="Language must be one of: yo, ha, ig, en")
|
| 445 |
+
|
| 446 |
+
audio_bytes = await audio_file.read()
|
| 447 |
+
user_audio_id = store_audio(audio_bytes, suffix=".webm")
|
| 448 |
+
|
| 449 |
+
audio_array = _preprocess_audio_ffmpeg(audio_bytes)
|
| 450 |
+
model_asr, proc = _get_asr(language)
|
| 451 |
+
if model_asr is None or proc is None:
|
| 452 |
+
raise HTTPException(status_code=500, detail="ASR model not available")
|
| 453 |
+
|
| 454 |
+
try:
|
| 455 |
+
device_t = next(model_asr.parameters()).device
|
| 456 |
+
inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
|
| 457 |
+
input_features = inputs.input_features.to(device_t)
|
| 458 |
+
with torch.no_grad():
|
| 459 |
+
pred_ids = model_asr.generate(input_features)
|
| 460 |
+
text_list = proc.batch_decode(pred_ids, skip_special_tokens=True)
|
| 461 |
+
user_transcription = text_list[0].strip() if text_list else ""
|
| 462 |
+
except Exception as e:
|
| 463 |
+
logger.error(f"ASR inference failed: {e}")
|
| 464 |
+
raise HTTPException(status_code=500, detail="ASR inference failed")
|
| 465 |
+
|
| 466 |
+
if not user_transcription:
|
| 467 |
+
raise HTTPException(status_code=400, detail="Could not transcribe audio")
|
| 468 |
+
|
| 469 |
+
try:
|
| 470 |
+
ans = requests.post(ASK_URL, json={"query": user_transcription}, timeout=30)
|
| 471 |
+
ans.raise_for_status()
|
| 472 |
+
ai_response = ans.json().get("answer", "")
|
| 473 |
+
if not ai_response:
|
| 474 |
+
ai_response = "I'm sorry, I couldn't generate a response."
|
| 475 |
+
except Exception as e:
|
| 476 |
+
logger.warning(f"Ask failed ({e}); using fallback response")
|
| 477 |
+
ai_response = "I'm sorry, I'm having trouble connecting. Please try again."
|
| 478 |
+
|
| 479 |
+
if model is None:
|
| 480 |
+
logger.info("Loading YarnGPT2 model (lazy loading)...")
|
| 481 |
+
load_model()
|
| 482 |
+
if audio_tokenizer is None:
|
| 483 |
+
logger.info("Loading audio tokenizer (lazy loading)...")
|
| 484 |
+
load_audio_tokenizer()
|
| 485 |
+
|
| 486 |
+
if model is None or audio_tokenizer is None:
|
| 487 |
+
raise HTTPException(status_code=503, detail="TTS model loading failed")
|
| 488 |
+
|
| 489 |
+
tts_language = _map_lang_code(language)
|
| 490 |
+
default_speakers = {
|
| 491 |
+
"english": "idera",
|
| 492 |
+
"yoruba": "yoruba_male2",
|
| 493 |
+
"igbo": "igbo_male2",
|
| 494 |
+
"hausa": "hausa_female1",
|
| 495 |
+
}
|
| 496 |
+
speaker = default_speakers.get(tts_language, "idera")
|
| 497 |
+
|
| 498 |
+
try:
|
| 499 |
+
prompt = audio_tokenizer.create_prompt(
|
| 500 |
+
ai_response,
|
| 501 |
+
lang=tts_language,
|
| 502 |
+
speaker_name=speaker,
|
| 503 |
+
)
|
| 504 |
+
tokenized = audio_tokenizer.tokenize_prompt(prompt)
|
| 505 |
+
if isinstance(tokenized, torch.Tensor):
|
| 506 |
+
input_ids = tokenized
|
| 507 |
+
attention_mask = None
|
| 508 |
+
else:
|
| 509 |
+
input_ids = tokenized.get("input_ids", tokenized)
|
| 510 |
+
attention_mask = tokenized.get("attention_mask", None)
|
| 511 |
+
|
| 512 |
+
if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None:
|
| 513 |
+
audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token
|
| 514 |
+
|
| 515 |
+
with torch.no_grad():
|
| 516 |
+
gen_kwargs = {
|
| 517 |
+
"input_ids": input_ids,
|
| 518 |
+
"repetition_penalty": 1.1,
|
| 519 |
+
"max_length": 4000,
|
| 520 |
+
}
|
| 521 |
+
if attention_mask is not None:
|
| 522 |
+
gen_kwargs["attention_mask"] = attention_mask
|
| 523 |
+
|
| 524 |
+
use_beams = tts_language in ["yoruba", "igbo", "hausa"]
|
| 525 |
+
if use_beams:
|
| 526 |
+
gen_kwargs["num_beams"] = 5
|
| 527 |
+
gen_kwargs["early_stopping"] = False
|
| 528 |
+
else:
|
| 529 |
+
gen_kwargs["do_sample"] = True
|
| 530 |
+
gen_kwargs["temperature"] = 0.1
|
| 531 |
+
|
| 532 |
+
output = model.generate(**gen_kwargs)
|
| 533 |
+
|
| 534 |
+
codes = audio_tokenizer.get_codes(output)
|
| 535 |
+
audio = audio_tokenizer.get_audio(codes)
|
| 536 |
+
|
| 537 |
+
if isinstance(audio, torch.Tensor):
|
| 538 |
+
audio_tensor = audio.detach()
|
| 539 |
+
else:
|
| 540 |
+
audio_tensor = torch.tensor(np.asarray(audio))
|
| 541 |
+
audio_tensor = audio_tensor.to(torch.float32).cpu()
|
| 542 |
+
if audio_tensor.ndim > 1:
|
| 543 |
+
audio_tensor = audio_tensor.squeeze()
|
| 544 |
+
peak = audio_tensor.abs().max()
|
| 545 |
+
if peak > 1.0:
|
| 546 |
+
audio_tensor = audio_tensor / peak
|
| 547 |
+
|
| 548 |
+
buffer = io.BytesIO()
|
| 549 |
+
torchaudio.save(buffer, audio_tensor.unsqueeze(0), 24000, format="wav")
|
| 550 |
+
buffer.seek(0)
|
| 551 |
+
ai_audio_bytes = buffer.read()
|
| 552 |
+
|
| 553 |
+
ai_audio_id = store_audio(ai_audio_bytes, suffix=".wav")
|
| 554 |
+
|
| 555 |
+
except Exception as e:
|
| 556 |
+
logger.error(f"TTS generation failed: {e}")
|
| 557 |
+
raise HTTPException(status_code=500, detail=f"TTS generation failed: {e}")
|
| 558 |
+
|
| 559 |
+
return VoiceChatResponse(
|
| 560 |
+
user_transcription=user_transcription,
|
| 561 |
+
user_audio_id=user_audio_id,
|
| 562 |
+
ai_response=ai_response,
|
| 563 |
+
ai_audio_id=ai_audio_id
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
@app.post("/tts")
|
| 568 |
async def text_to_speech(request: TTSRequest):
|
| 569 |
global model, audio_tokenizer
|