|
|
import os |
|
|
import torch |
|
|
import logging |
|
|
from typing import Optional, Dict, Any |
|
|
from fastapi import FastAPI, HTTPException, status, File, UploadFile, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import FileResponse, StreamingResponse |
|
|
from starlette.background import BackgroundTask |
|
|
from pydantic import BaseModel |
|
|
import torchaudio |
|
|
import io |
|
|
import tempfile |
|
|
import numpy as np |
|
|
import requests |
|
|
import soundfile as sf |
|
|
import subprocess |
|
|
import imageio_ffmpeg |
|
|
import uuid |
|
|
import time |
|
|
import threading |
|
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
|
title="Farmlingua Speech System", |
|
|
description="ASR → Ask → YarnGPT2 TTS with default voices per language", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
for _p in ["/tmp/huggingface", "/tmp/models", "/tmp/hf_asr"]: |
|
|
try: |
|
|
os.makedirs(_p, exist_ok=True) |
|
|
os.chmod(_p, 0o777) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
ASK_URL = os.getenv("ASK_URL", "https://remostart-farmlingua-ai-conversational.hf.space/ask") |
|
|
|
|
|
AUDIO_STORAGE_DIR = Path("/tmp/voice_chat_audio") |
|
|
AUDIO_STORAGE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
AUDIO_EXPIRY_SECONDS = 3600 |
|
|
|
|
|
audio_registry: Dict[str, Dict[str, Any]] = {} |
|
|
audio_registry_lock = threading.Lock() |
|
|
|
|
|
|
|
|
def cleanup_expired_audio(): |
|
|
now = time.time() |
|
|
expired_ids = [] |
|
|
with audio_registry_lock: |
|
|
for audio_id, info in audio_registry.items(): |
|
|
if now - info["created_at"] > AUDIO_EXPIRY_SECONDS: |
|
|
expired_ids.append(audio_id) |
|
|
for audio_id in expired_ids: |
|
|
info = audio_registry.pop(audio_id, None) |
|
|
if info and os.path.exists(info["path"]): |
|
|
try: |
|
|
os.unlink(info["path"]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
def store_audio(audio_data: bytes, suffix: str = ".wav") -> str: |
|
|
cleanup_expired_audio() |
|
|
audio_id = str(uuid.uuid4()) |
|
|
file_path = AUDIO_STORAGE_DIR / f"{audio_id}{suffix}" |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(audio_data) |
|
|
with audio_registry_lock: |
|
|
audio_registry[audio_id] = { |
|
|
"path": str(file_path), |
|
|
"created_at": time.time() |
|
|
} |
|
|
return audio_id |
|
|
|
|
|
|
|
|
def get_audio_path(audio_id: str) -> Optional[str]: |
|
|
with audio_registry_lock: |
|
|
info = audio_registry.get(audio_id) |
|
|
if info and os.path.exists(info["path"]): |
|
|
return info["path"] |
|
|
return None |
|
|
|
|
|
asr_models = { |
|
|
"ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None}, |
|
|
"yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None}, |
|
|
"ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None}, |
|
|
"en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None}, |
|
|
} |
|
|
|
|
|
model = None |
|
|
audio_tokenizer = None |
|
|
device = None |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
class TTSRequest(BaseModel): |
|
|
text: str |
|
|
language: str = "english" |
|
|
speaker_name: str = "idera" |
|
|
temperature: float = 0.1 |
|
|
repetition_penalty: float = 1.1 |
|
|
max_length: int = 4000 |
|
|
|
|
|
class TTSResponse(BaseModel): |
|
|
message: str |
|
|
audio_url: str |
|
|
|
|
|
|
|
|
class SpeakRequest(BaseModel): |
|
|
text: str |
|
|
language: str |
|
|
temperature: float | None = 0.1 |
|
|
repetition_penalty: float | None = 1.1 |
|
|
max_length: int | None = 4000 |
|
|
|
|
|
|
|
|
class VoiceChatResponse(BaseModel): |
|
|
user_transcription: str |
|
|
user_audio_id: str |
|
|
ai_response: str |
|
|
ai_audio_id: str |
|
|
|
|
|
def load_audio_tokenizer(): |
|
|
global audio_tokenizer |
|
|
|
|
|
try: |
|
|
config_paths = [ |
|
|
"./wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml", |
|
|
"./models/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml", |
|
|
"./wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" |
|
|
] |
|
|
|
|
|
model_paths = [ |
|
|
"./wavtokenizer_large_speech_320_24k.ckpt", |
|
|
"./models/wavtokenizer_large_speech_320_24k.ckpt", |
|
|
"./wavtokenizer_large_speech_320_24k.ckpt" |
|
|
] |
|
|
|
|
|
config_path = next((p for p in config_paths if os.path.exists(p)), config_paths[0]) |
|
|
|
|
|
model_path = None |
|
|
for mp in model_paths: |
|
|
if os.path.exists(mp): |
|
|
model_path = mp |
|
|
break |
|
|
|
|
|
if not model_path or not os.path.exists(model_path): |
|
|
logger.warning("Checkpoint file not found, attempting to download...") |
|
|
try: |
|
|
import subprocess, tempfile, shutil as _shutil |
|
|
target_dir = os.environ.get("MODEL_DIR", "/tmp/models") |
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
tmp_file = os.path.join("/tmp", f"wavtokenizer_large_speech_320_24k.ckpt.{os.getpid()}.part") |
|
|
result = subprocess.run([ |
|
|
"gdown", "--fuzzy", "1-ASeEkrn4HY49yZWHTASgfGFNXdVnLTt", |
|
|
"-O", tmp_file |
|
|
], check=False, capture_output=True, text=True, env=os.environ.copy()) |
|
|
final_path = os.path.join(target_dir, "wavtokenizer_large_speech_320_24k.ckpt") |
|
|
if result.returncode == 0 and os.path.exists(tmp_file): |
|
|
_shutil.move(tmp_file, final_path) |
|
|
model_path = final_path |
|
|
logger.info("Checkpoint downloaded successfully") |
|
|
else: |
|
|
model_path = model_paths[0] |
|
|
logger.warning(f"Checkpoint download failed: {result.stderr}, using fallback path") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not download checkpoint: {e}, using fallback path") |
|
|
model_path = model_paths[0] |
|
|
|
|
|
from yarngpt.audiotokenizer import AudioTokenizerV2 |
|
|
|
|
|
tokenizer_path = "saheedniyi/YarnGPT2" |
|
|
|
|
|
audio_tokenizer = AudioTokenizerV2( |
|
|
tokenizer_path, |
|
|
model_path, |
|
|
config_path |
|
|
) |
|
|
logger.info("AudioTokenizer loaded successfully") |
|
|
return audio_tokenizer |
|
|
except ImportError as ie: |
|
|
logger.warning(f"yarngpt package not found: {ie}") |
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
tokenizer_path = "saheedniyi/YarnGPT2" |
|
|
|
|
|
class AudioTokenizerWrapper: |
|
|
def __init__(self, tokenizer_path): |
|
|
self.tokenizer_path = tokenizer_path |
|
|
self.device = device |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
logger.info("Using fallback tokenizer") |
|
|
|
|
|
def create_prompt(self, text, lang="english", speaker_name="idera"): |
|
|
speaker_tag = f"<{speaker_name}>" |
|
|
lang_tag = f"<{lang}>" |
|
|
return f"{speaker_tag}{lang_tag}{text}</s>" |
|
|
|
|
|
def tokenize_prompt(self, prompt): |
|
|
return self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
|
|
|
|
def get_codes(self, output): |
|
|
return output |
|
|
|
|
|
def get_audio(self, codes): |
|
|
import numpy as np |
|
|
sample_rate = 24000 |
|
|
duration = 3.0 |
|
|
audio = np.random.randn(int(duration * sample_rate)).astype(np.float32) |
|
|
return torch.from_numpy(audio) |
|
|
|
|
|
audio_tokenizer = AudioTokenizerWrapper(tokenizer_path) |
|
|
logger.info("Using alternative AudioTokenizer") |
|
|
return audio_tokenizer |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load audio tokenizer: {e}") |
|
|
raise |
|
|
|
|
|
def load_model(): |
|
|
global model |
|
|
|
|
|
try: |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
tokenizer_path = "saheedniyi/YarnGPT2" |
|
|
logger.info("Loading YarnGPT2 model from HuggingFace...") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
tokenizer_path, |
|
|
torch_dtype="auto" |
|
|
).to(device) |
|
|
if model.config.pad_token_id is None and model.config.eos_token_id is not None: |
|
|
model.config.pad_token_id = model.config.eos_token_id |
|
|
|
|
|
logger.info("YarnGPT2 model loaded successfully") |
|
|
return model |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def _get_asr(lang_code: str): |
|
|
try: |
|
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
from huggingface_hub import snapshot_download |
|
|
except Exception as e: |
|
|
logger.error(f"Transformers missing whisper classes: {e}") |
|
|
return None, None |
|
|
|
|
|
entry = asr_models.get(lang_code) |
|
|
if not entry: |
|
|
return None, None |
|
|
if entry["model"] is not None and entry["proc"] is not None: |
|
|
return entry["model"], entry["proc"] |
|
|
repo_id = entry["repo"] |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
try: |
|
|
device_t = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...") |
|
|
|
|
|
safe_name = repo_id.replace('/', '__') |
|
|
local_dir = f"/tmp/hf_asr/{safe_name}" |
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
try: |
|
|
snapshot_download(repo_id=repo_id, token=hf_token, local_dir=local_dir) |
|
|
except Exception as pre_e: |
|
|
logger.warning(f"ASR snapshot prefetch skipped/failed for {repo_id}: {pre_e}") |
|
|
proc = WhisperProcessor.from_pretrained(local_dir, local_files_only=True) |
|
|
model_asr = WhisperForConditionalGeneration.from_pretrained(local_dir, local_files_only=True) |
|
|
model_asr.to(device_t) |
|
|
model_asr.eval() |
|
|
entry["model"], entry["proc"] = model_asr, proc |
|
|
return model_asr, proc |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load ASR for {lang_code}: {e}") |
|
|
entry["model"], entry["proc"] = None, None |
|
|
return None, None |
|
|
|
|
|
|
|
|
def _preprocess_audio_ffmpeg(audio_bytes: bytes, target_sr: int = 16000) -> np.ndarray: |
|
|
try: |
|
|
with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file: |
|
|
in_file.write(audio_bytes) |
|
|
in_path = in_file.name |
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out_file: |
|
|
out_path = out_file.name |
|
|
|
|
|
ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() |
|
|
subprocess.run([ |
|
|
ffmpeg_exe, '-y', '-i', in_path, |
|
|
'-ac', '1', |
|
|
'-ar', str(target_sr), |
|
|
out_path |
|
|
], check=True, capture_output=True) |
|
|
|
|
|
with open(out_path, 'rb') as f: |
|
|
wav_data = f.read() |
|
|
|
|
|
os.unlink(in_path) |
|
|
os.unlink(out_path) |
|
|
|
|
|
audio_array, sr = sf.read(io.BytesIO(wav_data)) |
|
|
if audio_array.ndim > 1: |
|
|
audio_array = np.mean(audio_array, axis=1) |
|
|
if sr != target_sr: |
|
|
ratio = target_sr / sr |
|
|
new_len = int(len(audio_array) * ratio) |
|
|
audio_array = np.interp( |
|
|
np.linspace(0, len(audio_array), new_len), |
|
|
np.arange(len(audio_array)), |
|
|
audio_array |
|
|
) |
|
|
audio_array = np.clip(audio_array, -0.99, 0.99) |
|
|
audio_array = audio_array - float(np.mean(audio_array)) |
|
|
return audio_array.astype(np.float32) |
|
|
except Exception as e: |
|
|
logger.error(f"FFmpeg preprocessing failed: {e}") |
|
|
raise HTTPException(status_code=400, detail="Audio preprocessing failed") |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
logger.info("Server started. Models will be loaded on first request.") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"name": "Farmlingua Speech System", |
|
|
"description": "ASR → Ask → YarnGPT2 TTS with default voices per language", |
|
|
"status": "running" if model is not None else "model_loading_failed", |
|
|
"available_languages": ["english", "yoruba", "igbo", "hausa"], |
|
|
"available_speakers": { |
|
|
"english": ["idera"], |
|
|
"yoruba": ["yoruba_male2"], |
|
|
"igbo": ["igbo_male2"], |
|
|
"hausa": ["hausa_female1"] |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "healthy" if model is not None else "degraded", |
|
|
"device": str(device), |
|
|
"model_loaded": model is not None, |
|
|
"tokenizer_loaded": audio_tokenizer is not None |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/ask") |
|
|
async def ask(query: str = Form(...)): |
|
|
try: |
|
|
resp = requests.post(ASK_URL, json={"query": query}, timeout=30) |
|
|
resp.raise_for_status() |
|
|
return resp.json() |
|
|
except Exception as e: |
|
|
logger.error(f"ASK error: {e}") |
|
|
raise HTTPException(status_code=502, detail="Ask backend error") |
|
|
|
|
|
|
|
|
async def _transcribe_impl(audio_file: UploadFile, language: str): |
|
|
if not audio_file.content_type or not audio_file.content_type.startswith('audio/'): |
|
|
raise HTTPException(status_code=400, detail="File must be an audio file") |
|
|
if language not in ["yo", "ha", "ig", "en"]: |
|
|
raise HTTPException(status_code=400, detail="Language must be one of: yo, ha, ig, en") |
|
|
audio_bytes = await audio_file.read() |
|
|
audio_array = _preprocess_audio_ffmpeg(audio_bytes) |
|
|
model_asr, proc = _get_asr(language) |
|
|
if model_asr is None or proc is None: |
|
|
raise HTTPException(status_code=500, detail="ASR model not available") |
|
|
try: |
|
|
device_t = next(model_asr.parameters()).device |
|
|
inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt") |
|
|
input_features = inputs.input_features.to(device_t) |
|
|
with torch.no_grad(): |
|
|
pred_ids = model_asr.generate(input_features) |
|
|
text_list = proc.batch_decode(pred_ids, skip_special_tokens=True) |
|
|
transcript = text_list[0].strip() if text_list else "" |
|
|
return {"language": language, "transcription": transcript} |
|
|
except Exception as e: |
|
|
logger.error(f"ASR inference failed: {e}") |
|
|
raise HTTPException(status_code=500, detail="ASR inference failed") |
|
|
|
|
|
|
|
|
@app.post("/transcribe") |
|
|
async def transcribe(audio_file: UploadFile = File(...), language: str = Form(...)): |
|
|
return await _transcribe_impl(audio_file, language) |
|
|
|
|
|
|
|
|
@app.post("/speak-ai") |
|
|
async def speak_ai(audio_file: UploadFile = File(...), language: str = Form(...)): |
|
|
tr = await _transcribe_impl(audio_file, language) |
|
|
query = tr.get("transcription", "") |
|
|
if not query: |
|
|
raise HTTPException(status_code=400, detail="No transcription obtained from audio") |
|
|
try: |
|
|
ans = requests.post(ASK_URL, json={"query": query}, timeout=30) |
|
|
ans.raise_for_status() |
|
|
answer_text = ans.json().get("answer", "") |
|
|
if not answer_text: |
|
|
answer_text = query |
|
|
except Exception as e: |
|
|
logger.warning(f"Ask failed ({e}); falling back to transcript") |
|
|
answer_text = query |
|
|
speak_req = SpeakRequest(text=answer_text, language=_map_lang_code(language)) |
|
|
return await speak(speak_req) |
|
|
|
|
|
|
|
|
def _map_lang_code(code: str) -> str: |
|
|
m = {"yo": "yoruba", "ha": "hausa", "ig": "igbo", "en": "english"} |
|
|
return m.get(code.lower(), "english") |
|
|
|
|
|
|
|
|
@app.get("/audio/{audio_id}") |
|
|
async def get_audio(audio_id: str): |
|
|
file_path = get_audio_path(audio_id) |
|
|
if not file_path: |
|
|
raise HTTPException(status_code=404, detail="Audio not found or expired") |
|
|
return FileResponse( |
|
|
file_path, |
|
|
media_type="audio/wav", |
|
|
filename=f"{audio_id}.wav" |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/voice-chat", response_model=VoiceChatResponse) |
|
|
async def voice_chat(audio_file: UploadFile = File(...), language: str = Form(...)): |
|
|
global model, audio_tokenizer |
|
|
|
|
|
if language not in ["yo", "ha", "ig", "en"]: |
|
|
raise HTTPException(status_code=400, detail="Language must be one of: yo, ha, ig, en") |
|
|
|
|
|
audio_bytes = await audio_file.read() |
|
|
user_audio_id = store_audio(audio_bytes, suffix=".webm") |
|
|
|
|
|
audio_array = _preprocess_audio_ffmpeg(audio_bytes) |
|
|
model_asr, proc = _get_asr(language) |
|
|
if model_asr is None or proc is None: |
|
|
raise HTTPException(status_code=500, detail="ASR model not available") |
|
|
|
|
|
try: |
|
|
device_t = next(model_asr.parameters()).device |
|
|
inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt") |
|
|
input_features = inputs.input_features.to(device_t) |
|
|
with torch.no_grad(): |
|
|
pred_ids = model_asr.generate(input_features) |
|
|
text_list = proc.batch_decode(pred_ids, skip_special_tokens=True) |
|
|
user_transcription = text_list[0].strip() if text_list else "" |
|
|
except Exception as e: |
|
|
logger.error(f"ASR inference failed: {e}") |
|
|
raise HTTPException(status_code=500, detail="ASR inference failed") |
|
|
|
|
|
if not user_transcription: |
|
|
raise HTTPException(status_code=400, detail="Could not transcribe audio") |
|
|
|
|
|
try: |
|
|
ans = requests.post(ASK_URL, json={"query": user_transcription}, timeout=30) |
|
|
ans.raise_for_status() |
|
|
ai_response = ans.json().get("answer", "") |
|
|
if not ai_response: |
|
|
ai_response = "I'm sorry, I couldn't generate a response." |
|
|
except Exception as e: |
|
|
logger.warning(f"Ask failed ({e}); using fallback response") |
|
|
ai_response = "I'm sorry, I'm having trouble connecting. Please try again." |
|
|
|
|
|
if model is None: |
|
|
logger.info("Loading YarnGPT2 model (lazy loading)...") |
|
|
load_model() |
|
|
if audio_tokenizer is None: |
|
|
logger.info("Loading audio tokenizer (lazy loading)...") |
|
|
load_audio_tokenizer() |
|
|
|
|
|
if model is None or audio_tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="TTS model loading failed") |
|
|
|
|
|
tts_language = _map_lang_code(language) |
|
|
default_speakers = { |
|
|
"english": "idera", |
|
|
"yoruba": "yoruba_male2", |
|
|
"igbo": "igbo_male2", |
|
|
"hausa": "hausa_female1", |
|
|
} |
|
|
speaker = default_speakers.get(tts_language, "idera") |
|
|
|
|
|
try: |
|
|
prompt = audio_tokenizer.create_prompt( |
|
|
ai_response, |
|
|
lang=tts_language, |
|
|
speaker_name=speaker, |
|
|
) |
|
|
tokenized = audio_tokenizer.tokenize_prompt(prompt) |
|
|
if isinstance(tokenized, torch.Tensor): |
|
|
input_ids = tokenized |
|
|
attention_mask = None |
|
|
else: |
|
|
input_ids = tokenized.get("input_ids", tokenized) |
|
|
attention_mask = tokenized.get("attention_mask", None) |
|
|
|
|
|
if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None: |
|
|
audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token |
|
|
|
|
|
with torch.no_grad(): |
|
|
gen_kwargs = { |
|
|
"input_ids": input_ids, |
|
|
"repetition_penalty": 1.1, |
|
|
"max_length": 4000, |
|
|
} |
|
|
if attention_mask is not None: |
|
|
gen_kwargs["attention_mask"] = attention_mask |
|
|
|
|
|
use_beams = tts_language in ["yoruba", "igbo", "hausa"] |
|
|
if use_beams: |
|
|
gen_kwargs["num_beams"] = 5 |
|
|
gen_kwargs["early_stopping"] = False |
|
|
else: |
|
|
gen_kwargs["do_sample"] = True |
|
|
gen_kwargs["temperature"] = 0.1 |
|
|
|
|
|
output = model.generate(**gen_kwargs) |
|
|
|
|
|
codes = audio_tokenizer.get_codes(output) |
|
|
audio = audio_tokenizer.get_audio(codes) |
|
|
|
|
|
if isinstance(audio, torch.Tensor): |
|
|
audio_tensor = audio.detach() |
|
|
else: |
|
|
audio_tensor = torch.tensor(np.asarray(audio)) |
|
|
audio_tensor = audio_tensor.to(torch.float32).cpu() |
|
|
if audio_tensor.ndim > 1: |
|
|
audio_tensor = audio_tensor.squeeze() |
|
|
peak = audio_tensor.abs().max() |
|
|
if peak > 1.0: |
|
|
audio_tensor = audio_tensor / peak |
|
|
|
|
|
buffer = io.BytesIO() |
|
|
torchaudio.save(buffer, audio_tensor.unsqueeze(0), 24000, format="wav") |
|
|
buffer.seek(0) |
|
|
ai_audio_bytes = buffer.read() |
|
|
|
|
|
ai_audio_id = store_audio(ai_audio_bytes, suffix=".wav") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"TTS generation failed: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"TTS generation failed: {e}") |
|
|
|
|
|
return VoiceChatResponse( |
|
|
user_transcription=user_transcription, |
|
|
user_audio_id=user_audio_id, |
|
|
ai_response=ai_response, |
|
|
ai_audio_id=ai_audio_id |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/tts") |
|
|
async def text_to_speech(request: TTSRequest): |
|
|
global model, audio_tokenizer |
|
|
if model is None: |
|
|
logger.info("Loading YarnGPT2 model (lazy loading)...") |
|
|
load_model() |
|
|
if audio_tokenizer is None: |
|
|
logger.info("Loading audio tokenizer (lazy loading)...") |
|
|
load_audio_tokenizer() |
|
|
|
|
|
if model is None or audio_tokenizer is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
|
detail="Model loading failed. Please check logs." |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/speak") |
|
|
async def speak(request: SpeakRequest): |
|
|
global model, audio_tokenizer |
|
|
if model is None: |
|
|
logger.info("Loading YarnGPT2 model (lazy loading)...") |
|
|
load_model() |
|
|
if audio_tokenizer is None: |
|
|
logger.info("Loading audio tokenizer (lazy loading)...") |
|
|
load_audio_tokenizer() |
|
|
|
|
|
if model is None or audio_tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model loading failed. Please check logs.") |
|
|
|
|
|
default_speakers = { |
|
|
"english": "idera", |
|
|
"yoruba": "yoruba_male2", |
|
|
"igbo": "igbo_male2", |
|
|
"hausa": "hausa_female1", |
|
|
} |
|
|
language = request.language.lower().strip() |
|
|
speaker = default_speakers.get(language, "idera") |
|
|
|
|
|
try: |
|
|
prompt = audio_tokenizer.create_prompt( |
|
|
request.text, |
|
|
lang=language, |
|
|
speaker_name=speaker, |
|
|
) |
|
|
tokenized = audio_tokenizer.tokenize_prompt(prompt) |
|
|
if isinstance(tokenized, torch.Tensor): |
|
|
input_ids = tokenized |
|
|
attention_mask = None |
|
|
else: |
|
|
input_ids = tokenized.get("input_ids", tokenized) |
|
|
attention_mask = tokenized.get("attention_mask", None) |
|
|
|
|
|
if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None: |
|
|
audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token |
|
|
|
|
|
with torch.no_grad(): |
|
|
gen_kwargs = { |
|
|
"input_ids": input_ids, |
|
|
"repetition_penalty": request.repetition_penalty or 1.1, |
|
|
"max_length": request.max_length or 4000, |
|
|
} |
|
|
if attention_mask is not None: |
|
|
gen_kwargs["attention_mask"] = attention_mask |
|
|
|
|
|
use_beams = language in ["yoruba", "igbo", "hausa"] |
|
|
if use_beams: |
|
|
gen_kwargs["num_beams"] = 5 |
|
|
gen_kwargs["early_stopping"] = False |
|
|
else: |
|
|
temp = request.temperature or 0.1 |
|
|
if temp > 0: |
|
|
gen_kwargs["do_sample"] = True |
|
|
gen_kwargs["temperature"] = temp |
|
|
output = model.generate(**gen_kwargs) |
|
|
logger.info(f"Generated output length: {output.shape[1]}, input length: {input_ids.shape[1]}, generated tokens: {output.shape[1] - input_ids.shape[1]}") |
|
|
|
|
|
codes = audio_tokenizer.get_codes(output) |
|
|
logger.info(f"Extracted {len(codes)} audio codes") |
|
|
audio = audio_tokenizer.get_audio(codes) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') |
|
|
if isinstance(audio, torch.Tensor): |
|
|
audio_tensor = audio.detach() |
|
|
else: |
|
|
audio_tensor = torch.tensor(np.asarray(audio)) |
|
|
audio_tensor = audio_tensor.to(torch.float32).cpu() |
|
|
if audio_tensor.ndim > 1: |
|
|
audio_tensor = audio_tensor.squeeze() |
|
|
peak = audio_tensor.abs().max() |
|
|
if peak > 1.0: |
|
|
audio_tensor = audio_tensor / peak |
|
|
torchaudio.save(temp_file.name, audio_tensor.unsqueeze(0), 24000) |
|
|
|
|
|
return FileResponse( |
|
|
temp_file.name, |
|
|
media_type="audio/wav", |
|
|
filename="speech.wav", |
|
|
background=BackgroundTask(lambda: os.path.exists(temp_file.name) and os.unlink(temp_file.name)) |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Speak error: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Speak failed: {e}") |
|
|
|
|
|
@app.post("/tts-stream") |
|
|
async def text_to_speech_stream(request: TTSRequest): |
|
|
global model, audio_tokenizer |
|
|
if model is None: |
|
|
logger.info("Loading YarnGPT2 model (lazy loading)...") |
|
|
load_model() |
|
|
if audio_tokenizer is None: |
|
|
logger.info("Loading audio tokenizer (lazy loading)...") |
|
|
load_audio_tokenizer() |
|
|
|
|
|
if model is None or audio_tokenizer is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
|
detail="Model loading failed. Please check logs." |
|
|
) |
|
|
|
|
|
try: |
|
|
default_speakers = { |
|
|
"english": "idera", |
|
|
"yoruba": "yoruba_male2", |
|
|
"igbo": "igbo_male2", |
|
|
"hausa": "hausa_female1", |
|
|
} |
|
|
lang_norm = request.language.lower().strip() |
|
|
spk = default_speakers.get(lang_norm, "idera") |
|
|
prompt = audio_tokenizer.create_prompt(request.text, lang=lang_norm, speaker_name=spk) |
|
|
|
|
|
tokenized = audio_tokenizer.tokenize_prompt(prompt) |
|
|
if isinstance(tokenized, torch.Tensor): |
|
|
input_ids = tokenized |
|
|
attention_mask = None |
|
|
else: |
|
|
input_ids = tokenized.get("input_ids", tokenized) |
|
|
attention_mask = tokenized.get("attention_mask", None) |
|
|
|
|
|
if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None: |
|
|
audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token |
|
|
|
|
|
logger.info(f"Generating speech (streaming) for text: {request.text[:50]}...") |
|
|
with torch.no_grad(): |
|
|
gen_kwargs = { |
|
|
"input_ids": input_ids, |
|
|
"repetition_penalty": request.repetition_penalty or 1.1, |
|
|
"max_length": request.max_length or 4000, |
|
|
} |
|
|
if attention_mask is not None: |
|
|
gen_kwargs["attention_mask"] = attention_mask |
|
|
|
|
|
use_beams = lang_norm in ["yoruba", "igbo", "hausa"] |
|
|
if use_beams: |
|
|
gen_kwargs["num_beams"] = 5 |
|
|
gen_kwargs["early_stopping"] = False |
|
|
else: |
|
|
temp = request.temperature or 0.1 |
|
|
if temp > 0: |
|
|
gen_kwargs["do_sample"] = True |
|
|
gen_kwargs["temperature"] = temp |
|
|
output = model.generate(**gen_kwargs) |
|
|
|
|
|
codes = audio_tokenizer.get_codes(output) |
|
|
audio = audio_tokenizer.get_audio(codes) |
|
|
|
|
|
buffer = io.BytesIO() |
|
|
if isinstance(audio, torch.Tensor): |
|
|
audio_tensor = audio.detach() |
|
|
else: |
|
|
audio_tensor = torch.tensor(np.asarray(audio)) |
|
|
audio_tensor = audio_tensor.to(torch.float32).cpu() |
|
|
if audio_tensor.ndim > 1: |
|
|
audio_tensor = audio_tensor.squeeze() |
|
|
peak = audio_tensor.abs().max() |
|
|
if peak > 1.0: |
|
|
audio_tensor = audio_tensor / peak |
|
|
torchaudio.save(buffer, audio_tensor.unsqueeze(0), 24000, format="wav") |
|
|
buffer.seek(0) |
|
|
|
|
|
return StreamingResponse( |
|
|
buffer, |
|
|
media_type="audio/wav", |
|
|
headers={"Content-Disposition": "attachment; filename=speech.wav"}, |
|
|
background=BackgroundTask(buffer.close) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating speech: {e}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Failed to generate speech: {str(e)}" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |
|
|
|
|
|
|