Spaces:
Running
Running
| import os | |
| import torch | |
| import logging | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, status, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from starlette.background import BackgroundTask | |
| from pydantic import BaseModel | |
| import torchaudio | |
| import io | |
| import tempfile | |
| import numpy as np | |
| import requests | |
| import soundfile as sf | |
| import subprocess | |
| import imageio_ffmpeg | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="Multi-speech system", | |
| description="ASR β Ask β TTS", | |
| version="0.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| for _p in ["/tmp/huggingface", "/tmp/models", "/tmp/hf_asr"]: | |
| try: | |
| os.makedirs(_p, exist_ok=True) | |
| os.chmod(_p, 0o777) | |
| except Exception: | |
| pass | |
| ASK_URL = os.getenv("ASK_URL", "put your text generation endpoint here") | |
| asr_models = { | |
| "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None}, | |
| "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None}, | |
| "ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None}, | |
| "en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None}, | |
| } | |
| model = None | |
| audio_tokenizer = None | |
| device = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: str = "english" | |
| speaker_name: str = "idera" | |
| temperature: float = 0.1 | |
| repetition_penalty: float = 1.1 | |
| max_length: int = 4000 | |
| class TTSResponse(BaseModel): | |
| message: str | |
| audio_url: str | |
| class SpeakRequest(BaseModel): | |
| text: str | |
| language: str | |
| temperature: float | None = 0.1 | |
| repetition_penalty: float | None = 1.1 | |
| max_length: int | None = 4000 | |
| def load_audio_tokenizer(): | |
| global audio_tokenizer | |
| try: | |
| config_paths = [ | |
| "./wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml", | |
| "./models/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml", | |
| "./wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
| ] | |
| model_paths = [ | |
| "./wavtokenizer_large_speech_320_24k.ckpt", | |
| "./models/wavtokenizer_large_speech_320_24k.ckpt", | |
| "./wavtokenizer_large_speech_320_24k.ckpt" | |
| ] | |
| config_path = next((p for p in config_paths if os.path.exists(p)), config_paths[0]) | |
| model_path = None | |
| for mp in model_paths: | |
| if os.path.exists(mp): | |
| model_path = mp | |
| break | |
| if not model_path or not os.path.exists(model_path): | |
| logger.warning("Checkpoint file not found, attempting to download...") | |
| try: | |
| import subprocess, tempfile, shutil as _shutil | |
| target_dir = os.environ.get("MODEL_DIR", "/tmp/models") | |
| os.makedirs(target_dir, exist_ok=True) | |
| tmp_file = os.path.join("/tmp", f"wavtokenizer_large_speech_320_24k.ckpt.{os.getpid()}.part") | |
| result = subprocess.run([ | |
| "gdown", "--fuzzy", "1-ASeEkrn4HY49yZWHTASgfGFNXdVnLTt", | |
| "-O", tmp_file | |
| ], check=False, capture_output=True, text=True, env=os.environ.copy()) | |
| final_path = os.path.join(target_dir, "wavtokenizer_large_speech_320_24k.ckpt") | |
| if result.returncode == 0 and os.path.exists(tmp_file): | |
| _shutil.move(tmp_file, final_path) | |
| model_path = final_path | |
| logger.info("Checkpoint downloaded successfully") | |
| else: | |
| model_path = model_paths[0] | |
| logger.warning(f"Checkpoint download failed: {result.stderr}, using fallback path") | |
| except Exception as e: | |
| logger.warning(f"Could not download checkpoint: {e}, using fallback path") | |
| model_path = model_paths[0] | |
| from yarngpt.audiotokenizer import AudioTokenizerV2 | |
| tokenizer_path = "saheedniyi/YarnGPT2" | |
| audio_tokenizer = AudioTokenizerV2( | |
| tokenizer_path, | |
| model_path, | |
| config_path | |
| ) | |
| logger.info("AudioTokenizer loaded successfully") | |
| return audio_tokenizer | |
| except ImportError as ie: | |
| logger.warning(f"yarngpt package not found: {ie}") | |
| try: | |
| from transformers import AutoTokenizer | |
| tokenizer_path = "saheedniyi/YarnGPT2" | |
| class AudioTokenizerWrapper: | |
| def __init__(self, tokenizer_path): | |
| self.tokenizer_path = tokenizer_path | |
| self.device = device | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
| logger.info("Using fallback tokenizer") | |
| def create_prompt(self, text, lang="english", speaker_name="idera"): | |
| speaker_tag = f"<{speaker_name}>" | |
| lang_tag = f"<{lang}>" | |
| return f"{speaker_tag}{lang_tag}{text}</s>" | |
| def tokenize_prompt(self, prompt): | |
| return self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| def get_codes(self, output): | |
| return output | |
| def get_audio(self, codes): | |
| import numpy as np | |
| sample_rate = 24000 | |
| duration = 3.0 | |
| audio = np.random.randn(int(duration * sample_rate)).astype(np.float32) | |
| return torch.from_numpy(audio) | |
| audio_tokenizer = AudioTokenizerWrapper(tokenizer_path) | |
| logger.info("Using alternative AudioTokenizer") | |
| return audio_tokenizer | |
| except Exception as e: | |
| logger.error(f"Failed to load audio tokenizer: {e}") | |
| raise | |
| def load_model(): | |
| global model | |
| try: | |
| from transformers import AutoModelForCausalLM | |
| tokenizer_path = "saheedniyi/YarnGPT2" | |
| logger.info("Loading YarnGPT2 model from HuggingFace...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| tokenizer_path, | |
| torch_dtype="auto" | |
| ).to(device) | |
| if model.config.pad_token_id is None and model.config.eos_token_id is not None: | |
| model.config.pad_token_id = model.config.eos_token_id | |
| logger.info("YarnGPT2 model loaded successfully") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def _get_asr(lang_code: str): | |
| try: | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| from huggingface_hub import snapshot_download | |
| except Exception as e: | |
| logger.error(f"Transformers missing whisper classes: {e}") | |
| return None, None | |
| entry = asr_models.get(lang_code) | |
| if not entry: | |
| return None, None | |
| if entry["model"] is not None and entry["proc"] is not None: | |
| return entry["model"], entry["proc"] | |
| repo_id = entry["repo"] | |
| hf_token = os.getenv("HF_TOKEN") | |
| try: | |
| device_t = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...") | |
| safe_name = repo_id.replace('/', '__') | |
| local_dir = f"/tmp/hf_asr/{safe_name}" | |
| os.makedirs(local_dir, exist_ok=True) | |
| try: | |
| snapshot_download(repo_id=repo_id, token=hf_token, local_dir=local_dir) | |
| except Exception as pre_e: | |
| logger.warning(f"ASR snapshot prefetch skipped/failed for {repo_id}: {pre_e}") | |
| proc = WhisperProcessor.from_pretrained(local_dir, local_files_only=True) | |
| model_asr = WhisperForConditionalGeneration.from_pretrained(local_dir, local_files_only=True) | |
| model_asr.to(device_t) | |
| model_asr.eval() | |
| entry["model"], entry["proc"] = model_asr, proc | |
| return model_asr, proc | |
| except Exception as e: | |
| logger.error(f"Failed to load ASR for {lang_code}: {e}") | |
| entry["model"], entry["proc"] = None, None | |
| return None, None | |
| def _preprocess_audio_ffmpeg(audio_bytes: bytes, target_sr: int = 16000) -> np.ndarray: | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file: | |
| in_file.write(audio_bytes) | |
| in_path = in_file.name | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out_file: | |
| out_path = out_file.name | |
| ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() | |
| subprocess.run([ | |
| ffmpeg_exe, '-y', '-i', in_path, | |
| '-ac', '1', | |
| '-ar', str(target_sr), | |
| out_path | |
| ], check=True, capture_output=True) | |
| with open(out_path, 'rb') as f: | |
| wav_data = f.read() | |
| os.unlink(in_path) | |
| os.unlink(out_path) | |
| audio_array, sr = sf.read(io.BytesIO(wav_data)) | |
| if audio_array.ndim > 1: | |
| audio_array = np.mean(audio_array, axis=1) | |
| if sr != target_sr: | |
| ratio = target_sr / sr | |
| new_len = int(len(audio_array) * ratio) | |
| audio_array = np.interp( | |
| np.linspace(0, len(audio_array), new_len), | |
| np.arange(len(audio_array)), | |
| audio_array | |
| ) | |
| audio_array = np.clip(audio_array, -0.99, 0.99) | |
| audio_array = audio_array - float(np.mean(audio_array)) | |
| return audio_array.astype(np.float32) | |
| except Exception as e: | |
| logger.error(f"FFmpeg preprocessing failed: {e}") | |
| raise HTTPException(status_code=400, detail="Audio preprocessing failed") | |
| async def startup_event(): | |
| logger.info("Server started. Models will be loaded on first request.") | |
| async def root(): | |
| return { | |
| "name": "Multi Speech System", | |
| "description": "ASR β Ask β YarnGPT2 TTS with default voices per language", | |
| "status": "running" if model is not None else "model_loading_failed", | |
| "available_languages": ["english", "yoruba", "igbo", "hausa"], | |
| "available_speakers": { | |
| "english": ["idera"], | |
| "yoruba": ["yoruba_male2"], | |
| "igbo": ["igbo_male2"], | |
| "hausa": ["hausa_female1"] | |
| } | |
| } | |
| async def health_check(): | |
| return { | |
| "status": "healthy" if model is not None else "degraded", | |
| "device": str(device), | |
| "model_loaded": model is not None, | |
| "tokenizer_loaded": audio_tokenizer is not None | |
| } | |
| async def ask(query: str = Form(...)): | |
| try: | |
| resp = requests.post(ASK_URL, json={"query": query}, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception as e: | |
| logger.error(f"ASK error: {e}") | |
| raise HTTPException(status_code=502, detail="Ask backend error") | |
| async def _transcribe_impl(audio_file: UploadFile, language: str): | |
| if not audio_file.content_type or not audio_file.content_type.startswith('audio/'): | |
| raise HTTPException(status_code=400, detail="File must be an audio file") | |
| if language not in ["yo", "ha", "ig", "en"]: | |
| raise HTTPException(status_code=400, detail="Language must be one of: yo, ha, ig, en") | |
| audio_bytes = await audio_file.read() | |
| audio_array = _preprocess_audio_ffmpeg(audio_bytes) | |
| model_asr, proc = _get_asr(language) | |
| if model_asr is None or proc is None: | |
| raise HTTPException(status_code=500, detail="ASR model not available") | |
| try: | |
| device_t = next(model_asr.parameters()).device | |
| inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt") | |
| input_features = inputs.input_features.to(device_t) | |
| with torch.no_grad(): | |
| pred_ids = model_asr.generate(input_features) | |
| text_list = proc.batch_decode(pred_ids, skip_special_tokens=True) | |
| transcript = text_list[0].strip() if text_list else "" | |
| return {"language": language, "transcription": transcript} | |
| except Exception as e: | |
| logger.error(f"ASR inference failed: {e}") | |
| raise HTTPException(status_code=500, detail="ASR inference failed") | |
| async def transcribe(audio_file: UploadFile = File(...), language: str = Form(...)): | |
| return await _transcribe_impl(audio_file, language) | |
| async def transcribe_english(audio_file: UploadFile = File(...)): | |
| return await _transcribe_impl(audio_file, "en") | |
| async def transcribe_yoruba(audio_file: UploadFile = File(...)): | |
| return await _transcribe_impl(audio_file, "yo") | |
| async def transcribe_igbo(audio_file: UploadFile = File(...)): | |
| return await _transcribe_impl(audio_file, "ig") | |
| async def transcribe_hausa(audio_file: UploadFile = File(...)): | |
| return await _transcribe_impl(audio_file, "ha") | |
| async def speak_ai(audio_file: UploadFile = File(...), language: str = Form(...)): | |
| tr = await _transcribe_impl(audio_file, language) | |
| query = tr.get("transcription", "") | |
| if not query: | |
| raise HTTPException(status_code=400, detail="No transcription obtained from audio") | |
| try: | |
| ans = requests.post(ASK_URL, json={"query": query}, timeout=30) | |
| ans.raise_for_status() | |
| answer_text = ans.json().get("answer", "") | |
| if not answer_text: | |
| answer_text = query | |
| except Exception as e: | |
| logger.warning(f"Ask failed ({e}); falling back to transcript") | |
| answer_text = query | |
| speak_req = SpeakRequest(text=answer_text, language=_map_lang_code(language)) | |
| return await speak(speak_req) | |
| def _map_lang_code(code: str) -> str: | |
| m = {"yo": "yoruba", "ha": "hausa", "ig": "igbo", "en": "english"} | |
| return m.get(code.lower(), "english") | |
| async def text_to_speech(request: TTSRequest): | |
| return await _dispatch_tts_request(request) | |
| async def _dispatch_tts_request(request: TTSRequest, language_override: Optional[str] = None): | |
| global model, audio_tokenizer | |
| if model is None: | |
| logger.info("Loading YarnGPT2 model (lazy loading)...") | |
| load_model() | |
| if audio_tokenizer is None: | |
| logger.info("Loading audio tokenizer (lazy loading)...") | |
| load_audio_tokenizer() | |
| if model is None or audio_tokenizer is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Model loading failed. Please check logs." | |
| ) | |
| language = (language_override or request.language or "english").lower().strip() | |
| speak_req = SpeakRequest( | |
| text=request.text, | |
| language=language, | |
| temperature=request.temperature, | |
| repetition_penalty=request.repetition_penalty, | |
| max_length=request.max_length, | |
| ) | |
| return await speak(speak_req) | |
| async def text_to_speech_english(request: TTSRequest): | |
| return await _dispatch_tts_request(request, "english") | |
| async def text_to_speech_yoruba(request: TTSRequest): | |
| return await _dispatch_tts_request(request, "yoruba") | |
| async def text_to_speech_igbo(request: TTSRequest): | |
| return await _dispatch_tts_request(request, "igbo") | |
| async def text_to_speech_hausa(request: TTSRequest): | |
| return await _dispatch_tts_request(request, "hausa") | |
| async def speak(request: SpeakRequest): | |
| global model, audio_tokenizer | |
| if model is None: | |
| logger.info("Loading YarnGPT2 model (lazy loading)...") | |
| load_model() | |
| if audio_tokenizer is None: | |
| logger.info("Loading audio tokenizer (lazy loading)...") | |
| load_audio_tokenizer() | |
| if model is None or audio_tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model loading failed. Please check logs.") | |
| default_speakers = { | |
| "english": "idera", | |
| "yoruba": "yoruba_male2", | |
| "igbo": "igbo_male2", | |
| "hausa": "hausa_female1", | |
| } | |
| language = request.language.lower().strip() | |
| speaker = default_speakers.get(language, "idera") | |
| try: | |
| prompt = audio_tokenizer.create_prompt( | |
| request.text, | |
| lang=language, | |
| speaker_name=speaker, | |
| ) | |
| tokenized = audio_tokenizer.tokenize_prompt(prompt) | |
| if isinstance(tokenized, torch.Tensor): | |
| input_ids = tokenized | |
| attention_mask = None | |
| else: | |
| input_ids = tokenized.get("input_ids", tokenized) | |
| attention_mask = tokenized.get("attention_mask", None) | |
| if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None: | |
| audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token | |
| with torch.no_grad(): | |
| gen_kwargs = { | |
| "input_ids": input_ids, | |
| "repetition_penalty": request.repetition_penalty or 1.1, | |
| "max_length": request.max_length or 4000, | |
| } | |
| if attention_mask is not None: | |
| gen_kwargs["attention_mask"] = attention_mask | |
| use_beams = language in ["yoruba", "igbo", "hausa"] | |
| if use_beams: | |
| gen_kwargs["num_beams"] = 5 | |
| gen_kwargs["early_stopping"] = False | |
| else: | |
| temp = request.temperature or 0.1 | |
| if temp > 0: | |
| gen_kwargs["do_sample"] = True | |
| gen_kwargs["temperature"] = temp | |
| output = model.generate(**gen_kwargs) | |
| logger.info(f"Generated output length: {output.shape[1]}, input length: {input_ids.shape[1]}, generated tokens: {output.shape[1] - input_ids.shape[1]}") | |
| codes = audio_tokenizer.get_codes(output) | |
| logger.info(f"Extracted {len(codes)} audio codes") | |
| audio = audio_tokenizer.get_audio(codes) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| if isinstance(audio, torch.Tensor): | |
| audio_tensor = audio.detach() | |
| else: | |
| audio_tensor = torch.tensor(np.asarray(audio)) | |
| audio_tensor = audio_tensor.to(torch.float32).cpu() | |
| if audio_tensor.ndim > 1: | |
| audio_tensor = audio_tensor.squeeze() | |
| peak = audio_tensor.abs().max() | |
| if peak > 1.0: | |
| audio_tensor = audio_tensor / peak | |
| torchaudio.save(temp_file.name, audio_tensor.unsqueeze(0), 24000) | |
| return FileResponse( | |
| temp_file.name, | |
| media_type="audio/wav", | |
| filename="speech.wav", | |
| background=BackgroundTask(lambda: os.path.exists(temp_file.name) and os.unlink(temp_file.name)) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Speak error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Speak failed: {e}") | |
| async def text_to_speech_stream(request: TTSRequest): | |
| global model, audio_tokenizer | |
| if model is None: | |
| logger.info("Loading YarnGPT2 model (lazy loading)...") | |
| load_model() | |
| if audio_tokenizer is None: | |
| logger.info("Loading audio tokenizer (lazy loading)...") | |
| load_audio_tokenizer() | |
| if model is None or audio_tokenizer is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Model loading failed. Please check logs." | |
| ) | |
| try: | |
| default_speakers = { | |
| "english": "idera", | |
| "yoruba": "yoruba_male2", | |
| "igbo": "igbo_male2", | |
| "hausa": "hausa_female1", | |
| } | |
| lang_norm = request.language.lower().strip() | |
| spk = default_speakers.get(lang_norm, "idera") | |
| prompt = audio_tokenizer.create_prompt(request.text, lang=lang_norm, speaker_name=spk) | |
| tokenized = audio_tokenizer.tokenize_prompt(prompt) | |
| if isinstance(tokenized, torch.Tensor): | |
| input_ids = tokenized | |
| attention_mask = None | |
| else: | |
| input_ids = tokenized.get("input_ids", tokenized) | |
| attention_mask = tokenized.get("attention_mask", None) | |
| if hasattr(audio_tokenizer, 'tokenizer') and audio_tokenizer.tokenizer.pad_token is None: | |
| audio_tokenizer.tokenizer.pad_token = audio_tokenizer.tokenizer.eos_token | |
| logger.info(f"Generating speech (streaming) for text: {request.text[:50]}...") | |
| with torch.no_grad(): | |
| gen_kwargs = { | |
| "input_ids": input_ids, | |
| "repetition_penalty": request.repetition_penalty or 1.1, | |
| "max_length": request.max_length or 4000, | |
| } | |
| if attention_mask is not None: | |
| gen_kwargs["attention_mask"] = attention_mask | |
| use_beams = lang_norm in ["yoruba", "igbo", "hausa"] | |
| if use_beams: | |
| gen_kwargs["num_beams"] = 5 | |
| gen_kwargs["early_stopping"] = False | |
| else: | |
| temp = request.temperature or 0.1 | |
| if temp > 0: | |
| gen_kwargs["do_sample"] = True | |
| gen_kwargs["temperature"] = temp | |
| output = model.generate(**gen_kwargs) | |
| codes = audio_tokenizer.get_codes(output) | |
| audio = audio_tokenizer.get_audio(codes) | |
| buffer = io.BytesIO() | |
| if isinstance(audio, torch.Tensor): | |
| audio_tensor = audio.detach() | |
| else: | |
| audio_tensor = torch.tensor(np.asarray(audio)) | |
| audio_tensor = audio_tensor.to(torch.float32).cpu() | |
| if audio_tensor.ndim > 1: | |
| audio_tensor = audio_tensor.squeeze() | |
| peak = audio_tensor.abs().max() | |
| if peak > 1.0: | |
| audio_tensor = audio_tensor / peak | |
| torchaudio.save(buffer, audio_tensor.unsqueeze(0), 24000, format="wav") | |
| buffer.seek(0) | |
| return StreamingResponse( | |
| buffer, | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "attachment; filename=speech.wav"}, | |
| background=BackgroundTask(buffer.close) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating speech: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to generate speech: {str(e)}" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) | |