nexusbert commited on
Commit
4d777b5
·
1 Parent(s): 0139faa

Add voice chat feature with audio storage

Browse files
Files changed (1) hide show
  1. main.py +197 -0
main.py CHANGED
@@ -15,6 +15,10 @@ import requests
15
  import soundfile as sf
16
  import subprocess
17
  import imageio_ffmpeg
 
 
 
 
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
@@ -42,6 +46,51 @@ for _p in ["/tmp/huggingface", "/tmp/models", "/tmp/hf_asr"]:
42
 
43
  ASK_URL = os.getenv("ASK_URL", "https://remostart-farmlingua-ai-conversational.hf.space/ask")
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  asr_models = {
46
  "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
47
  "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
@@ -76,6 +125,13 @@ class SpeakRequest(BaseModel):
76
  repetition_penalty: float | None = 1.1
77
  max_length: int | None = 4000
78
 
 
 
 
 
 
 
 
79
  def load_audio_tokenizer():
80
  global audio_tokenizer
81
 
@@ -367,6 +423,147 @@ def _map_lang_code(code: str) -> str:
367
  m = {"yo": "yoruba", "ha": "hausa", "ig": "igbo", "en": "english"}
368
  return m.get(code.lower(), "english")
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  @app.post("/tts")
371
  async def text_to_speech(request: TTSRequest):
372
  global model, audio_tokenizer
 
15
  import soundfile as sf
16
  import subprocess
17
  import imageio_ffmpeg
18
+ import uuid
19
+ import time
20
+ import threading
21
+ from pathlib import Path
22
 
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
 
46
 
47
  ASK_URL = os.getenv("ASK_URL", "https://remostart-farmlingua-ai-conversational.hf.space/ask")
48
 
49
+ AUDIO_STORAGE_DIR = Path("/tmp/voice_chat_audio")
50
+ AUDIO_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
51
+ AUDIO_EXPIRY_SECONDS = 3600
52
+
53
+ audio_registry: Dict[str, Dict[str, Any]] = {}
54
+ audio_registry_lock = threading.Lock()
55
+
56
+
57
+ def cleanup_expired_audio():
58
+ now = time.time()
59
+ expired_ids = []
60
+ with audio_registry_lock:
61
+ for audio_id, info in audio_registry.items():
62
+ if now - info["created_at"] > AUDIO_EXPIRY_SECONDS:
63
+ expired_ids.append(audio_id)
64
+ for audio_id in expired_ids:
65
+ info = audio_registry.pop(audio_id, None)
66
+ if info and os.path.exists(info["path"]):
67
+ try:
68
+ os.unlink(info["path"])
69
+ except Exception:
70
+ pass
71
+
72
+
73
+ def store_audio(audio_data: bytes, suffix: str = ".wav") -> str:
74
+ cleanup_expired_audio()
75
+ audio_id = str(uuid.uuid4())
76
+ file_path = AUDIO_STORAGE_DIR / f"{audio_id}{suffix}"
77
+ with open(file_path, "wb") as f:
78
+ f.write(audio_data)
79
+ with audio_registry_lock:
80
+ audio_registry[audio_id] = {
81
+ "path": str(file_path),
82
+ "created_at": time.time()
83
+ }
84
+ return audio_id
85
+
86
+
87
+ def get_audio_path(audio_id: str) -> Optional[str]:
88
+ with audio_registry_lock:
89
+ info = audio_registry.get(audio_id)
90
+ if info and os.path.exists(info["path"]):
91
+ return info["path"]
92
+ return None
93
+
94
  asr_models = {
95
  "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
96
  "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
 
125
  repetition_penalty: float | None = 1.1
126
  max_length: int | None = 4000
127
 
128
+
129
+ class VoiceChatResponse(BaseModel):
130
+ user_transcription: str
131
+ user_audio_id: str
132
+ ai_response: str
133
+ ai_audio_id: str
134
+
135
  def load_audio_tokenizer():
136
  global audio_tokenizer
137
 
 
423
  m = {"yo": "yoruba", "ha": "hausa", "ig": "igbo", "en": "english"}
424
  return m.get(code.lower(), "english")
425
 
426
+
427
+ @app.get("/audio/{audio_id}")
428
+ async def get_audio(audio_id: str):
429
+ file_path = get_audio_path(audio_id)
430
+ if not file_path:
431
+ raise HTTPException(status_code=404, detail="Audio not found or expired")
432
+ return FileResponse(
433
+ file_path,
434
+ media_type="audio/wav",
435
+ filename=f"{audio_id}.wav"
436
+ )
437
+
438
+
439
+ @app.post("/voice-chat", response_model=VoiceChatResponse)
440
+ async def voice_chat(audio_file: UploadFile = File(...), language: str = Form(...)):
441
+ global model, audio_tokenizer
442
+
443
+ if language not in ["yo", "ha", "ig", "en"]:
444
+ raise HTTPException(status_code=400, detail="Language must be one of: yo, ha, ig, en")
445
+
446
+ audio_bytes = await audio_file.read()
447
+ user_audio_id = store_audio(audio_bytes, suffix=".webm")
448
+
449
+ audio_array = _preprocess_audio_ffmpeg(audio_bytes)
450
+ model_asr, proc = _get_asr(language)
451
+ if model_asr is None or proc is None:
452
+ raise HTTPException(status_code=500, detail="ASR model not available")
453
+
454
+ try:
455
+ device_t = next(model_asr.parameters()).device
456
+ inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
457
+ input_features = inputs.input_features.to(device_t)
458
+ with torch.no_grad():
459
+ pred_ids = model_asr.generate(input_features)
460
+ text_list = proc.batch_decode(pred_ids, skip_special_tokens=True)
461
+ user_transcription = text_list[0].strip() if text_list else ""
462
+ except Exception as e:
463
+ logger.error(f"ASR inference failed: {e}")
464
+ raise HTTPException(status_code=500, detail="ASR inference failed")
465
+
466
+ if not user_transcription:
467
+ raise HTTPException(status_code=400, detail="Could not transcribe audio")
468
+
469
+ try:
470
+ ans = requests.post(ASK_URL, json={"query": user_transcription}, timeout=30)
471
+ ans.raise_for_status()
472
+ ai_response = ans.json().get("answer", "")
473
+ if not ai_response:
474
+ ai_response = "I'm sorry, I couldn't generate a response."
475
+ except Exception as e:
476
+ logger.warning(f"Ask failed ({e}); using fallback response")
477
+ ai_response = "I'm sorry, I'm having trouble connecting. Please try again."
478
+
479
+ if model is None:
480
+ logger.info("Loading YarnGPT2 model (lazy loading)...")
481
+ load_model()
482
+ if audio_tokenizer is None:
483
+ logger.info("Loading audio tokenizer (lazy loading)...")
484
+ load_audio_tokenizer()
485
+
486
+ if model is None or audio_tokenizer is None:
487
+ raise HTTPException(status_code=503, detail="TTS model loading failed")
488
+
489
+ tts_language = _map_lang_code(language)
490
+ default_speakers = {
491
+ "english": "idera",
492
+ "yoruba": "yoruba_male2",
493
+ "igbo": "igbo_male2",
494
+ "hausa": "hausa_female1",
495
+ }
496
+ speaker = default_speakers.get(tts_language, "idera")
497
+
498
+ try:
499
+ prompt = audio_tokenizer.create_prompt(
500
+ ai_response,
501
+ lang=tts_language,
502
+ speaker_name=speaker,
503
+ )
504
+ tokenized = audio_tokenizer.tokenize_prompt(prompt)
505
+ if isinstance(tokenized, torch.Tensor):
506
+ input_ids = tokenized
507
+ attention_mask = None
508
+ else:
509
+ input_ids = tokenized.get("input_ids", tokenized)
510
+ attention_mask = tokenized.get("attention_mask", None)
511
+
512
+ if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None:
513
+ audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token
514
+
515
+ with torch.no_grad():
516
+ gen_kwargs = {
517
+ "input_ids": input_ids,
518
+ "repetition_penalty": 1.1,
519
+ "max_length": 4000,
520
+ }
521
+ if attention_mask is not None:
522
+ gen_kwargs["attention_mask"] = attention_mask
523
+
524
+ use_beams = tts_language in ["yoruba", "igbo", "hausa"]
525
+ if use_beams:
526
+ gen_kwargs["num_beams"] = 5
527
+ gen_kwargs["early_stopping"] = False
528
+ else:
529
+ gen_kwargs["do_sample"] = True
530
+ gen_kwargs["temperature"] = 0.1
531
+
532
+ output = model.generate(**gen_kwargs)
533
+
534
+ codes = audio_tokenizer.get_codes(output)
535
+ audio = audio_tokenizer.get_audio(codes)
536
+
537
+ if isinstance(audio, torch.Tensor):
538
+ audio_tensor = audio.detach()
539
+ else:
540
+ audio_tensor = torch.tensor(np.asarray(audio))
541
+ audio_tensor = audio_tensor.to(torch.float32).cpu()
542
+ if audio_tensor.ndim > 1:
543
+ audio_tensor = audio_tensor.squeeze()
544
+ peak = audio_tensor.abs().max()
545
+ if peak > 1.0:
546
+ audio_tensor = audio_tensor / peak
547
+
548
+ buffer = io.BytesIO()
549
+ torchaudio.save(buffer, audio_tensor.unsqueeze(0), 24000, format="wav")
550
+ buffer.seek(0)
551
+ ai_audio_bytes = buffer.read()
552
+
553
+ ai_audio_id = store_audio(ai_audio_bytes, suffix=".wav")
554
+
555
+ except Exception as e:
556
+ logger.error(f"TTS generation failed: {e}")
557
+ raise HTTPException(status_code=500, detail=f"TTS generation failed: {e}")
558
+
559
+ return VoiceChatResponse(
560
+ user_transcription=user_transcription,
561
+ user_audio_id=user_audio_id,
562
+ ai_response=ai_response,
563
+ ai_audio_id=ai_audio_id
564
+ )
565
+
566
+
567
  @app.post("/tts")
568
  async def text_to_speech(request: TTSRequest):
569
  global model, audio_tokenizer